Skip to content
17 changes: 16 additions & 1 deletion src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def compile(
| common.OffsetProvider
| list[common.OffsetProviderType | common.OffsetProvider]
| None = None,
static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None,
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
) -> Self:
"""
Expand Down Expand Up @@ -194,12 +195,25 @@ def compile(
for op in offset_provider
)

self._compiled_programs.compile(offset_providers=offset_provider, **static_args)
if self.compilation_options.static_domains and static_domains is None:
raise ValueError(
"Static domains option is enabled, but no static domain information was "
"provided. Missing required argument 'static_domains'."
)
if not self.compilation_options.static_domains and static_domains is not None:
raise ValueError(
"Static domains may not be provided when 'static_domains' option is disabled."
)

self._compiled_programs.compile(
offset_providers=offset_provider, static_domains=static_domains, **static_args
)
return self


def _field_domain_descriptor_mapping_from_func_type(func_type: ts.FunctionType) -> list[str]:
static_domain_args = []
assert func_type.pos_only_args == []
param_types = func_type.pos_or_kw_args | func_type.kw_only_args
for name, type_ in param_types.items():
for el_type_, path in type_info.primitive_constituents(type_, with_path_arg=True):
Expand Down Expand Up @@ -483,6 +497,7 @@ def compile(
| common.OffsetProvider
| list[common.OffsetProviderType | common.OffsetProvider]
| None = None,
static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None,
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
) -> Self:
raise NotImplementedError("Compilation of programs with bound arguments is not implemented")
Expand Down
91 changes: 81 additions & 10 deletions src/gt4py/next/otf/compiled_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
)
from gt4py.next.instrumentation import hook_machinery, metrics
from gt4py.next.otf import arguments, stages
from gt4py.next.type_system import type_specifications as ts
from gt4py.next.type_system import type_info, type_specifications as ts
from gt4py.next.utils import tree_map


Expand Down Expand Up @@ -605,11 +605,21 @@ def _compile_variant(
def compile(
self,
offset_providers: list[common.OffsetProvider | common.OffsetProviderType],
static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None,
**static_args: list[ScalarOrTupleOfScalars],
) -> None:
"""
Compiles the program for all combinations of static arguments and the given 'OffsetProviderType'.

Args:
offset_providers: List of offset providers to compile for.
static_domains: Per-field domain specification.
Example::
{
"input_field": {IDim: (0, 55), JDim: (0, 30)},
"out": {IHalfDim: (0, 52), JDim: (0, 30)},
}

Note: In case you want to compile for specific combinations of static arguments (instead
of the combinatoral), you can call compile multiples times.

Expand All @@ -619,17 +629,78 @@ def compile(
pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3])
will compile for (0,2), (1,3)
"""

def _build_field_domain_descriptors(
program_type: ts_ffront.ProgramType,
static_domains: dict[str, dict[common.Dimension, tuple[int, int]]],
) -> dict[str, arguments.FieldDomainDescriptor]:
field_domain_descriptors: dict[str, arguments.FieldDomainDescriptor] = {}
assert program_type.definition.pos_only_args == []

matched_keys = set()
param_types = (
program_type.definition.pos_or_kw_args | program_type.definition.kw_only_args
)
for arg_name, arg_type_ in param_types.items():
for el_type_, path in type_info.primitive_constituents(
arg_type_, with_path_arg=True
):
if isinstance(el_type_, ts.FieldType):
path_as_expr = "".join(map(lambda idx: f"[{idx}]", path))
field_expr = f"{arg_name}{path_as_expr}"

if field_expr not in static_domains:
raise ValueError(
f"Missing static domain for field '{field_expr}'. "
f"Expected domains for all field parameters. "
f"Available keys in static_domains: {list(static_domains.keys())}"
)

field_domain = static_domains[field_expr]
matched_keys.add(field_expr)

expected_dims = set(el_type_.dims)
provided_dims = set(field_domain.keys())
if expected_dims != provided_dims:
raise ValueError(
f"Domain dimension mismatch for field '{field_expr}': "
f"field has dimensions {set(d.value for d in expected_dims)}, "
f"but static_domains provides {set(d.value for d in provided_dims)}."
)

domain_ranges = {dim: field_domain[dim] for dim in el_type_.dims}
field_domain_descriptors[field_expr] = arguments.FieldDomainDescriptor(
common.domain(domain_ranges)
)

extra_keys = set(static_domains.keys()) - matched_keys
if extra_keys:
raise ValueError(
f"static_domains contains keys that do not correspond to any field parameter: "
f"{list(extra_keys)}. Valid field expressions are: {list(matched_keys)}"
)

return field_domain_descriptors

for offset_provider in offset_providers: # not included in product for better type checking
for static_values in itertools.product(*static_args.values()):
argument_descriptor_dict: dict[
type[arguments.ArgStaticDescriptor],
dict[str, arguments.ArgStaticDescriptor],
] = {
arguments.StaticArg: dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
),
}
if static_domains:
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

static_domains is checked via truthiness (if static_domains:), so passing an empty dict (which is still a non-None value per the public API) will silently skip building FieldDomainDescriptors. This can lead to compiling a variant without any domain descriptors even though the caller explicitly provided static_domains. Consider switching this condition to if static_domains is not None: (and let _build_field_domain_descriptors raise for missing dims as needed).

Suggested change
if static_domains:
if static_domains is not None:

Copilot uses AI. Check for mistakes.
argument_descriptor_dict[arguments.FieldDomainDescriptor] = (
_build_field_domain_descriptors(self.program_type, static_domains) # type: ignore[assignment]
)
Comment on lines +700 to +702
Copy link

Copilot AI Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a # type: ignore[assignment] on the FieldDomainDescriptor insertion because _build_field_domain_descriptors returns dict[str, FieldDomainDescriptor] while argument_descriptor_dict is typed as dict[str, ArgStaticDescriptor]. Consider widening the helper’s return type (e.g., to dict[str, arguments.ArgStaticDescriptor]) or casting at the assignment site to avoid relying on type: ignore here.

Copilot uses AI. Check for mistakes.
self._compile_variant(
argument_descriptors={
arguments.StaticArg: dict(
zip(
static_args.keys(),
[arguments.StaticArg(value=v) for v in static_values],
strict=True,
)
),
},
argument_descriptor_dict,
offset_provider=offset_provider,
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
skip_value_mesh,
)

from gt4py.next.otf import arguments

_raise_on_compile = mock.Mock()
_raise_on_compile.compile.side_effect = AssertionError("This function should never be called.")
Expand All @@ -62,6 +61,7 @@ def testee(a: cases.IField, b: cases.IField, out: cases.IField):
testee_op(a, b, out=out)

wrap_in_program = request.param

if wrap_in_program:
return testee
else:
Expand Down Expand Up @@ -944,7 +944,8 @@ def test_wait_for_compilation(cartesian_case, compile_testee, compile_testee_dom
compile_testee_domain.compile(offset_provider=cartesian_case.offset_provider)


def test_compile_variants_decorator_static_domains(cartesian_case):
@pytest.mark.parametrize("precompile", [True, False], ids=["precompile", "run"])
def test_compile_variants_decorator_static_domains(cartesian_case, precompile):
if cartesian_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

Expand All @@ -970,24 +971,32 @@ def identity_like(inp: tuple[cases.IField, cases.IField, float]):
def testee(inp: tuple[cases.IField, cases.IField, float], out: NamedTupleNamedCollection):
identity_like(inp, out=out)

inp = cases.allocate(cartesian_case, testee, "inp")()
out = cases.allocate(cartesian_case, testee, "out")()

testee(inp, out, offset_provider={})
assert np.allclose(inp[0].ndarray, out[0].ndarray)
assert np.allclose(inp[1].ndarray, out[1].ndarray)

assert testee._compiled_programs.argument_descriptor_mapping[
arguments.FieldDomainDescriptor
] == ["inp[0]", "inp[1]", "out[0]", "out[1]"]
assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == {
"inp": (
arguments.FieldDomainDescriptor(inp[0].domain),
arguments.FieldDomainDescriptor(inp[1].domain),
None,
),
"out": (
arguments.FieldDomainDescriptor(out[0].domain),
arguments.FieldDomainDescriptor(out[1].domain),
),
}
with mock.patch.object(compiled_program, "_async_compilation_pool", None):
inp = cases.allocate(cartesian_case, testee, "inp")()
out = cases.allocate(cartesian_case, testee, "out")()
if precompile:
testee.compile(
offset_provider=cartesian_case.offset_provider,
static_domains={
dim: (0, size) for dim, size in cartesian_case.default_sizes.items()
},
)
else:
testee(inp, out, offset_provider={})
assert np.allclose(inp[0].ndarray, out[0].ndarray)
assert np.allclose(inp[1].ndarray, out[1].ndarray)

assert testee._compiled_programs.argument_descriptor_mapping[
arguments.FieldDomainDescriptor
] == ["inp[0]", "inp[1]", "out[0]", "out[1]"]
assert captured_cargs.argument_descriptor_contexts[arguments.FieldDomainDescriptor] == {
"inp": (
arguments.FieldDomainDescriptor(inp[0].domain),
arguments.FieldDomainDescriptor(inp[1].domain),
None,
),
"out": (
arguments.FieldDomainDescriptor(out[0].domain),
arguments.FieldDomainDescriptor(out[1].domain),
),
}
Loading