diff --git a/.github/workflows/type-check.yml b/.github/workflows/type-check.yml index 8b367d1..93a7078 100644 --- a/.github/workflows/type-check.yml +++ b/.github/workflows/type-check.yml @@ -10,4 +10,4 @@ jobs: contents: read steps: - uses: actions/checkout@v6 - - run: pipx run uv tool run --with .[typecheck] ty check + - run: pipx run uv tool run --with .[typecheck,yaml] ty check diff --git a/README.md b/README.md index ec0d053..3e1a60c 100644 --- a/README.md +++ b/README.md @@ -31,3 +31,4 @@ Please refer to test files for detailed usage. - Native integration with [Sphinx](https://github.com/sphinx-doc/sphinx), [DP-GUI](https://github.com/deepmodeling/dpgui), and [Jupyter Notebook](https://jupyter.org/) - JSON encoder for `Argument` and `Variant` classes - Generate [JSON schema](https://json-schema.org/) from an `Argument`, which can be further integrated with JSON editors such as [Visual Studio Code](https://code.visualstudio.com/) +- Load dict values from external JSON/YAML files via the `$ref` key diff --git a/dargs/check.py b/dargs/check.py index 5e2a3c3..eeff6cf 100644 --- a/dargs/check.py +++ b/dargs/check.py @@ -10,6 +10,7 @@ def check( data: dict, strict: bool = True, trim_pattern: str = "_*", + allow_ref: bool = False, ) -> dict: """Check and normalize input data. @@ -23,6 +24,9 @@ def check( If True, raise an error if the key is not pre-defined, by default True trim_pattern : str, optional Pattern to trim the key, by default "_*" + allow_ref : bool, optional + If True, allow loading from external files via the ``$ref`` key, + by default False. Returns ------- @@ -34,6 +38,6 @@ def check( "base", dtype=dict, sub_fields=cast("list[Argument]", arginfo) ) - data = arginfo.normalize_value(data, trim_pattern=trim_pattern) - arginfo.check_value(data, strict=strict) + data = arginfo.normalize_value(data, trim_pattern=trim_pattern, allow_ref=allow_ref) + arginfo.check_value(data, strict=strict, allow_ref=allow_ref) return data diff --git a/dargs/cli.py b/dargs/cli.py index d2d6f08..3273463 100644 --- a/dargs/cli.py +++ b/dargs/cli.py @@ -52,6 +52,12 @@ def main_parser() -> argparse.ArgumentParser: default="_*", help="Pattern to trim the key", ) + parser_check.add_argument( + "--allow-ref", + action="store_true", + dest="allow_ref", + help="Allow loading from external files via the $ref key", + ) parser_check.set_defaults(entrypoint=check_cli) # doc subcommand @@ -92,6 +98,7 @@ def check_cli( func: str, jdata: list[IO], strict: bool, + allow_ref: bool = False, **kwargs: Any, ) -> None: """Normalize and check input data. @@ -104,6 +111,8 @@ def check_cli( File object that contains the JSON data strict : bool If True, raise an error if the key is not pre-defined + allow_ref : bool, optional + If True, allow loading from external files via the ``$ref`` key Returns ------- @@ -124,7 +133,7 @@ def check_cli( arginfo = func_obj() for jj in jdata: data = json.load(jj) - check(arginfo, data, strict=strict) + check(arginfo, data, strict=strict, allow_ref=allow_ref) def doc_cli( diff --git a/dargs/dargs.py b/dargs/dargs.py index bf0cf51..f067b01 100644 --- a/dargs/dargs.py +++ b/dargs/dargs.py @@ -21,6 +21,7 @@ import difflib import fnmatch import json +import os import re from copy import deepcopy from enum import Enum @@ -335,6 +336,7 @@ def traverse( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, + allow_ref: bool = False, ) -> None: # first, do something with the key # then, take out the vaule and do something with it @@ -347,7 +349,7 @@ def traverse( newpath = [*path, self.name] # this is the key step that we traverse into the tree self.traverse_value( - value, key_hook, value_hook, sub_hook, variant_hook, newpath + value, key_hook, value_hook, sub_hook, variant_hook, newpath, allow_ref ) def traverse_value( @@ -358,6 +360,7 @@ def traverse_value( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, + allow_ref: bool = False, ) -> None: # this is not private, and can be called directly # in the condition where there is no leading key @@ -365,7 +368,7 @@ def traverse_value( path = [] if not self.repeat and isinstance(value, dict): self._traverse_sub( - value, key_hook, value_hook, sub_hook, variant_hook, path + value, key_hook, value_hook, sub_hook, variant_hook, path, allow_ref ) elif self.repeat and isinstance(value, list): for idx, item in enumerate(value): @@ -376,6 +379,7 @@ def traverse_value( sub_hook, variant_hook, [*path, str(idx)], + allow_ref, ) elif self.repeat and isinstance(value, dict): for kk, item in value.items(): @@ -386,6 +390,7 @@ def traverse_value( sub_hook, variant_hook, [*path, kk], + allow_ref, ) def _traverse_sub( @@ -396,6 +401,7 @@ def _traverse_sub( sub_hook: HookArgKType = _DUMMYHOOK, variant_hook: HookVrntType = _DUMMYHOOK, path: list[str] | None = None, + allow_ref: bool = False, ) -> None: if path is None: path = [self.name] @@ -405,16 +411,21 @@ def _traverse_sub( f"key `{path[-1]}` gets wrong value type, " f"requires dict but {type(value).__name__} is given", ) + _resolve_ref(value, allow_ref) sub_hook(self, value, path) for subvrnt in self.sub_variants.values(): variant_hook(subvrnt, value, path) for subarg in self.flatten_sub(value, path).values(): - subarg.traverse(value, key_hook, value_hook, sub_hook, variant_hook, path) + subarg.traverse( + value, key_hook, value_hook, sub_hook, variant_hook, path, allow_ref + ) # above are general traverse part # below are type checking part - def check(self, argdict: dict, strict: bool = False) -> None: + def check( + self, argdict: dict, strict: bool = False, allow_ref: bool = False + ) -> None: """Check whether `argdict` meets the structure defined in self. Will recursively check nested dicts according to @@ -426,6 +437,10 @@ def check(self, argdict: dict, strict: bool = False) -> None: The arg dict to be checked strict : bool, optional If true, only keys defined in `Argument` are allowed. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. + A deep copy of ``argdict`` is made internally so the caller's + data is not mutated. """ if strict and len(argdict) != 1: raise ArgumentKeyError( @@ -434,14 +449,19 @@ def check(self, argdict: dict, strict: bool = False) -> None: "for check in strict mode at top level, " "use check_value if you are checking subfields", ) + if allow_ref: + argdict = deepcopy(argdict) self.traverse( argdict, key_hook=Argument._check_exist, value_hook=Argument._check_data, sub_hook=Argument._check_strict if strict else _DUMMYHOOK, + allow_ref=allow_ref, ) - def check_value(self, value: Any, strict: bool = False) -> None: + def check_value( + self, value: Any, strict: bool = False, allow_ref: bool = False + ) -> None: """Check the value without the leading key. Same as `check({self.name: value})`. @@ -453,12 +473,19 @@ def check_value(self, value: Any, strict: bool = False) -> None: The value to be checked strict : bool, optional If true, only keys defined in `Argument` are allowed. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. + A deep copy of ``value`` is made internally so the caller's + data is not mutated. """ + if allow_ref: + value = deepcopy(value) self.traverse_value( value, key_hook=Argument._check_exist, value_hook=Argument._check_data, sub_hook=Argument._check_strict if strict else _DUMMYHOOK, + allow_ref=allow_ref, ) def _check_exist(self, argdict: dict, path: list[str] | None = None) -> None: @@ -518,6 +545,7 @@ def normalize( do_default: bool = True, do_alias: bool = True, trim_pattern: str | None = None, + allow_ref: bool = False, ) -> dict: """Modify `argdict` so that it meets the Argument structure. @@ -537,6 +565,8 @@ def normalize( Whether to transform alias names. trim_pattern : str, optional If given, discard keys that matches the glob pattern. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. Returns ------- @@ -550,10 +580,15 @@ def normalize( argdict, key_hook=Argument._convert_alias, variant_hook=Variant._convert_choice_alias, + allow_ref=allow_ref, ) if do_default: - self.traverse(argdict, key_hook=Argument._assign_default) - self.traverse(argdict, key_hook=Argument._handle_empty_dict) + self.traverse( + argdict, key_hook=Argument._assign_default, allow_ref=allow_ref + ) + self.traverse( + argdict, key_hook=Argument._handle_empty_dict, allow_ref=allow_ref + ) if trim_pattern is not None: trim_by_pattern(argdict, trim_pattern, reserved=[self.name]) self.traverse( @@ -561,6 +596,7 @@ def normalize( sub_hook=lambda a, d, p: trim_by_pattern( d, trim_pattern, a.flatten_sub(d, p).keys() ), + allow_ref=allow_ref, ) return argdict @@ -571,6 +607,7 @@ def normalize_value( do_default: bool = True, do_alias: bool = True, trim_pattern: str | None = None, + allow_ref: bool = False, ) -> Any: """Modify the value so that it meets the Argument structure. @@ -588,6 +625,8 @@ def normalize_value( Whether to transform alias names. trim_pattern : str, optional If given, discard keys that matches the glob pattern. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. Returns ------- @@ -601,16 +640,22 @@ def normalize_value( value, key_hook=Argument._convert_alias, variant_hook=Variant._convert_choice_alias, + allow_ref=allow_ref, ) if do_default: - self.traverse_value(value, key_hook=Argument._assign_default) - self.traverse_value(value, key_hook=Argument._handle_empty_dict) + self.traverse_value( + value, key_hook=Argument._assign_default, allow_ref=allow_ref + ) + self.traverse_value( + value, key_hook=Argument._handle_empty_dict, allow_ref=allow_ref + ) if trim_pattern is not None: self.traverse_value( value, sub_hook=lambda a, d, p: trim_by_pattern( d, trim_pattern, a.flatten_sub(d, p).keys() ), + allow_ref=allow_ref, ) return value @@ -1065,6 +1110,100 @@ def trim_by_pattern( argdict.pop(key) +def _load_ref(ref_path: str) -> dict: + """Load a dict from an external file referenced by ``$ref``. + + Parameters + ---------- + ref_path : str + Path to the external file. Supported extensions: ``.json``, ``.yml``, ``.yaml``. + + Returns + ------- + dict + The loaded dict from the external file. + + Raises + ------ + ValueError + If the file extension is not supported, or if the file does not contain a + top-level mapping/object. + ImportError + If pyyaml is not installed and a YAML file is requested. + """ + ext = os.path.splitext(ref_path)[1].lower() + if ext == ".json": + with open(ref_path, encoding="utf-8") as f: + loaded = json.load(f) + elif ext in (".yml", ".yaml"): + try: + import yaml + except ImportError as e: + raise ImportError( + "pyyaml is required to load YAML files referenced by $ref. " + "Install it with: pip install pyyaml" + ) from e + with open(ref_path, encoding="utf-8") as f: + loaded = yaml.safe_load(f) + else: + raise ValueError( + f"Unsupported file extension `{ext}` for $ref. " + "Supported extensions are: .json, .yml, .yaml" + ) + if not isinstance(loaded, dict): + raise ValueError( + f"Referenced file {ref_path!r} must contain a mapping/object at the top " + f"level, but got {type(loaded).__name__!r}." + ) + return loaded + + +def _resolve_ref(d: dict, allow_ref: bool = False) -> None: + """Resolve the ``$ref`` key in a dict by loading from an external file. + + If ``$ref`` is present in ``d``, its value is treated as a file path. + The file is loaded and its contents are merged into ``d``. Keys already + present in ``d`` (other than ``$ref``) take precedence over keys from the + loaded file, allowing local overrides. Chained ``$ref`` values in the + loaded content are resolved in turn. Cyclic references are detected and + raise a ``ValueError``. + + The dict is modified **in place**. + + Parameters + ---------- + d : dict + The dict that may contain a ``$ref`` key. + allow_ref : bool, optional + If False (the default), raise a ``ValueError`` when ``$ref`` is found. + Set to True to enable loading from external files. + + Raises + ------ + ValueError + If ``$ref`` is found but ``allow_ref`` is False, or if a cyclic + reference is detected. + """ + if "$ref" not in d: + return + if not allow_ref: + raise ValueError( + "$ref is not allowed by default. " + "Pass allow_ref=True to enable loading from external files." + ) + visited_refs: set[str] = set() + while "$ref" in d: + ref_path = d.pop("$ref") + if ref_path in visited_refs: + raise ValueError(f"Cyclic $ref detected for path: {ref_path!r}") + visited_refs.add(ref_path) + loaded = _load_ref(ref_path) + # Merge: loaded content as base, local keys take precedence + merged = {**loaded, **d} + d.clear() + d.update(merged) + + def isinstance_annotation(value: Any, dtype: type | Any) -> bool: """Same as isinstance(), but supports arbitrary type annotations.""" try: diff --git a/dargs/notebook.py b/dargs/notebook.py index 80cb4f5..c705dbf 100644 --- a/dargs/notebook.py +++ b/dargs/notebook.py @@ -27,6 +27,7 @@ from IPython.display import HTML, display from dargs import Argument, Variant +from dargs.dargs import _resolve_ref __all__ = ["JSON"] @@ -90,7 +91,9 @@ """ -def JSON(data: dict | str, arg: Argument | list[Argument]) -> None: +def JSON( + data: dict | str, arg: Argument | list[Argument], allow_ref: bool = False +) -> None: """Display JSON data with Argument in the Jupyter Notebook. Parameters @@ -99,11 +102,15 @@ def JSON(data: dict | str, arg: Argument | list[Argument]) -> None: The JSON data to be displayed, either JSON string or a dict. arg : dargs.Argument or list[dargs.Argument] The Argument that describes the JSON data. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. """ - display(HTML(print_html(data, arg))) + display(HTML(print_html(data, arg, allow_ref=allow_ref))) -def print_html(data: Any, arg: Argument | list[Argument]) -> str: +def print_html( + data: Any, arg: Argument | list[Argument], allow_ref: bool = False +) -> str: """Print HTML string with Argument in the Jupyter Notebook. Parameters @@ -112,6 +119,8 @@ def print_html(data: Any, arg: Argument | list[Argument]) -> str: The JSON data to be displayed, either JSON string or a dict. arg : dargs.Argument or list[dargs.Argument] The Argument that describes the JSON data. + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. Returns ------- @@ -131,7 +140,7 @@ def print_html(data: Any, arg: Argument | list[Argument]) -> str: pass else: raise ValueError(f"Unknown type: {type(arg)}") - argdata = ArgumentData(data, arg) + argdata = ArgumentData(data, arg, allow_ref=allow_ref) buff = [css, r"""
""", argdata.print_html(), r"
"] return "".join(buff) @@ -149,14 +158,21 @@ class ArgumentData: The Argument that describes the data. repeat : bool, optional The argument is repeat + allow_ref : bool, optional + If true, allow loading from external files via the ``$ref`` key. """ def __init__( - self, data: dict, arg: Argument | Variant, repeat: bool = False + self, + data: dict, + arg: Argument | Variant, + repeat: bool = False, + allow_ref: bool = False, ) -> None: self.data = data self.arg = arg self.repeat = repeat + self.allow_ref = allow_ref self.subdata = [] self._init_subdata() @@ -167,22 +183,31 @@ def _init_subdata(self) -> None: and isinstance(self.arg, Argument) and not (self.arg.repeat and not self.repeat) ): + # Work on a copy to avoid mutating the caller's data + data = self.data.copy() + _resolve_ref(data, self.allow_ref) sub_fields = self.arg.sub_fields.copy() # extend subfiles with sub_variants for vv in self.arg.sub_variants.values(): - choice = self.data.get(vv.flag_name, vv.default_tag) + choice = data.get(vv.flag_name, vv.default_tag) if choice and choice in vv.choice_dict: sub_fields.update(vv.choice_dict[choice].sub_fields) - for kk in self.data: + for kk in data: if kk in sub_fields: - self.subdata.append(ArgumentData(self.data[kk], sub_fields[kk])) + self.subdata.append( + ArgumentData(data[kk], sub_fields[kk], allow_ref=self.allow_ref) + ) elif kk in self.arg.sub_variants: self.subdata.append( - ArgumentData(self.data[kk], self.arg.sub_variants[kk]) + ArgumentData( + data[kk], + self.arg.sub_variants[kk], + allow_ref=self.allow_ref, + ) ) else: - self.subdata.append(ArgumentData(self.data[kk], kk)) + self.subdata.append(ArgumentData(data[kk], kk)) elif ( isinstance(self.data, list) and isinstance(self.arg, Argument) @@ -190,7 +215,9 @@ def _init_subdata(self) -> None: and not self.repeat ): for dd in self.data: - self.subdata.append(ArgumentData(dd, self.arg, repeat=True)) + self.subdata.append( + ArgumentData(dd, self.arg, repeat=True, allow_ref=self.allow_ref) + ) elif ( isinstance(self.data, dict) and isinstance(self.arg, Argument) @@ -198,7 +225,9 @@ def _init_subdata(self) -> None: and not self.repeat ): for dd in self.data.values(): - self.subdata.append(ArgumentData(dd, self.arg, repeat=True)) + self.subdata.append( + ArgumentData(dd, self.arg, repeat=True, allow_ref=self.allow_ref) + ) def print_html(self, _level: int = 0, _last_one: bool = True) -> str: """Print the data with Argument in HTML format. diff --git a/docs/index.rst b/docs/index.rst index cdf7bd9..8158bdf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,6 +16,7 @@ Welcome to dargs's documentation! dpgui nb json_schema + ref api/api credits diff --git a/docs/ref.md b/docs/ref.md new file mode 100644 index 0000000..8cccfc6 --- /dev/null +++ b/docs/ref.md @@ -0,0 +1,37 @@ +## Loading from external files with `$ref` + +Any dict that is processed by `check`, `check_value`, `normalize`, or `normalize_value` +may include a `"$ref"` key whose value is a path to an external file. +Before validation or normalization, dargs will load that file and merge its +contents into the dict, with any keys already present in the dict taking +precedence (local overrides). + +Loading from external files is **disabled by default** for security. +Pass `allow_ref=True` to the relevant method to enable this feature: + +```python +argument.check(data, allow_ref=True) +argument.normalize(data, allow_ref=True) +argument.check_value(value, allow_ref=True) +argument.normalize_value(value, allow_ref=True) +``` + +Supported file formats: + +- **JSON** (`.json`) — no extra dependencies required. +- **YAML** (`.yml` / `.yaml`) — requires [pyyaml](https://pypi.org/project/pyyaml/). + Install it with `pip install pyyaml` or `pip install dargs[yaml]`. + +Example — split a large config into reusable pieces: + +```json +{ + "model": { + "$ref": "model_defaults.json", + "hidden_size": 256 + } +} +``` + +The contents of `model_defaults.json` are loaded first, then `"hidden_size": 256` +overrides (or adds to) the loaded values before the dict is validated or normalized. diff --git a/pyproject.toml b/pyproject.toml index f1d92e6..d63fc50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,12 +31,16 @@ repository = "https://github.com/deepmodeling/dargs" test = [ "ipython", "jsonschema", + "pyyaml", ] typecheck = [ "ty==0.0.17", "sphinx", "ipython", ] +yaml = [ + "pyyaml", +] [project.scripts] dargs = "dargs.cli:main" diff --git a/tests/test_ref.py b/tests/test_ref.py new file mode 100644 index 0000000..7c04f0c --- /dev/null +++ b/tests/test_ref.py @@ -0,0 +1,231 @@ +"""Tests for $ref loading from external JSON/YAML files.""" + +from __future__ import annotations + +import importlib.util +import json +import os +import tempfile +import unittest + +from dargs import Argument + + +class TestRef(unittest.TestCase): + def setUp(self) -> None: + self._tmpdir = tempfile.mkdtemp() + + def tearDown(self) -> None: + import shutil + + shutil.rmtree(self._tmpdir, ignore_errors=True) + + def _tmpfile(self, name: str) -> str: + return os.path.join(self._tmpdir, name) + + def _write_json(self, name: str, data: dict) -> str: + path = self._tmpfile(name) + with open(path, "w") as f: + json.dump(data, f) + return path + + def _write_yaml(self, name: str, text: str) -> str: + path = self._tmpfile(name) + with open(path, "w") as f: + f.write(text) + return path + + def test_ref_not_allowed_by_default(self) -> None: + """$ref raises ValueError when allow_ref is not set (secure by default).""" + ref_path = self._write_json("ref_default.json", {"sub1": 1}) + ca = Argument("base", dict, [Argument("sub1", int)]) + with self.assertRaises(ValueError): + ca.check({"base": {"$ref": ref_path}}) + with self.assertRaises(ValueError): + ca.normalize({"base": {"$ref": ref_path}}) + with self.assertRaises(ValueError): + ca.check_value({"$ref": ref_path}) + with self.assertRaises(ValueError): + ca.normalize_value({"$ref": ref_path}) + + def test_ref_json_check(self) -> None: + """$ref to a JSON file is resolved before check.""" + ref_path = self._write_json("ref_test.json", {"sub1": 1, "sub2": "hello"}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + ca.check({"base": {"$ref": ref_path}}, allow_ref=True) + + def test_ref_json_normalize(self) -> None: + """$ref to a JSON file is resolved before normalize.""" + ref_path = self._write_json("ref_norm.json", {"sub1": 1}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str, optional=True, default="default"), + ], + ) + result = ca.normalize({"base": {"$ref": ref_path}}, allow_ref=True) + self.assertEqual(result["base"]["sub1"], 1) + self.assertEqual(result["base"]["sub2"], "default") + + def test_ref_local_override(self) -> None: + """Keys in the dict alongside $ref override keys from the loaded file.""" + ref_path = self._write_json( + "ref_override.json", {"sub1": 1, "sub2": "from_file"} + ) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + result = ca.normalize( + {"base": {"$ref": ref_path, "sub2": "local"}}, allow_ref=True + ) + self.assertEqual(result["base"]["sub1"], 1) + self.assertEqual(result["base"]["sub2"], "local") + + def test_ref_yaml(self) -> None: + """$ref to a YAML file is resolved when pyyaml is installed.""" + if importlib.util.find_spec("yaml") is None: + self.skipTest("pyyaml not installed") + ref_path = self._write_yaml("ref_test.yaml", "sub1: 42\nsub2: yaml_val\n") + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + ca.check({"base": {"$ref": ref_path}}, allow_ref=True) + + def test_ref_yml_extension(self) -> None: + """$ref works with .yml extension as well.""" + if importlib.util.find_spec("yaml") is None: + self.skipTest("pyyaml not installed") + ref_path = self._write_yaml("ref_test.yml", "sub1: 7\nsub2: yml_val\n") + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + ca.check({"base": {"$ref": ref_path}}, allow_ref=True) + + def test_ref_unsupported_extension(self) -> None: + """$ref with unsupported extension raises ValueError.""" + ref_path = self._tmpfile("ref_test.toml") + with open(ref_path, "w") as f: + f.write("sub1 = 1\n") + ca = Argument("base", dict, [Argument("sub1", int)]) + with self.assertRaises(ValueError): + ca.check({"base": {"$ref": ref_path}}, allow_ref=True) + + def test_ref_check_value(self) -> None: + """$ref is resolved when using check_value.""" + ref_path = self._write_json("ref_val.json", {"sub1": 5, "sub2": "v"}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + ca.check_value({"$ref": ref_path}, allow_ref=True) + + def test_ref_normalize_value(self) -> None: + """$ref is resolved when using normalize_value.""" + ref_path = self._write_json("ref_normval.json", {"sub1": 99}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str, optional=True, default="d"), + ], + ) + result = ca.normalize_value({"$ref": ref_path}, allow_ref=True) + self.assertEqual(result["sub1"], 99) + self.assertEqual(result["sub2"], "d") + + def test_ref_check_no_mutation(self) -> None: + """check() with allow_ref=True does not mutate the caller's data.""" + ref_path = self._write_json("ref_nomut.json", {"sub1": 1, "sub2": "v"}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + original = {"base": {"$ref": ref_path}} + ca.check(original, allow_ref=True) + # $ref key must still be present in the original + self.assertIn("$ref", original["base"]) + + def test_ref_check_value_no_mutation(self) -> None: + """check_value() with allow_ref=True does not mutate the caller's data.""" + ref_path = self._write_json("ref_nomut_val.json", {"sub1": 1, "sub2": "v"}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + original = {"$ref": ref_path} + ca.check_value(original, allow_ref=True) + self.assertIn("$ref", original) + + def test_ref_non_dict_content(self) -> None: + """$ref pointing to a non-dict file raises ValueError.""" + ref_path = self._write_json("ref_list.json", [1, 2, 3]) + ca = Argument("base", dict, [Argument("sub1", int)]) + with self.assertRaises(ValueError): + ca.check({"base": {"$ref": ref_path}}, allow_ref=True) + + def test_ref_cyclic_detection(self) -> None: + """Cyclic $ref raises ValueError.""" + # Write a file that points back to itself + ref_path = self._tmpfile("ref_cyclic.json") + with open(ref_path, "w") as f: + json.dump({"$ref": ref_path}, f) + ca = Argument("base", dict, [Argument("sub1", int, optional=True)]) + with self.assertRaises(ValueError, msg="Cyclic $ref"): + ca.check({"base": {"$ref": ref_path}}, allow_ref=True) + + def test_ref_chained(self) -> None: + """A $ref that loads a file containing another $ref is fully resolved.""" + inner_path = self._write_json("ref_inner.json", {"sub1": 7, "sub2": "inner"}) + outer_path = self._write_json("ref_outer.json", {"$ref": inner_path}) + ca = Argument( + "base", + dict, + [ + Argument("sub1", int), + Argument("sub2", str), + ], + ) + result = ca.normalize({"base": {"$ref": outer_path}}, allow_ref=True) + self.assertEqual(result["base"]["sub1"], 7) + self.assertEqual(result["base"]["sub2"], "inner") + + +if __name__ == "__main__": + unittest.main()