diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index a7b703564a..b896e49f38 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -153,6 +153,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: """ @@ -194,12 +195,25 @@ def compile( for op in offset_provider ) - self._compiled_programs.compile(offset_providers=offset_provider, **static_args) + if self.compilation_options.static_domains and static_domains is None: + raise ValueError( + "Static domains option is enabled, but no static domain information was " + "provided. Missing required argument 'static_domains'." + ) + if not self.compilation_options.static_domains and static_domains is not None: + raise ValueError( + "Static domains may not be provided when 'static_domains' option is disabled." + ) + + self._compiled_programs.compile( + offset_providers=offset_provider, static_domains=static_domains, **static_args + ) return self def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) -> list[str]: static_domain_args = [] + assert func_type.pos_only_args == [] param_types = func_type.pos_or_kw_args | func_type.kw_only_args for name, type_ in param_types.items(): for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): @@ -483,6 +497,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: raise NotImplementedError("Compilation of programs with bound arguments is not implemented") diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index e46d8219a2..902c2ca19d 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -28,7 +28,7 @@ ) from gt4py.next.instrumentation import hook_machinery, metrics from gt4py.next.otf import arguments, stages -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -605,11 +605,21 @@ def _compile_variant( def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, **static_args: list[ScalarOrTupleOfScalars], ) -> None: """ Compiles the program for all combinations of static arguments and the given 'OffsetProviderType'. + Args: + offset_providers: List of offset providers to compile for. + static_domains: Per-field domain specification. + Example:: + { + "input_field": {IDim: (0, 55), JDim: (0, 30)}, + "out": {IHalfDim: (0, 52), JDim: (0, 30)}, + } + Note: In case you want to compile for specific combinations of static arguments (instead of the combinatoral), you can call compile multiples times. @@ -619,17 +629,78 @@ def compile( pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3]) will compile for (0,2), (1,3) """ + + def _build_field_domain_descriptors( + program_type: ts_ffront.ProgramType, + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]], + ) -> dict[str, arguments.FieldDomainDescriptor]: + field_domain_descriptors: dict[str, arguments.FieldDomainDescriptor] = {} + assert program_type.definition.pos_only_args == [] + + matched_keys = set() + param_types = ( + program_type.definition.pos_or_kw_args | program_type.definition.kw_only_args + ) + for arg_name, arg_type_ in param_types.items(): + for el_type_, path in type_info.primitive_constituents( + arg_type_, with_path_arg=True + ): + if isinstance(el_type_, ts.FieldType): + path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) + field_expr = f"{arg_name}{path_as_expr}" + + if field_expr not in static_domains: + raise ValueError( + f"Missing static domain for field '{field_expr}'. " + f"Expected domains for all field parameters. " + f"Available keys in static_domains: {list(static_domains.keys())}" + ) + + field_domain = static_domains[field_expr] + matched_keys.add(field_expr) + + expected_dims = set(el_type_.dims) + provided_dims = set(field_domain.keys()) + if expected_dims != provided_dims: + raise ValueError( + f"Domain dimension mismatch for field '{field_expr}': " + f"field has dimensions {set(d.value for d in expected_dims)}, " + f"but static_domains provides {set(d.value for d in provided_dims)}." + ) + + domain_ranges = {dim: field_domain[dim] for dim in el_type_.dims} + field_domain_descriptors[field_expr] = arguments.FieldDomainDescriptor( + common.domain(domain_ranges) + ) + + extra_keys = set(static_domains.keys()) - matched_keys + if extra_keys: + raise ValueError( + f"static_domains contains keys that do not correspond to any field parameter: " + f"{list(extra_keys)}. Valid field expressions are: {list(matched_keys)}" + ) + + return field_domain_descriptors + for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): + argument_descriptor_dict: dict[ + type[arguments.ArgStaticDescriptor], + dict[str, arguments.ArgStaticDescriptor], + ] = { + arguments.StaticArg: dict( + zip( + static_args.keys(), + [arguments.StaticArg(value=v) for v in static_values], + strict=True, + ) + ), + } + if static_domains: + argument_descriptor_dict[arguments.FieldDomainDescriptor] = ( + _build_field_domain_descriptors(self.program_type, static_domains) # type: ignore[assignment] + ) self._compile_variant( - argument_descriptors={ - arguments.StaticArg: dict( - zip( - static_args.keys(), - [arguments.StaticArg(value=v) for v in static_values], - strict=True, - ) - ), - }, + argument_descriptor_dict, offset_provider=offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 18c6c26ff4..7b6fa0eab0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -35,7 +35,6 @@ skip_value_mesh, ) -from gt4py.next.otf import arguments _raise_on_compile = mock.Mock() _raise_on_compile.compile.side_effect = AssertionError("This function should never be called.") @@ -62,6 +61,7 @@ def testee(a: cases.IField, b: cases.IField, out: cases.IField): testee_op(a, b, out=out) wrap_in_program = request.param + if wrap_in_program: return testee else: @@ -944,7 +944,8 @@ def test_wait_for_compilation(cartesian_case, compile_testee, compile_testee_dom compile_testee_domain.compile(offset_provider=cartesian_case.offset_provider) -def test_compile_variants_decorator_static_domains(cartesian_case): +@pytest.mark.parametrize("precompile", [True, False], ids=["precompile", "run"]) +def test_compile_variants_decorator_static_domains(cartesian_case, precompile): if cartesian_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") @@ -970,24 +971,32 @@ def identity_like(inp: tuple[cases.IField, cases.IField, float]): def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCollection): identity_like(inp, out=out) - inp = cases.allocate(cartesian_case, testee, "inp")() - out = cases.allocate(cartesian_case, testee, "out")() - - testee(inp, out, offset_provider={}) - assert np.allclose(inp[0].ndarray, out[0].ndarray) - assert np.allclose(inp[1].ndarray, out[1].ndarray) - - assert testee._compiled_programs.argument_descriptor_mapping[ - arguments.FieldDomainDescriptor - ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] - assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { - "inp": ( - arguments.FieldDomainDescriptor(inp[0].domain), - arguments.FieldDomainDescriptor(inp[1].domain), - None, - ), - "out": ( - arguments.FieldDomainDescriptor(out[0].domain), - arguments.FieldDomainDescriptor(out[1].domain), - ), - } + with mock.patch.object(compiled_program, "_async_compilation_pool", None): + inp = cases.allocate(cartesian_case, testee, "inp")() + out = cases.allocate(cartesian_case, testee, "out")() + if precompile: + testee.compile( + offset_provider=cartesian_case.offset_provider, + static_domains={ + dim: (0, size) for dim, size in cartesian_case.default_sizes.items() + }, + ) + else: + testee(inp, out, offset_provider={}) + assert np.allclose(inp[0].ndarray, out[0].ndarray) + assert np.allclose(inp[1].ndarray, out[1].ndarray) + + assert testee._compiled_programs.argument_descriptor_mapping[ + arguments.FieldDomainDescriptor + ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] + assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { + "inp": ( + arguments.FieldDomainDescriptor(inp[0].domain), + arguments.FieldDomainDescriptor(inp[1].domain), + None, + ), + "out": ( + arguments.FieldDomainDescriptor(out[0].domain), + arguments.FieldDomainDescriptor(out[1].domain), + ), + }