diff --git a/crates/core/src/functions.rs b/crates/core/src/functions.rs index 6996dca94..feb6b0a97 100644 --- a/crates/core/src/functions.rs +++ b/crates/core/src/functions.rs @@ -93,6 +93,13 @@ fn array_cat(exprs: Vec) -> PyExpr { array_concat(exprs) } +#[pyfunction] +fn make_map(keys: Vec, values: Vec) -> PyExpr { + let keys = keys.into_iter().map(|x| x.into()).collect(); + let values = values.into_iter().map(|x| x.into()).collect(); + datafusion::functions_nested::map::map(keys, values).into() +} + #[pyfunction] #[pyo3(signature = (array, element, index=None))] fn array_position(array: PyExpr, element: PyExpr, index: Option) -> PyExpr { @@ -666,6 +673,12 @@ array_fn!(cardinality, array); array_fn!(flatten, array); array_fn!(range, start stop step); +// Map Functions +array_fn!(map_keys, map); +array_fn!(map_values, map); +array_fn!(map_extract, map key); +array_fn!(map_entries, map); + aggregate_function!(array_agg); aggregate_function!(max); aggregate_function!(min); @@ -1126,6 +1139,13 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(flatten))?; m.add_wrapped(wrap_pyfunction!(cardinality))?; + // Map Functions + m.add_wrapped(wrap_pyfunction!(make_map))?; + m.add_wrapped(wrap_pyfunction!(map_keys))?; + m.add_wrapped(wrap_pyfunction!(map_values))?; + m.add_wrapped(wrap_pyfunction!(map_extract))?; + m.add_wrapped(wrap_pyfunction!(map_entries))?; + // Window Functions m.add_wrapped(wrap_pyfunction!(lead))?; m.add_wrapped(wrap_pyfunction!(lag))?; diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 3c8d2bcee..f265f7f4c 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -139,6 +139,7 @@ "degrees", "dense_rank", "digest", + "element_at", "empty", "encode", "ends_with", @@ -202,7 +203,12 @@ "make_array", "make_date", "make_list", + "make_map", "make_time", + "map_entries", + "map_extract", + "map_keys", + "map_values", "max", "md5", "mean", @@ -3374,6 +3380,158 @@ def empty(array: Expr) -> Expr: return array_empty(array) +# map functions + + +def make_map(*args: Any) -> Expr: + """Returns a map expression. + + Supports three calling conventions: + + - ``make_map({"a": 1, "b": 2})`` — from a Python dictionary. + - ``make_map([keys], [values])`` — from a list of keys and a list of + their associated values. Both lists must be the same length. + - ``make_map(k1, v1, k2, v2, ...)`` — from alternating keys and their + associated values. + + Keys and values that are not already :py:class:`~datafusion.expr.Expr` + are automatically converted to literal expressions. + + Examples: + From a dictionary: + + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.make_map({"a": 1, "b": 2}).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('a', 1), ('b', 2)] + + From two lists: + + >>> df = ctx.from_pydict({"key": ["x", "y"], "val": [10, 20]}) + >>> df = df.select( + ... dfn.functions.make_map( + ... [dfn.col("key")], [dfn.col("val")] + ... ).alias("m")) + >>> df.collect_column("m")[0].as_py() + [('x', 10)] + + From alternating keys and values: + + >>> df = ctx.from_pydict({"a": [1]}) + >>> result = df.select( + ... dfn.functions.make_map("x", 1, "y", 2).alias("m")) + >>> result.collect_column("m")[0].as_py() + [('x', 1), ('y', 2)] + """ + if len(args) == 1 and isinstance(args[0], dict): + key_list = list(args[0].keys()) + value_list = list(args[0].values()) + elif ( + len(args) == 2 # noqa: PLR2004 + and isinstance(args[0], list) + and isinstance(args[1], list) + ): + if len(args[0]) != len(args[1]): + msg = "make_map requires key and value lists to be the same length" + raise ValueError(msg) + key_list = args[0] + value_list = args[1] + elif len(args) >= 2 and len(args) % 2 == 0: # noqa: PLR2004 + key_list = list(args[0::2]) + value_list = list(args[1::2]) + else: + msg = ( + "make_map expects a dict, two lists, or an even number of " + "key-value arguments" + ) + raise ValueError(msg) + + key_exprs = [k if isinstance(k, Expr) else Expr.literal(k) for k in key_list] + val_exprs = [v if isinstance(v, Expr) else Expr.literal(v) for v in value_list] + return Expr(f.make_map([k.expr for k in key_exprs], [v.expr for v in val_exprs])) + + +def map_keys(map: Expr) -> Expr: + """Returns a list of all keys in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_keys(dfn.col("m")).alias("keys")) + >>> result.collect_column("keys")[0].as_py() + ['x', 'y'] + """ + return Expr(f.map_keys(map.expr)) + + +def map_values(map: Expr) -> Expr: + """Returns a list of all values in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_values(dfn.col("m")).alias("vals")) + >>> result.collect_column("vals")[0].as_py() + [1, 2] + """ + return Expr(f.map_values(map.expr)) + + +def map_extract(map: Expr, key: Expr) -> Expr: + """Returns the value for a given key in the map. + + Returns ``[None]`` if the key is absent. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_extract( + ... dfn.col("m"), dfn.lit("x") + ... ).alias("val")) + >>> result.collect_column("val")[0].as_py() + [1] + """ + return Expr(f.map_extract(map.expr, key.expr)) + + +def map_entries(map: Expr) -> Expr: + """Returns a list of all entries (key-value struct pairs) in the map. + + Examples: + >>> ctx = dfn.SessionContext() + >>> df = ctx.from_pydict({"a": [1]}) + >>> df = df.select( + ... dfn.functions.make_map({"x": 1, "y": 2}).alias("m")) + >>> result = df.select( + ... dfn.functions.map_entries(dfn.col("m")).alias("entries")) + >>> result.collect_column("entries")[0].as_py() + [{'key': 'x', 'value': 1}, {'key': 'y', 'value': 2}] + """ + return Expr(f.map_entries(map.expr)) + + +def element_at(map: Expr, key: Expr) -> Expr: + """Returns the value for a given key in the map. + + Returns ``[None]`` if the key is absent. + + See Also: + This is an alias for :py:func:`map_extract`. + """ + return map_extract(map, key) + + # aggregate functions def approx_distinct( expression: Expr, diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 08420826d..00698f5aa 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -668,6 +668,106 @@ def test_array_function_obj_tests(stmt, py_expr): assert a == b +@pytest.mark.parametrize( + ("args", "expected"), + [ + pytest.param( + ({"x": 1, "y": 2},), + [("x", 1), ("y", 2)], + id="dict", + ), + pytest.param( + ({"x": literal(1), "y": literal(2)},), + [("x", 1), ("y", 2)], + id="dict_with_exprs", + ), + pytest.param( + ("x", 1, "y", 2), + [("x", 1), ("y", 2)], + id="variadic_pairs", + ), + pytest.param( + (literal("x"), literal(1), literal("y"), literal(2)), + [("x", 1), ("y", 2)], + id="variadic_with_exprs", + ), + ], +) +def test_make_map(args, expected): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + result = df.select(f.make_map(*args).alias("m")).collect()[0].column(0) + assert result[0].as_py() == expected + + +def test_make_map_from_two_lists(): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays( + [ + pa.array(["k1", "k2", "k3"]), + pa.array([10, 20, 30]), + ], + names=["keys", "vals"], + ) + df = ctx.create_dataframe([[batch]]) + + m = f.make_map([column("keys")], [column("vals")]) + result = df.select(f.map_keys(m).alias("k")).collect()[0].column(0) + assert result.to_pylist() == [["k1"], ["k2"], ["k3"]] + + result = df.select(f.map_values(m).alias("v")).collect()[0].column(0) + assert result.to_pylist() == [[10], [20], [30]] + + +def test_make_map_odd_args_raises(): + with pytest.raises(ValueError, match="make_map expects"): + f.make_map("x", 1, "y") + + +def test_make_map_mismatched_lengths(): + with pytest.raises(ValueError, match="same length"): + f.make_map(["a", "b"], [1]) + + +@pytest.mark.parametrize( + ("func", "expected"), + [ + pytest.param(f.map_keys, ["x", "y"], id="map_keys"), + pytest.param(f.map_values, [1, 2], id="map_values"), + pytest.param( + lambda m: f.map_extract(m, literal("x")), + [1], + id="map_extract", + ), + pytest.param( + lambda m: f.map_extract(m, literal("z")), + [None], + id="map_extract_missing_key", + ), + pytest.param( + f.map_entries, + [{"key": "x", "value": 1}, {"key": "y", "value": 2}], + id="map_entries", + ), + pytest.param( + lambda m: f.element_at(m, literal("y")), + [2], + id="element_at", + ), + ], +) +def test_map_functions(func, expected): + ctx = SessionContext() + batch = pa.RecordBatch.from_arrays([pa.array([1])], names=["a"]) + df = ctx.create_dataframe([[batch]]) + + m = f.make_map({"x": 1, "y": 2}) + result = df.select(func(m).alias("out")).collect()[0].column(0) + assert result[0].as_py() == expected + + @pytest.mark.parametrize( ("function", "expected_result"), [