diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml index 8b367d1..93a7078 100644 --- a/.github/workflows/type-check.yml +++ b/.github/workflows/type-check.yml @@ -10,4 +10,4 @@ jobs: contents: read steps: - uses: actions/checkout@v6 - - run: pipx run uv tool run --with .[typecheck] ty check + - run: pipx run uv tool run --with .[typecheck,yaml] ty check diff --git a/README.md b/README.md index ec0d053..3e1a60c 100644 --- a/README.md +++ b/README.md @@ -31,3 +31,4 @@ Please refer to test files for detailed usage. - Native integration with [Sphinx](https://github.com/sphinx-doc/sphinx), [DP-GUI](https://github.com/deepmodeling/dpgui), and [Jupyter Notebook](https://jupyter.org/) - JSON encoder for `Argument` and `Variant` classes - Generate [JSON schema](https://json-schema.org/) from an `Argument`, which can be further integrated with JSON editors such as [Visual Studio Code](https://code.visualstudio.com/) +- Load dict values from external JSON/YAML files via the `$ref` key diff --git a/dargs/check.py b/dargs/check.py index 5e2a3c3..eeff6cf 100644 --- a/dargs/check.py +++ b/dargs/check.py @@ -10,6 +10,7 @@ def check( data: dict, strict: bool = True, trim_pattern: str = "_*", + allow_ref: bool = False, ) -> dict: """Check and normalize input data. @@ -23,6 +24,9 @@ def check( If True, raise an error if the key is not pre-defined, by default True trim_pattern : str, optional Pattern to trim the key, by default "_*" + allow_ref : bool, optional + If True, allow loading from external files via the ``$ref`` key, + by default False. Returns ------- @@ -34,6 +38,6 @@ def check( "base", dtype=dict, sub_fields=cast("list[Argument]", arginfo) ) - data = arginfo.normalize_value(data, trim_pattern=trim_pattern) - arginfo.check_value(data, strict=strict) + data = arginfo.normalize_value(data, trim_pattern=trim_pattern, allow_ref=allow_ref) + arginfo.check_value(data, strict=strict, allow_ref=allow_ref) return data diff --git a/dargs/cli.py b/dargs/cli.py index d2d6f08..3273463 100644 --- a/dargs/cli.py +++ b/dargs/cli.py @@ -52,6 +52,12 @@ def main_parser() -> argparse.ArgumentParser: default="_*", help="Pattern to trim the key", ) + parser_check.add_argument( + "--allow-ref", + action="store_true", + dest="allow_ref", + help="Allow loading from external files via the $ref key", + ) parser_check.set_defaults(entrypoint=check_cli) # doc subcommand @@ -92,6 +98,7 @@ def check_cli( func: str, jdata: list[IO], strict: bool, + allow_ref: bool = False, **kwargs: Any, ) -> None: """Normalize and check input data. @@ -104,6 +111,8 @@ def check_cli( File object that contains the JSON data strict : bool If True, raise an error if the key is not pre-defined + allow_ref : bool, optional + If True, allow loading from external files via the ``$ref`` key Returns ------- @@ -124,7 +133,7 @@ def check_cli( arginfo = func_obj() for jj in jdata: data = json.load(jj) - check(arginfo, data, strict=strict) + check(arginfo, data, strict=strict, allow_ref=allow_ref) def doc_cli( diff --git a/dargs/dargs.py b/dargs/dargs.py index bf0cf51..f067b01 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -21,6 +21,7 @@ import difflib import fnmatch import json +import os import re from copy import deepcopy from enum import Enum @@ -335,6 +336,7 @@ def traverse( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, + allow_ref: bool = False, ) -> None: # first, do something with the key # then, take out the vaule and do something with it @@ -347,7 +349,7 @@ def traverse( newpath = [*path, self.name] # this is the key step that we traverse into the tree self.traverse_value( - value, key_hook, value_hook, sub_hook, variant_hook, newpath + value, key_hook, value_hook, sub_hook, variant_hook, newpath, allow_ref ) def traverse_value( @@ -358,6 +360,7 @@ def traverse_value( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, + allow_ref: bool = False, ) -> None: # this is not private, and can be called directly # in the condition where there is no leading key @@ -365,7 +368,7 @@ def traverse_value( path = [] if not self.repeat and isinstance(value, dict): self._traverse_sub( - value, key_hook, value_hook, sub_hook, variant_hook, path + value, key_hook, value_hook, sub_hook, variant_hook, path, allow_ref ) elif self.repeat and isinstance(value, list): for idx, item in enumerate(value): @@ -376,6 +379,7 @@ def traverse_value( sub_hook, variant_hook, [*path, str(idx)], + allow_ref, ) elif self.repeat and isinstance(value, dict): for kk, item in value.items(): @@ -386,6 +390,7 @@ def traverse_value( sub_hook, variant_hook, [*path, kk], + allow_ref, ) def _traverse_sub( @@ -396,6 +401,7 @@ def _traverse_sub( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, + allow_ref: bool = False, ) -> None: if path is None: path = [self.name] @@ -405,16 +411,21 @@ def _traverse_sub( f"key `{path[-1]}` gets wrong value type, " f"requires dict but {type(value).__name__} is given", ) + _resolve_ref(value, allow_ref) sub_hook(self, value, path) for subvrnt in self.sub_variants.values(): variant_hook(subvrnt, value, path) for subarg in self.flatten_sub(value, path).values(): - subarg.traverse(value, key_hook, value_hook, sub_hook, variant_hook, path) + subarg.traverse( + value, key_hook, value_hook, sub_hook, variant_hook, path, allow_ref + ) # above are general traverse part # below are type checking part - def check(self, argdict: dict, strict: bool = False) -> None: + def check( + self, argdict: dict, strict: bool = False, allow_ref: bool = False + ) -> None: """Check whether `argdict` meets the structure defined in self. Will recursively check nested dicts according to @@ -426,6 +437,10 @@ def check(self, argdict: dict, strict: bool = False) -> None: The arg dict to be checked strict : bool, optional If true, only keys defined in `Argument` are allowed. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. + A deep copy of ``argdict`` is made internally so the caller's + data is not mutated. """ if strict and len(argdict) != 1: raise ArgumentKeyError( @@ -434,14 +449,19 @@ def check(self, argdict: dict, strict: bool = False) -> None: "for check in strict mode at top level, " "use check_value if you are checking subfields", ) + if allow_ref: + argdict = deepcopy(argdict) self.traverse( argdict, key_hook=Argument._check_exist, value_hook=Argument._check_data, sub_hook=Argument._check_strict if strict else _DUMMYHOOK, + allow_ref=allow_ref, ) - def check_value(self, value: Any, strict: bool = False) -> None: + def check_value( + self, value: Any, strict: bool = False, allow_ref: bool = False + ) -> None: """Check the value without the leading key. Same as `check({self.name: value})`. @@ -453,12 +473,19 @@ def check_value(self, value: Any, strict: bool = False) -> None: The value to be checked strict : bool, optional If true, only keys defined in `Argument` are allowed. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. + A deep copy of ``value`` is made internally so the caller's + data is not mutated. """ + if allow_ref: + value = deepcopy(value) self.traverse_value( value, key_hook=Argument._check_exist, value_hook=Argument._check_data, sub_hook=Argument._check_strict if strict else _DUMMYHOOK, + allow_ref=allow_ref, ) def _check_exist(self, argdict: dict, path: list[str] | None = None) -> None: @@ -518,6 +545,7 @@ def normalize( do_default: bool = True, do_alias: bool = True, trim_pattern: str | None = None, + allow_ref: bool = False, ) -> dict: """Modify `argdict` so that it meets the Argument structure. @@ -537,6 +565,8 @@ def normalize( Whether to transform alias names. trim_pattern : str, optional If given, discard keys that matches the glob pattern. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. Returns ------- @@ -550,10 +580,15 @@ def normalize( argdict, key_hook=Argument._convert_alias, variant_hook=Variant._convert_choice_alias, + allow_ref=allow_ref, ) if do_default: - self.traverse(argdict, key_hook=Argument._assign_default) - self.traverse(argdict, key_hook=Argument._handle_empty_dict) + self.traverse( + argdict, key_hook=Argument._assign_default, allow_ref=allow_ref + ) + self.traverse( + argdict, key_hook=Argument._handle_empty_dict, allow_ref=allow_ref + ) if trim_pattern is not None: trim_by_pattern(argdict, trim_pattern, reserved=[self.name]) self.traverse( @@ -561,6 +596,7 @@ def normalize( sub_hook=lambda a, d, p: trim_by_pattern( d, trim_pattern, a.flatten_sub(d, p).keys() ), + allow_ref=allow_ref, ) return argdict @@ -571,6 +607,7 @@ def normalize_value( do_default: bool = True, do_alias: bool = True, trim_pattern: str | None = None, + allow_ref: bool = False, ) -> Any: """Modify the value so that it meets the Argument structure. @@ -588,6 +625,8 @@ def normalize_value( Whether to transform alias names. trim_pattern : str, optional If given, discard keys that matches the glob pattern. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. Returns ------- @@ -601,16 +640,22 @@ def normalize_value( value, key_hook=Argument._convert_alias, variant_hook=Variant._convert_choice_alias, + allow_ref=allow_ref, ) if do_default: - self.traverse_value(value, key_hook=Argument._assign_default) - self.traverse_value(value, key_hook=Argument._handle_empty_dict) + self.traverse_value( + value, key_hook=Argument._assign_default, allow_ref=allow_ref + ) + self.traverse_value( + value, key_hook=Argument._handle_empty_dict, allow_ref=allow_ref + ) if trim_pattern is not None: self.traverse_value( value, sub_hook=lambda a, d, p: trim_by_pattern( d, trim_pattern, a.flatten_sub(d, p).keys() ), + allow_ref=allow_ref, ) return value @@ -1065,6 +1110,100 @@ def trim_by_pattern( argdict.pop(key) +def _load_ref(ref_path: str) -> dict: + """Load a dict from an external file referenced by ``$ref``. + + Parameters + ---------- + ref_path : str + Path to the external file. Supported extensions: ``.json``, ``.yml``, ``.yaml``. + + Returns + ------- + dict + The loaded dict from the external file. + + Raises + ------ + ValueError + If the file extension is not supported, or if the file does not contain a + top-level mapping/object. + ImportError + If pyyaml is not installed and a YAML file is requested. + """ + ext = os.path.splitext(ref_path)[1].lower() + if ext == ".json": + with open(ref_path, encoding="utf-8") as f: + loaded = json.load(f) + elif ext in (".yml", ".yaml"): + try: + import yaml + except ImportError as e: + raise ImportError( + "pyyaml is required to load YAML files referenced by $ref. " + "Install it with: pip install pyyaml" + ) from e + with open(ref_path, encoding="utf-8") as f: + loaded = yaml.safe_load(f) + else: + raise ValueError( + f"Unsupported file extension `{ext}` for $ref. " + "Supported extensions are: .json, .yml, .yaml" + ) + if not isinstance(loaded, dict): + raise ValueError( + f"Referenced file {ref_path!r} must contain a mapping/object at the top " + f"level, but got {type(loaded).__name__!r}." + ) + return loaded + + +def _resolve_ref(d: dict, allow_ref: bool = False) -> None: + """Resolve the ``$ref`` key in a dict by loading from an external file. + + If ``$ref`` is present in ``d``, its value is treated as a file path. + The file is loaded and its contents are merged into ``d``. Keys already + present in ``d`` (other than ``$ref``) take precedence over keys from the + loaded file, allowing local overrides. Chained ``$ref`` values in the + loaded content are resolved in turn. Cyclic references are detected and + raise a ``ValueError``. + + The dict is modified **in place**. + + Parameters + ---------- + d : dict + The dict that may contain a ``$ref`` key. + allow_ref : bool, optional + If False (the default), raise a ``ValueError`` when ``$ref`` is found. + Set to True to enable loading from external files. + + Raises + ------ + ValueError + If ``$ref`` is found but ``allow_ref`` is False, or if a cyclic + reference is detected. + """ + if "$ref" not in d: + return + if not allow_ref: + raise ValueError( + "$ref is not allowed by default. " + "Pass allow_ref=True to enable loading from external files." + ) + visited_refs: set[str] = set() + while "$ref" in d: + ref_path = d.pop("$ref") + if ref_path in visited_refs: + raise ValueError(f"Cyclic $ref detected for path: {ref_path!r}") + visited_refs.add(ref_path) + loaded = _load_ref(ref_path) + # Merge: loaded content as base, local keys take precedence + merged = {**loaded, **d} + d.clear() + d.update(merged) + + def isinstance_annotation(value: Any, dtype: type | Any) -> bool: """Same as isinstance(), but supports arbitrary type annotations.""" try: diff --git a/dargs/notebook.py b/dargs/notebook.py index 80cb4f5..c705dbf 100644 --- a/dargs/notebook.py +++ b/dargs/notebook.py @@ -27,6 +27,7 @@ from IPython.display import HTML, display from dargs import Argument, Variant +from dargs.dargs import _resolve_ref __all__ = ["JSON"] @@ -90,7 +91,9 @@ """ -def JSON(data: dict | str, arg: Argument | list[Argument]) -> None: +def JSON( + data: dict | str, arg: Argument | list[Argument], allow_ref: bool = False +) -> None: """Display JSON data with Argument in the Jupyter Notebook. Parameters @@ -99,11 +102,15 @@ def JSON(data: dict | str, arg: Argument | list[Argument]) -> None: The JSON data to be displayed, either JSON string or a dict. arg : dargs.Argument or list[dargs.Argument] The Argument that describes the JSON data. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. """ - display(HTML(print_html(data, arg))) + display(HTML(print_html(data, arg, allow_ref=allow_ref))) -def print_html(data: Any, arg: Argument | list[Argument]) -> str: +def print_html( + data: Any, arg: Argument | list[Argument], allow_ref: bool = False +) -> str: """Print HTML string with Argument in the Jupyter Notebook. Parameters @@ -112,6 +119,8 @@ def print_html(data: Any, arg: Argument | list[Argument]) -> str: The JSON data to be displayed, either JSON string or a dict. arg : dargs.Argument or list[dargs.Argument] The Argument that describes the JSON data. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. Returns ------- @@ -131,7 +140,7 @@ def print_html(data: Any, arg: Argument | list[Argument]) -> str: pass else: raise ValueError(f"Unknown type: {type(arg)}") - argdata = ArgumentData(data, arg) + argdata = ArgumentData(data, arg, allow_ref=allow_ref) buff = [css, r"""