From dc6dbb73339258fe515d6e373024754656010ab3 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 17 Feb 2026 15:13:39 +0100 Subject: [PATCH 1/6] Add static_domains for pre-compilation --- src/gt4py/next/ffront/decorator.py | 5 +- src/gt4py/next/otf/compiled_program.py | 49 ++++++++++++---- .../ffront_tests/test_compiled_program.py | 57 ++++++++++++++++++- 3 files changed, 97 insertions(+), 14 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index f1efeb7353..1de2bc2591 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -153,6 +153,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, + static_domains: Optional[dict[common.Domain, int] | None] = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: """ @@ -194,7 +195,9 @@ def compile( for op in offset_provider ) - self._compiled_programs.compile(offset_providers=offset_provider, **static_args) + self._compiled_programs.compile( + offset_providers=offset_provider, static_domains=static_domains, **static_args + ) return self diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index e46d8219a2..8b90b0c49c 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -15,7 +15,7 @@ import itertools import warnings from collections.abc import Callable, Hashable, Sequence -from typing import Any, Generic, TypeAlias, TypeVar +from typing import Any, Generic, Optional, TypeAlias, TypeVar from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping, utils as eve_utils @@ -28,7 +28,7 @@ ) from gt4py.next.instrumentation import hook_machinery, metrics from gt4py.next.otf import arguments, stages -from gt4py.next.type_system import type_specifications as ts +from gt4py.next.type_system import type_info, type_specifications as ts from gt4py.next.utils import tree_map @@ -605,6 +605,7 @@ def _compile_variant( def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], + static_domains: Optional[dict[common.Domain, int] | None] = None, **static_args: list[ScalarOrTupleOfScalars], ) -> None: """ @@ -619,17 +620,43 @@ def compile( pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3]) will compile for (0,2), (1,3) """ + + def _build_field_domain_descriptors(program_type, static_domains): + def _create_field_descriptor(field_type): + domain_ranges = { + dim: static_domains[dim] for dim in field_type.dims + } # TODO: improve error message + return arguments.FieldDomainDescriptor(common.domain(domain_ranges)) + + field_domain_descriptors = {} + for arg_name, arg_type_ in program_type.definition.pos_or_kw_args.items(): + for el_type_, path in type_info.primitive_constituents( + arg_type_, with_path_arg=True + ): + if isinstance(el_type_, ts.FieldType): + path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) + field_domain_descriptors[f"{arg_name}{path_as_expr}"] = ( + _create_field_descriptor(el_type_) + ) + + return field_domain_descriptors + for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): + argument_descriptor_dict = { + arguments.StaticArg: dict( + zip( + static_args.keys(), + [arguments.StaticArg(value=v) for v in static_values], + strict=True, + ) + ), + } + if static_domains: + argument_descriptor_dict[arguments.FieldDomainDescriptor] = ( + _build_field_domain_descriptors(self.program_type, static_domains) + ) self._compile_variant( - argument_descriptors={ - arguments.StaticArg: dict( - zip( - static_args.keys(), - [arguments.StaticArg(value=v) for v in static_values], - strict=True, - ) - ), - }, + argument_descriptor_dict, offset_provider=offset_provider, ) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 18c6c26ff4..04e4c8e4d5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -35,7 +35,6 @@ skip_value_mesh, ) -from gt4py.next.otf import arguments _raise_on_compile = mock.Mock() _raise_on_compile.compile.side_effect = AssertionError("This function should never be called.") @@ -49,7 +48,7 @@ class NamedTupleNamedCollection(NamedTuple): @pytest.fixture( params=[ pytest.param(True, id="program"), - pytest.param(False, id="field-operator"), + # pytest.param(False, id="field-operator"), ] ) def compile_testee(request, cartesian_case): @@ -62,6 +61,7 @@ def testee(a: cases.IField, b: cases.IField, out: cases.IField): testee_op(a, b, out=out) wrap_in_program = request.param + print(f"HUHU{id(testee)}", flush=True) if wrap_in_program: return testee else: @@ -991,3 +991,56 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo arguments.FieldDomainDescriptor(out[1].domain), ), } + + +def test_compile_with_static_domains(compile_variants_field_operator, cartesian_case): + if cartesian_case.backend is None: + pytest.skip("Embedded compiled program doesn't make sense.") + + captured_cargs: Optional[arguments.CompileTimeArgs] = None + + class CaptureCompileTimeArgsBackend: + def __getattr__(self, name): + return getattr(cartesian_case.backend, name) + + def compile(self, program, compile_time_args): + nonlocal captured_cargs + captured_cargs = compile_time_args + + return cartesian_case.backend.compile(program, compile_time_args) + + @gtx.field_operator + def identity_like(inp: tuple[cases.IField, cases.IField, float]): + return inp[0], inp[1] + + # the float argument here is merely to test that static domains work for tuple arguments + # of inhomogeneous types + @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) + def testee( + inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField] + ): + identity_like(inp, out=out) + + inp = cases.allocate(cartesian_case, testee, "inp")() + out = cases.allocate(cartesian_case, testee, "out")() + + testee.compile( + offset_provider=cartesian_case.offset_provider, + static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()}, + ) + + assert testee._compiled_programs.argument_descriptor_mapping[ + arguments.FieldDomainDescriptor + ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] + + assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { + "inp": ( + arguments.FieldDomainDescriptor(inp[0].domain), + arguments.FieldDomainDescriptor(inp[1].domain), + None, + ), + "out": ( + arguments.FieldDomainDescriptor(out[0].domain), + arguments.FieldDomainDescriptor(out[1].domain), + ), + } From c3e045381c2fe3da1e3de3325b754391e167457f Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Tue, 17 Feb 2026 15:31:09 +0100 Subject: [PATCH 2/6] Minor --- src/gt4py/next/otf/compiled_program.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 8b90b0c49c..a7b0fa88fa 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -19,6 +19,7 @@ from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping, utils as eve_utils +from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import ( stages as ffront_stages, @@ -605,7 +606,7 @@ def _compile_variant( def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], - static_domains: Optional[dict[common.Domain, int] | None] = None, + static_domains: Optional[dict[common.Dimension, tuple[int, int]] | None] = None, **static_args: list[ScalarOrTupleOfScalars], ) -> None: """ @@ -621,8 +622,13 @@ def compile( will compile for (0,2), (1,3) """ - def _build_field_domain_descriptors(program_type, static_domains): - def _create_field_descriptor(field_type): + def _build_field_domain_descriptors( + program_type: ts_ffront.ProgramType, + static_domains: dict[common.Dimension, tuple[int, int]], + ) -> dict[str, MaybeNestedInTuple[arguments.FieldDomainDescriptor]]: + def _create_field_descriptor( + field_type: ts.FieldType, + ) -> arguments.FieldDomainDescriptor: domain_ranges = { dim: static_domains[dim] for dim in field_type.dims } # TODO: improve error message From 46a855ef8e6184eb8f20e8646098bf1624ebabf7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 18 Feb 2026 14:08:58 +0100 Subject: [PATCH 3/6] Some fixes --- src/gt4py/next/ffront/decorator.py | 14 ++- src/gt4py/next/otf/compiled_program.py | 22 +++-- .../ffront_tests/test_compiled_program.py | 88 ++++++++++--------- 3 files changed, 72 insertions(+), 52 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 1de2bc2591..c8a08ef105 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -153,7 +153,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, - static_domains: Optional[dict[common.Domain, int] | None] = None, + static_domains: dict[common.Dimension, tuple[int, int]] | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: """ @@ -195,6 +195,16 @@ def compile( for op in offset_provider ) + if self.compilation_options.static_domains and static_domains is None: + raise ValueError( + "Static domains option is enabled, but no static domain information was " + "provided. Missing required argument 'static_domains'." + ) + if not self.compilation_options.static_domains and static_domains is not None: + raise ValueError( + "Static domains may not be provided when 'static_domains' option is disabled." + ) + self._compiled_programs.compile( offset_providers=offset_provider, static_domains=static_domains, **static_args ) @@ -203,6 +213,7 @@ def compile( def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) -> list[str]: static_domain_args = [] + assert func_type.pos_only_args == [] param_types = func_type.pos_or_kw_args | func_type.kw_only_args for name, type_ in param_types.items(): for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True): @@ -486,6 +497,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, + static_domains: dict[common.Dimension, tuple[int, int]] | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: raise NotImplementedError("Compilation of programs with bound arguments is not implemented") diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index a7b0fa88fa..2df22cca0c 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -15,11 +15,10 @@ import itertools import warnings from collections.abc import Callable, Hashable, Sequence -from typing import Any, Generic, Optional, TypeAlias, TypeVar +from typing import Any, Generic, TypeAlias, TypeVar from gt4py._core import definitions as core_defs from gt4py.eve import extended_typing as xtyping, utils as eve_utils -from gt4py.eve.extended_typing import MaybeNestedInTuple from gt4py.next import backend as gtx_backend, common, config, errors, utils as gtx_utils from gt4py.next.ffront import ( stages as ffront_stages, @@ -606,7 +605,7 @@ def _compile_variant( def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], - static_domains: Optional[dict[common.Dimension, tuple[int, int]] | None] = None, + static_domains: dict[common.Dimension, tuple[int, int]] | None = None, **static_args: list[ScalarOrTupleOfScalars], ) -> None: """ @@ -625,7 +624,7 @@ def compile( def _build_field_domain_descriptors( program_type: ts_ffront.ProgramType, static_domains: dict[common.Dimension, tuple[int, int]], - ) -> dict[str, MaybeNestedInTuple[arguments.FieldDomainDescriptor]]: + ) -> dict[str, arguments.FieldDomainDescriptor]: def _create_field_descriptor( field_type: ts.FieldType, ) -> arguments.FieldDomainDescriptor: @@ -634,8 +633,12 @@ def _create_field_descriptor( } # TODO: improve error message return arguments.FieldDomainDescriptor(common.domain(domain_ranges)) - field_domain_descriptors = {} - for arg_name, arg_type_ in program_type.definition.pos_or_kw_args.items(): + field_domain_descriptors: dict[str, arguments.FieldDomainDescriptor] = {} + assert program_type.definition.pos_only_args == [] + param_types = ( + program_type.definition.pos_or_kw_args | program_type.definition.kw_only_args + ) + for arg_name, arg_type_ in param_types.items(): for el_type_, path in type_info.primitive_constituents( arg_type_, with_path_arg=True ): @@ -649,7 +652,10 @@ def _create_field_descriptor( for offset_provider in offset_providers: # not included in product for better type checking for static_values in itertools.product(*static_args.values()): - argument_descriptor_dict = { + argument_descriptor_dict: dict[ + type[arguments.ArgStaticDescriptor], + dict[str, arguments.ArgStaticDescriptor], + ] = { arguments.StaticArg: dict( zip( static_args.keys(), @@ -660,7 +666,7 @@ def _create_field_descriptor( } if static_domains: argument_descriptor_dict[arguments.FieldDomainDescriptor] = ( - _build_field_domain_descriptors(self.program_type, static_domains) + _build_field_domain_descriptors(self.program_type, static_domains) # type: ignore[assignment] ) self._compile_variant( argument_descriptor_dict, diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 04e4c8e4d5..ef5d9b2d58 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -48,7 +48,7 @@ class NamedTupleNamedCollection(NamedTuple): @pytest.fixture( params=[ pytest.param(True, id="program"), - # pytest.param(False, id="field-operator"), + pytest.param(False, id="field-operator"), ] ) def compile_testee(request, cartesian_case): @@ -61,7 +61,7 @@ def testee(a: cases.IField, b: cases.IField, out: cases.IField): testee_op(a, b, out=out) wrap_in_program = request.param - print(f"HUHU{id(testee)}", flush=True) + if wrap_in_program: return testee else: @@ -970,27 +970,28 @@ def identity_like(inp: tuple[cases.IField, cases.IField, float]): def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCollection): identity_like(inp, out=out) - inp = cases.allocate(cartesian_case, testee, "inp")() - out = cases.allocate(cartesian_case, testee, "out")() + with mock.patch.object(compiled_program, "_async_compilation_pool", None): + inp = cases.allocate(cartesian_case, testee, "inp")() + out = cases.allocate(cartesian_case, testee, "out")() - testee(inp, out, offset_provider={}) - assert np.allclose(inp[0].ndarray, out[0].ndarray) - assert np.allclose(inp[1].ndarray, out[1].ndarray) - - assert testee._compiled_programs.argument_descriptor_mapping[ - arguments.FieldDomainDescriptor - ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] - assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { - "inp": ( - arguments.FieldDomainDescriptor(inp[0].domain), - arguments.FieldDomainDescriptor(inp[1].domain), - None, - ), - "out": ( - arguments.FieldDomainDescriptor(out[0].domain), - arguments.FieldDomainDescriptor(out[1].domain), - ), - } + testee(inp, out, offset_provider={}) + assert np.allclose(inp[0].ndarray, out[0].ndarray) + assert np.allclose(inp[1].ndarray, out[1].ndarray) + + assert testee._compiled_programs.argument_descriptor_mapping[ + arguments.FieldDomainDescriptor + ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] + assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { + "inp": ( + arguments.FieldDomainDescriptor(inp[0].domain), + arguments.FieldDomainDescriptor(inp[1].domain), + None, + ), + "out": ( + arguments.FieldDomainDescriptor(out[0].domain), + arguments.FieldDomainDescriptor(out[1].domain), + ), + } def test_compile_with_static_domains(compile_variants_field_operator, cartesian_case): @@ -1021,26 +1022,27 @@ def testee( ): identity_like(inp, out=out) - inp = cases.allocate(cartesian_case, testee, "inp")() - out = cases.allocate(cartesian_case, testee, "out")() + with mock.patch.object(compiled_program, "_async_compilation_pool", None): + inp = cases.allocate(cartesian_case, testee, "inp")() + out = cases.allocate(cartesian_case, testee, "out")() - testee.compile( - offset_provider=cartesian_case.offset_provider, - static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()}, - ) + testee.compile( + offset_provider=cartesian_case.offset_provider, + static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()}, + ) - assert testee._compiled_programs.argument_descriptor_mapping[ - arguments.FieldDomainDescriptor - ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] - - assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { - "inp": ( - arguments.FieldDomainDescriptor(inp[0].domain), - arguments.FieldDomainDescriptor(inp[1].domain), - None, - ), - "out": ( - arguments.FieldDomainDescriptor(out[0].domain), - arguments.FieldDomainDescriptor(out[1].domain), - ), - } + assert testee._compiled_programs.argument_descriptor_mapping[ + arguments.FieldDomainDescriptor + ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] + + assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { + "inp": ( + arguments.FieldDomainDescriptor(inp[0].domain), + arguments.FieldDomainDescriptor(inp[1].domain), + None, + ), + "out": ( + arguments.FieldDomainDescriptor(out[0].domain), + arguments.FieldDomainDescriptor(out[1].domain), + ), + } From 694d927fb1b9dcacdac50168fb46768b6c3452a7 Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 18 Feb 2026 14:32:12 +0100 Subject: [PATCH 4/6] Refactor test and add error message --- src/gt4py/next/otf/compiled_program.py | 18 ++--- .../ffront_tests/test_compiled_program.py | 70 +++---------------- 2 files changed, 20 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 2df22cca0c..932e108943 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -625,14 +625,6 @@ def _build_field_domain_descriptors( program_type: ts_ffront.ProgramType, static_domains: dict[common.Dimension, tuple[int, int]], ) -> dict[str, arguments.FieldDomainDescriptor]: - def _create_field_descriptor( - field_type: ts.FieldType, - ) -> arguments.FieldDomainDescriptor: - domain_ranges = { - dim: static_domains[dim] for dim in field_type.dims - } # TODO: improve error message - return arguments.FieldDomainDescriptor(common.domain(domain_ranges)) - field_domain_descriptors: dict[str, arguments.FieldDomainDescriptor] = {} assert program_type.definition.pos_only_args == [] param_types = ( @@ -644,8 +636,16 @@ def _create_field_descriptor( ): if isinstance(el_type_, ts.FieldType): path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) + if missing_dims := [ + dim for dim in el_type_.dims if dim not in static_domains + ]: + raise ValueError( + f"Missing domain specification for dimension(s) {missing_dims} for {arg_name}{path_as_expr}. " + f"Field has dimensions {list(el_type_.dims)}, but static_domains only contains {list(static_domains.keys())}." + ) + domain_ranges = {dim: static_domains[dim] for dim in el_type_.dims} field_domain_descriptors[f"{arg_name}{path_as_expr}"] = ( - _create_field_descriptor(el_type_) + arguments.FieldDomainDescriptor(common.domain(domain_ranges)) ) return field_domain_descriptors diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index ef5d9b2d58..6f8097d608 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -944,7 +944,8 @@ def test_wait_for_compilation(cartesian_case, compile_testee, compile_testee_dom compile_testee_domain.compile(offset_provider=cartesian_case.offset_provider) -def test_compile_variants_decorator_static_domains(cartesian_case): +@pytest.mark.parametrize("precompile", [True, False], ids=["precompile", "run"]) +def test_compile_variants_decorator_static_domains(cartesian_case, precompile): if cartesian_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") @@ -973,68 +974,19 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo with mock.patch.object(compiled_program, "_async_compilation_pool", None): inp = cases.allocate(cartesian_case, testee, "inp")() out = cases.allocate(cartesian_case, testee, "out")() - - testee(inp, out, offset_provider={}) - assert np.allclose(inp[0].ndarray, out[0].ndarray) - assert np.allclose(inp[1].ndarray, out[1].ndarray) - - assert testee._compiled_programs.argument_descriptor_mapping[ - arguments.FieldDomainDescriptor - ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] - assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { - "inp": ( - arguments.FieldDomainDescriptor(inp[0].domain), - arguments.FieldDomainDescriptor(inp[1].domain), - None, - ), - "out": ( - arguments.FieldDomainDescriptor(out[0].domain), - arguments.FieldDomainDescriptor(out[1].domain), - ), - } - - -def test_compile_with_static_domains(compile_variants_field_operator, cartesian_case): - if cartesian_case.backend is None: - pytest.skip("Embedded compiled program doesn't make sense.") - - captured_cargs: Optional[arguments.CompileTimeArgs] = None - - class CaptureCompileTimeArgsBackend: - def __getattr__(self, name): - return getattr(cartesian_case.backend, name) - - def compile(self, program, compile_time_args): - nonlocal captured_cargs - captured_cargs = compile_time_args - - return cartesian_case.backend.compile(program, compile_time_args) - - @gtx.field_operator - def identity_like(inp: tuple[cases.IField, cases.IField, float]): - return inp[0], inp[1] - - # the float argument here is merely to test that static domains work for tuple arguments - # of inhomogeneous types - @gtx.program(backend=CaptureCompileTimeArgsBackend(), static_domains=True) - def testee( - inp: tuple[cases.IField, cases.IField, float], out: tuple[cases.IField, cases.IField] - ): - identity_like(inp, out=out) - - with mock.patch.object(compiled_program, "_async_compilation_pool", None): - inp = cases.allocate(cartesian_case, testee, "inp")() - out = cases.allocate(cartesian_case, testee, "out")() - - testee.compile( - offset_provider=cartesian_case.offset_provider, - static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()}, - ) + if precompile: + testee.compile( + offset_provider=cartesian_case.offset_provider, + static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()}, + ) + else: + testee(inp, out, offset_provider={}) + assert np.allclose(inp[0].ndarray, out[0].ndarray) + assert np.allclose(inp[1].ndarray, out[1].ndarray) assert testee._compiled_programs.argument_descriptor_mapping[ arguments.FieldDomainDescriptor ] == ["inp[0]", "inp[1]", "out[0]", "out[1]"] - assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == { "inp": ( arguments.FieldDomainDescriptor(inp[0].domain), From c2d0e9dc87f06d07c9f27050aa6445feb7f2a71b Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Wed, 18 Feb 2026 14:35:46 +0100 Subject: [PATCH 5/6] Formatting --- .../feature_tests/ffront_tests/test_compiled_program.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 6f8097d608..7b6fa0eab0 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -977,7 +977,9 @@ def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCo if precompile: testee.compile( offset_provider=cartesian_case.offset_provider, - static_domains={dim: (0, size) for dim, size in cartesian_case.default_sizes.items()}, + static_domains={ + dim: (0, size) for dim, size in cartesian_case.default_sizes.items() + }, ) else: testee(inp, out, offset_provider={}) From 9bc34cd3d8a42c7c973617dc62a102f174fdf6fc Mon Sep 17 00:00:00 2001 From: Sara Faghih-Naini Date: Thu, 5 Mar 2026 12:30:49 +0100 Subject: [PATCH 6/6] Update static_domains type to support per-field domain specifications --- src/gt4py/next/ffront/decorator.py | 4 +- src/gt4py/next/otf/compiled_program.py | 52 +++++++++++++++++++++----- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index c8a08ef105..ad19a14c2e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -153,7 +153,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, - static_domains: dict[common.Dimension, tuple[int, int]] | None = None, + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: """ @@ -497,7 +497,7 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, - static_domains: dict[common.Dimension, tuple[int, int]] | None = None, + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: raise NotImplementedError("Compilation of programs with bound arguments is not implemented") diff --git a/src/gt4py/next/otf/compiled_program.py b/src/gt4py/next/otf/compiled_program.py index 932e108943..902c2ca19d 100644 --- a/src/gt4py/next/otf/compiled_program.py +++ b/src/gt4py/next/otf/compiled_program.py @@ -605,12 +605,21 @@ def _compile_variant( def compile( self, offset_providers: list[common.OffsetProvider | common.OffsetProviderType], - static_domains: dict[common.Dimension, tuple[int, int]] | None = None, + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, **static_args: list[ScalarOrTupleOfScalars], ) -> None: """ Compiles the program for all combinations of static arguments and the given 'OffsetProviderType'. + Args: + offset_providers: List of offset providers to compile for. + static_domains: Per-field domain specification. + Example:: + { + "input_field": {IDim: (0, 55), JDim: (0, 30)}, + "out": {IHalfDim: (0, 52), JDim: (0, 30)}, + } + Note: In case you want to compile for specific combinations of static arguments (instead of the combinatoral), you can call compile multiples times. @@ -623,10 +632,12 @@ def compile( def _build_field_domain_descriptors( program_type: ts_ffront.ProgramType, - static_domains: dict[common.Dimension, tuple[int, int]], + static_domains: dict[str, dict[common.Dimension, tuple[int, int]]], ) -> dict[str, arguments.FieldDomainDescriptor]: field_domain_descriptors: dict[str, arguments.FieldDomainDescriptor] = {} assert program_type.definition.pos_only_args == [] + + matched_keys = set() param_types = ( program_type.definition.pos_or_kw_args | program_type.definition.kw_only_args ) @@ -636,18 +647,39 @@ def _build_field_domain_descriptors( ): if isinstance(el_type_, ts.FieldType): path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) - if missing_dims := [ - dim for dim in el_type_.dims if dim not in static_domains - ]: + field_expr = f"{arg_name}{path_as_expr}" + + if field_expr not in static_domains: raise ValueError( - f"Missing domain specification for dimension(s) {missing_dims} for {arg_name}{path_as_expr}. " - f"Field has dimensions {list(el_type_.dims)}, but static_domains only contains {list(static_domains.keys())}." + f"Missing static domain for field '{field_expr}'. " + f"Expected domains for all field parameters. " + f"Available keys in static_domains: {list(static_domains.keys())}" ) - domain_ranges = {dim: static_domains[dim] for dim in el_type_.dims} - field_domain_descriptors[f"{arg_name}{path_as_expr}"] = ( - arguments.FieldDomainDescriptor(common.domain(domain_ranges)) + + field_domain = static_domains[field_expr] + matched_keys.add(field_expr) + + expected_dims = set(el_type_.dims) + provided_dims = set(field_domain.keys()) + if expected_dims != provided_dims: + raise ValueError( + f"Domain dimension mismatch for field '{field_expr}': " + f"field has dimensions {set(d.value for d in expected_dims)}, " + f"but static_domains provides {set(d.value for d in provided_dims)}." + ) + + domain_ranges = {dim: field_domain[dim] for dim in el_type_.dims} + field_domain_descriptors[field_expr] = arguments.FieldDomainDescriptor( + common.domain(domain_ranges) ) + extra_keys = set(static_domains.keys()) - matched_keys + if extra_keys: + raise ValueError( + f"static_domains contains keys that do not correspond to any field parameter: " + f"{list(extra_keys)}. Valid field expressions are: {list(matched_keys)}" + ) + return field_domain_descriptors for offset_provider in offset_providers: # not included in product for better type checking