From 85b37805e816b89b29cdfd85684a14aadfd52cdf Mon Sep 17 00:00:00 2001 From: Ayush Agrawal Date: Sat, 4 Apr 2026 20:07:27 -0700 Subject: [PATCH] chore: Handle recursive JSON schema references in type conversion. Fixes #2181 PiperOrigin-RevId: 894740183 --- google/genai/_transformers.py | 27 +++++++++++++-- .../genai/tests/transformers/test_schema.py | 33 +++++++++++++++++++ google/genai/tests/types/test_types.py | 24 ++++++++++++++ google/genai/types.py | 19 ++++++++--- 4 files changed, 97 insertions(+), 6 deletions(-) diff --git a/google/genai/_transformers.py b/google/genai/_transformers.py index 0e1a9c41c..b9867d3e6 100644 --- a/google/genai/_transformers.py +++ b/google/genai/_transformers.py @@ -665,6 +665,7 @@ def process_schema( defs: Optional[_common.StringDict] = None, *, order_properties: bool = True, + visited_dicts_path: Optional[set[int]] = None, ) -> None: """Updates the schema and each sub-schema inplace to be API-compatible. @@ -726,6 +727,13 @@ def process_schema( 'type': 'array' } """ + if visited_dicts_path is None: + visited_dicts_path = set() + + if id(schema) in visited_dicts_path: + return + visited_dicts_path.add(id(schema)) + if schema.get('title') == 'PlaceholderLiteralEnum': del schema['title'] @@ -750,7 +758,11 @@ def process_schema( # directly referencing another '$ref': # https://json-schema.org/understanding-json-schema/structuring#recursion process_schema( - sub_schema, client, defs, order_properties=order_properties + sub_schema, + client, + defs, + order_properties=order_properties, + visited_dicts_path=visited_dicts_path, ) handle_null_fields(schema) @@ -765,11 +777,21 @@ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict: """Returns the processed `sub_schema`, resolving its '$ref' if any.""" if (ref := sub_schema.pop('$ref', None)) is not None: sub_schema = defs[ref.split('defs/')[-1]] - process_schema(sub_schema, client, defs, order_properties=order_properties) + if id(sub_schema) in visited_dicts_path: + return {} + + process_schema( + sub_schema, + client, + defs, + order_properties=order_properties, + visited_dicts_path=visited_dicts_path, + ) return sub_schema if (any_of := schema.get('anyOf')) is not None: schema['anyOf'] = [_recurse(sub_schema) for sub_schema in any_of] + visited_dicts_path.remove(id(schema)) return schema_type = schema.get('type') @@ -809,6 +831,7 @@ def _recurse(sub_schema: _common.StringDict) -> _common.StringDict: if (prefixes := schema.get('prefixItems')) is not None: schema['prefixItems'] = [_recurse(prefix) for prefix in prefixes] + visited_dicts_path.remove(id(schema)) def _process_enum( enum: EnumMeta, client: Optional[_api_client.BaseApiClient] diff --git a/google/genai/tests/transformers/test_schema.py b/google/genai/tests/transformers/test_schema.py index 69754362b..8655a01cc 100644 --- a/google/genai/tests/transformers/test_schema.py +++ b/google/genai/tests/transformers/test_schema.py @@ -607,6 +607,39 @@ def test_process_schema_order_properties_propagates_into_any_of( assert schema == schema_without_property_ordering +@pytest.mark.parametrize('use_vertex', [True, False]) +def test_process_schema_with_cycle(client): + schema = { + 'type': 'OBJECT', + 'properties': { + 'recursive': {'$ref': '#/$defs/RecursiveObject'}, + }, + '$defs': { + 'RecursiveObject': { + 'type': 'OBJECT', + 'properties': { + 'self': {'$ref': '#/$defs/RecursiveObject'}, + } + } + } + } + + _transformers.process_schema(schema, client) + + expected = { + 'type': 'OBJECT', + 'properties': { + 'recursive': { + 'type': 'OBJECT', + 'properties': { + 'self': {} + } + } + } + } + assert schema == expected + + @pytest.mark.parametrize('use_vertex', [True, False]) def test_t_schema_does_not_change_property_ordering_if_set(client): """Tests t_schema doesn't overwrite the property_ordering field if already set.""" diff --git a/google/genai/tests/types/test_types.py b/google/genai/tests/types/test_types.py index 531a5c5ad..099ff2a2d 100644 --- a/google/genai/tests/types/test_types.py +++ b/google/genai/tests/types/test_types.py @@ -2687,6 +2687,30 @@ def func_under_test(a: int) -> str: assert actual_schema_vertex == expected_schema_vertex +def test_convert_json_schema_with_cycle(): + json_schema_dict = { + 'type': 'object', + 'properties': { + 'foo': {'$ref': '#/$defs/Foo'} + }, + '$defs': { + 'Foo': { + 'type': 'object', + 'properties': { + 'foo': {'$ref': '#/$defs/Foo'} + } + } + } + } + + json_schema = types.JSONSchema(**json_schema_dict) + schema = types.Schema.from_json_schema(json_schema=json_schema) + + assert schema.type == types.Type.OBJECT + assert schema.properties['foo'].type == types.Type.OBJECT + assert schema.properties['foo'].properties['foo'] == types.Schema() + + def test_case_insensitive_enum(): assert types.Type('STRING') == types.Type.STRING assert types.Type('string') == types.Type.STRING diff --git a/google/genai/types.py b/google/genai/types.py index 4bc37eb98..4d0976c45 100644 --- a/google/genai/types.py +++ b/google/genai/types.py @@ -2911,14 +2911,20 @@ def convert_json_schema( root_json_schema_dict: dict[str, Any], api_option: Literal['VERTEX_AI', 'GEMINI_API'], raise_error_on_unsupported_field: bool, + visited_refs: Optional[set[str]] = None, ) -> 'Schema': + if visited_refs is None: + visited_refs = set() + schema = Schema() json_schema_dict = current_json_schema.model_dump() - if json_schema_dict.get('ref'): - json_schema_dict = _resolve_ref( - json_schema_dict['ref'], root_json_schema_dict - ) + ref = json_schema_dict.get('ref') + if ref: + if ref in visited_refs: + return Schema() + visited_refs.add(ref) + json_schema_dict = _resolve_ref(ref, root_json_schema_dict) raise_error_if_cannot_convert( json_schema_dict=json_schema_dict, @@ -2985,6 +2991,7 @@ def convert_json_schema( root_json_schema_dict=root_json_schema_dict, api_option=api_option, raise_error_on_unsupported_field=raise_error_on_unsupported_field, + visited_refs=visited_refs, ) setattr(schema, field_name, schema_field_value) elif field_name in list_schema_field_names: @@ -2994,6 +3001,7 @@ def convert_json_schema( root_json_schema_dict=root_json_schema_dict, api_option=api_option, raise_error_on_unsupported_field=raise_error_on_unsupported_field, + visited_refs=visited_refs, ) for this_field_value in field_value ] @@ -3007,6 +3015,7 @@ def convert_json_schema( root_json_schema_dict=root_json_schema_dict, api_option=api_option, raise_error_on_unsupported_field=raise_error_on_unsupported_field, + visited_refs=visited_refs, ) for key, value in field_value.items() } @@ -3051,6 +3060,8 @@ def convert_json_schema( if default_value is not None: schema.default = default_value + if ref: + visited_refs.remove(ref) return schema # This is the initial call to the recursive function.