-
Notifications
You must be signed in to change notification settings - Fork 55
feat[next]: Precompilation with static domains #2483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
dc6dbb7
c3e0453
46a855e
57db00a
694d927
bf02bb8
c2d0e9d
3b6c14d
9bc34cd
bb49222
5a5adf5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,7 +28,7 @@ | |
| ) | ||
| from gt4py.next.instrumentation import hook_machinery, metrics | ||
| from gt4py.next.otf import arguments, stages | ||
| from gt4py.next.type_system import type_specifications as ts | ||
| from gt4py.next.type_system import type_info, type_specifications as ts | ||
| from gt4py.next.utils import tree_map | ||
|
|
||
|
|
||
|
|
@@ -605,11 +605,21 @@ def _compile_variant( | |
| def compile( | ||
| self, | ||
| offset_providers: list[common.OffsetProvider | common.OffsetProviderType], | ||
| static_domains: dict[str, dict[common.Dimension, tuple[int, int]]] | None = None, | ||
| **static_args: list[ScalarOrTupleOfScalars], | ||
| ) -> None: | ||
| """ | ||
| Compiles the program for all combinations of static arguments and the given 'OffsetProviderType'. | ||
|
|
||
| Args: | ||
| offset_providers: List of offset providers to compile for. | ||
| static_domains: Per-field domain specification. | ||
| Example:: | ||
| { | ||
| "input_field": {IDim: (0, 55), JDim: (0, 30)}, | ||
| "out": {IHalfDim: (0, 52), JDim: (0, 30)}, | ||
| } | ||
|
|
||
| Note: In case you want to compile for specific combinations of static arguments (instead | ||
| of the combinatoral), you can call compile multiples times. | ||
|
|
||
|
|
@@ -619,17 +629,78 @@ def compile( | |
| pool.compile(static_arg0=[0], static_arg1=[2]).compile(static_arg=[1], static_arg1=[3]) | ||
| will compile for (0,2), (1,3) | ||
| """ | ||
|
|
||
| def _build_field_domain_descriptors( | ||
| program_type: ts_ffront.ProgramType, | ||
| static_domains: dict[str, dict[common.Dimension, tuple[int, int]]], | ||
| ) -> dict[str, arguments.FieldDomainDescriptor]: | ||
| field_domain_descriptors: dict[str, arguments.FieldDomainDescriptor] = {} | ||
| assert program_type.definition.pos_only_args == [] | ||
|
|
||
| matched_keys = set() | ||
| param_types = ( | ||
| program_type.definition.pos_or_kw_args | program_type.definition.kw_only_args | ||
| ) | ||
| for arg_name, arg_type_ in param_types.items(): | ||
| for el_type_, path in type_info.primitive_constituents( | ||
| arg_type_, with_path_arg=True | ||
| ): | ||
| if isinstance(el_type_, ts.FieldType): | ||
| path_as_expr = "".join(map(lambda idx: f"[{idx}]", path)) | ||
| field_expr = f"{arg_name}{path_as_expr}" | ||
|
|
||
| if field_expr not in static_domains: | ||
| raise ValueError( | ||
| f"Missing static domain for field '{field_expr}'. " | ||
| f"Expected domains for all field parameters. " | ||
| f"Available keys in static_domains: {list(static_domains.keys())}" | ||
| ) | ||
|
|
||
| field_domain = static_domains[field_expr] | ||
| matched_keys.add(field_expr) | ||
|
|
||
| expected_dims = set(el_type_.dims) | ||
| provided_dims = set(field_domain.keys()) | ||
| if expected_dims != provided_dims: | ||
| raise ValueError( | ||
| f"Domain dimension mismatch for field '{field_expr}': " | ||
| f"field has dimensions {set(d.value for d in expected_dims)}, " | ||
| f"but static_domains provides {set(d.value for d in provided_dims)}." | ||
| ) | ||
|
|
||
| domain_ranges = {dim: field_domain[dim] for dim in el_type_.dims} | ||
| field_domain_descriptors[field_expr] = arguments.FieldDomainDescriptor( | ||
| common.domain(domain_ranges) | ||
| ) | ||
|
|
||
| extra_keys = set(static_domains.keys()) - matched_keys | ||
| if extra_keys: | ||
| raise ValueError( | ||
| f"static_domains contains keys that do not correspond to any field parameter: " | ||
| f"{list(extra_keys)}. Valid field expressions are: {list(matched_keys)}" | ||
| ) | ||
|
|
||
| return field_domain_descriptors | ||
|
|
||
| for offset_provider in offset_providers: # not included in product for better type checking | ||
| for static_values in itertools.product(*static_args.values()): | ||
| argument_descriptor_dict: dict[ | ||
| type[arguments.ArgStaticDescriptor], | ||
| dict[str, arguments.ArgStaticDescriptor], | ||
| ] = { | ||
| arguments.StaticArg: dict( | ||
| zip( | ||
| static_args.keys(), | ||
| [arguments.StaticArg(value=v) for v in static_values], | ||
| strict=True, | ||
| ) | ||
| ), | ||
| } | ||
| if static_domains: | ||
| argument_descriptor_dict[arguments.FieldDomainDescriptor] = ( | ||
| _build_field_domain_descriptors(self.program_type, static_domains) # type: ignore[assignment] | ||
| ) | ||
SF-N marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
+700
to
+702
|
||
| self._compile_variant( | ||
| argument_descriptors={ | ||
| arguments.StaticArg: dict( | ||
| zip( | ||
| static_args.keys(), | ||
| [arguments.StaticArg(value=v) for v in static_values], | ||
| strict=True, | ||
| ) | ||
| ), | ||
| }, | ||
| argument_descriptor_dict, | ||
| offset_provider=offset_provider, | ||
| ) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_domainsis checked via truthiness (if static_domains:), so passing an empty dict (which is still a non-Nonevalue per the public API) will silently skip buildingFieldDomainDescriptors. This can lead to compiling a variant without any domain descriptors even though the caller explicitly providedstatic_domains. Consider switching this condition toif static_domains is not None:(and let_build_field_domain_descriptorsraise for missing dims as needed).