fix: resolve runtime crashes and mypy errors with transformers >= 4.57 + Pydantic v2#369
fix: resolve runtime crashes and mypy errors with transformers >= 4.57 + Pydantic v2#369rahul-tuli wants to merge 3 commits intomainfrom
Conversation
|
📦 Build Artifacts Available |
|
The quality checks have failed. Please run |
eb10ce2 to
fdf5ddf
Compare
…ibility
transformers v5 (available since 4.57.x as py.typed with inline stubs) introduced
three incompatibilities with our Pydantic + PretrainedConfig multiple-inheritance
pattern, causing test collection failures and runtime crashes.
## Problem 1 — PydanticUndefinedAnnotation at import time
Symptom: all tests fail at collection with:
pydantic.errors.PydanticUndefinedAnnotation: name 'torch' is not defined
Root cause: reload_schemas() calls model_rebuild(force=True) on SpeculatorModelConfig
and subclasses. transformers v5's PretrainedConfig.dtype field uses the forward
reference "torch.dtype". Pydantic evaluates this at rebuild time, but torch is not
in the evaluation namespace, so the rebuild fails.
Fix: SpeculatorModelConfig.reload_schema() overrides the base implementation to pass
_types_namespace={"torch": torch} to model_rebuild(), and rebuilds each registered
subclass (Eagle3SpeculatorConfig, EagleSpeculatorConfig, etc.) individually since
model_rebuild() on a parent does not propagate to subclasses.
## Problem 2 — AttributeError on SpeculatorModelConfig subclass construction
Symptom: constructing Eagle3SpeculatorConfig (or any subclass that doesn't define
its own __init__) raises:
AttributeError: 'Eagle3SpeculatorConfig' object has no attribute '__pydantic_fields_set__'
Root cause: transformers v5's PretrainedConfig.__init_subclass__ applies @DataClass
and wrap_init_to_accept_kwargs to every subclass that lacks __init__ in cls.__dict__.
This replaces the inherited SpeculatorModelConfig.__init__ with a dataclass-generated
wrapper that calls setattr() for every field before Pydantic can initialize
__pydantic_fields_set__, triggering Pydantic's __setattr__ too early.
Fix: SpeculatorModelConfig.__init_subclass__ injects __init__ into each subclass's
__dict__ before super().__init_subclass__() runs. PretrainedConfig.__init_subclass__
checks "__init__" in cls.__dict__ before wrapping, so the injection prevents the
dataclass wrapper from running. Python's @DataClass(repr=False) also skips __init__
generation when the class already defines one.
## Problem 3 — TypeError in save_pretrained via self.validate()
Symptom: save_pretrained() raises:
TypeError: BaseModel.validate() missing 1 required positional argument: 'value'
Root cause: transformers v5's @strict decorator adds a validate() instance method to
PretrainedConfig to run class validators (validate_architecture, validate_token_ids,
etc.). Pydantic's BaseModel.validate() is a classmethod (def validate(cls, value))
that comes earlier in our MRO, shadowing PretrainedConfig.validate(). When
save_pretrained() calls self.validate(), it hits BaseModel.validate() which requires
a value argument.
Fix: SpeculatorModelConfig.validate() explicitly delegates to PretrainedConfig.validate()
so the @strict validators run correctly.
## mypy fixes (Python 3.13 + mypy 1.15.0)
- config.py: base_model_pp_plan type widened from tuple[list[str]] to Sequence[list[str]]
to match the updated PretrainedConfig class variable declaration in transformers v5
- model_definitions.py: Eagle3FirstLayerMixin.forward() drops cache_position parameter
which was removed from LlamaDecoderLayer.forward() / Qwen3DecoderLayer.forward() in v5;
keeping it caused [misc] incompatible-override errors on both base classes
- eagle_converter.py, eagle3_converter.py, test_eagle_config.py, test_eagle_model.py,
test_setup_model.py: LlamaConfig() calls flagged as [call-arg] because transformers v5's
@strict decorator wraps LlamaConfig.__init__ via @wraps, which makes mypy follow
__wrapped__ back to PretrainedConfig.__init__ (losing all LlamaConfig-specific fields).
Fix: use llama_kwargs: dict[str, Any] = {...}; LlamaConfig(**llama_kwargs) — mypy cannot
check specific key names when unpacking dict[str, Any], so [call-arg] is bypassed without
any type: ignore suppressions.
pyproject.toml: add local/ and output/ to ruff's exclude list so local experiment
artifacts don't trigger lint errors during development.
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
fdf5ddf to
3d6fcc5
Compare
…egration test
transformers v5 added strict validation for rope_type=llama3 configs, requiring
rope_theta inside rope_scaling. RedHatAI/Llama-3.1-8B-Instruct's Hub config
predates this requirement, so PretrainedConfig.from_pretrained() now raises:
KeyError: Missing required keys in `rope_parameters` for 'rope_type'='llama3': {'rope_theta'}
The test only needs architectures from the config (to feed VerifierConfig.from_config).
Use get_config_dict() to fetch the raw dict without triggering validation, drop
rope_scaling, then construct a minimal PretrainedConfig for the test assertion.
Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
…rmers < 5.x transformers 4.57.x does not add validate() to PretrainedConfig (@strict was not yet present in that release line). Calling PretrainedConfig.validate(self) unconditionally raises AttributeError on 4.57.x. Add a hasattr() guard so the delegation only runs when validate() actually exists (transformers >= 5.x). The __init_subclass__ and reload_schema fixes are already no-ops on 4.57.x: - 4.57.x does not apply @DataClass + wrap_init_to_accept_kwargs in __init_subclass__ - 4.57.x has no dtype: "torch.dtype" forward reference in PretrainedConfig Confirmed: Eagle3SpeculatorConfig construction and save_pretrained/from_pretrained roundtrip work correctly under both transformers 4.57.6 and 5.4.0. Signed-off-by: Rahul-Tuli <rtuli@redhat.com>
|
@rahul-tuli What else do we need to do before this is ready for review? |
fynnsu
left a comment
There was a problem hiding this comment.
These errors were introduced by transformers 5.4.0 which was released yesterday. This seems to have broken type checking for model configs and caused an issue when PreTrainedConfig is used as a pydantic field.
I don't think this pr is the right way to solve these issues. These are problems with transformers itself and shouldn't be solved by:
- Replacing all instantiations of transformers configs with creating a dict that then gets passed in as
**kwargsto the config class. This is a hacky workaround that tricks the type checker into not validating the types. - The code in
src/speculators/config.pywhich manually injectstorchinto the namespace used to rebuild config because it wasn't included correctly in the transformers code.
These are issues with the upstream transformers project, and therefore should be solved there, rather than patched here. I've opened issues for both of these problems in the transformers repo. In the meantime, I suggest we cap transformers at <5.4.0.
|
I want to also note that we will be capping to transformers <v5 for the next release. I think it may make sense to revert upgrading if the time to fully support v5 is going to surpass our release time which is about 1-2 weeks out. Or at least capping the version so that we have a green CI which is required for dflash Cc @fynnsu |
Summary
Fix three runtime crashes, three static type errors, and one integration test
failure introduced by transformers v5 (
>= 4.57.xwithpy.typedinline stubs,PretrainedConfigrewritten as@strict @dataclass) when used with Pydantic v2and mypy 1.15.0.
All fixes are backwards-compatible with transformers 4.57.x (verified via
Eagle3SpeculatorConfigconstruction andsave_pretrained/from_pretrainedroundtrip under both transformers 4.57.6 and 5.4.0).
Closes #370.
Runtime crashes (affect Python 3.10+ with transformers >= 4.57.x + Pydantic v2)
1.
PydanticUndefinedAnnotation: name 'torch' is not defined(import-time crash)Error — every test file fails at collection:
Root cause —
reload_schemas()inspeculators/__init__.pycallsmodel_rebuild(force=True)onSpeculatorModelConfig. transformers v5 addeddtype: Union[str, "torch.dtype"] | NonetoPretrainedConfig. Pydanticevaluates this forward reference during
model_rebuild(), buttorchis notin Pydantic's evaluation namespace.
Fix —
SpeculatorModelConfig.reload_schema()passes_types_namespace={"torch": torch}tomodel_rebuild()AND explicitly iteratescls.registry.values()to rebuild each subclass (parent rebuild does notpropagate to subclasses).
Backwards compat — transformers 4.57.x has no
dtype: "torch.dtype"annotation; passing extra keys to
_types_namespaceis harmless.2.
AttributeError: 'Eagle3SpeculatorConfig' has no attribute '__pydantic_fields_set__'(construction crash)Error — constructing any
SpeculatorModelConfigsubclass that does not defineits own
__init__:Root cause — transformers v5's
PretrainedConfig.__init_subclass__applies@dataclass(repr=False)+wrap_init_to_accept_kwargsto every subclass thatlacks
__init__incls.__dict__. This replaces the inherited PydanticSpeculatorModelConfig.__init__with a dataclass-generated wrapper that callssetattr(self, f.name, f.default)for each field before Pydantic hasinitialized
__pydantic_fields_set__, triggering Pydantic's__setattr__tooearly.
Fix —
SpeculatorModelConfig.__init_subclass__injectscls.__init__ = SpeculatorModelConfig.__init__into each subclass's__dict__BEFORE calling
super().__init_subclass__(). The check"__init__" in cls.__dict__in transformers then skips wrapping.Backwards compat — transformers 4.57.x does not apply
@dataclass + wrap_init_to_accept_kwargsin__init_subclass__; the injectionis a no-op.
3.
TypeError: BaseModel.validate() missing 1 required positional argument: 'value'(save_pretrainedcrash)Error —
model.save_pretrained(path)raises:from
transformers/configuration_utils.py:517: self.validate().Root cause — transformers v5's
@strictdecorator adds avalidate()instance method to
PretrainedConfigthat runs class validators(
validate_architecture,validate_token_ids, etc.). Pydantic'sBaseModel.validate(cls, value)classmethod appears earlier in the MRO andshadows it. When
save_pretrained()callsself.validate(), it hitsBaseModel.validatewhich requires avalueargument.Fix —
SpeculatorModelConfig.validate()instance method delegates explicitlyto
PretrainedConfig.validate(self), with ahasattr(PretrainedConfig, "validate")guard for transformers < 5.x (4.57.xdoes not have this method).
Backwards compat — transformers 4.57.x has no
validate()onPretrainedConfig; thehasattrguard correctly skips delegation.Static type errors (mypy 1.15.0)
4. 79
[call-arg]errors onLlamaConfig(...)across 5 filesError — mypy 1.15.0 reports on every
LlamaConfig(vocab_size=..., ...)call:Root cause — transformers v5's
@strict(accept_kwargs=True)wrapsPretrainedConfig.__init__using@wraps. mypy 1.15.0 follows__wrapped__back to
PretrainedConfig.__init__, which only declares base-class fields. AllLlamaConfig-specific kwargs appear unknown.Fix — Use
llama_kwargs: dict[str, Any] = {...}; LlamaConfig(**llama_kwargs).mypy cannot verify individual key names in
dict[str, Any]unpacking, so[call-arg]is bypassed withouttype: ignoresuppressions:5.
[misc]incompatibleforward()override ineagle3/model_definitions.pyError:
Root cause — transformers v5 removed
cache_position: torch.LongTensor | Nonefrom the explicit parameter list of
LlamaDecoderLayer.forward()andQwen3DecoderLayer.forward()(it now flows through**kwargs). Our mixin stilldeclared
cache_position, causing an incompatible-override error.Fix — Remove
cache_positionfromEagle3FirstLayerMixin.forward()and thecache_position=cache_positionkwarg in theself.self_attn(...)call. It flowsthrough
**kwargs.6.
[assignment]onbase_model_pp_planinconfig.pyError:
Root cause — transformers v5 widened
PretrainedConfig.base_model_pp_planfrom
dict[str, tuple[list[str]]]todict[str, Sequence[list[str]]].Fix — Align
SpeculatorModelConfig.base_model_pp_planannotation toSequence[list[str]].Integration test fix
7.
KeyErrorintest_verifier_config_from_verifier_configError —
PretrainedConfig.from_pretrained("RedHatAI/Llama-3.1-8B-Instruct")raises a
KeyErrorduring rope_scaling validation.Root cause — transformers v5 strictly validates
rope_scalingand requiresrope_thetainsiderope_scalingwhenrope_type=llama3. TheRedHatAI/Llama-3.1-8B-InstructHub config predates this requirement.Fix — Use
PretrainedConfig.get_config_dict()(raw fetch, no validation),drop
rope_scaling, then constructPretrainedConfig(**config_dict).VerifierConfig.from_config()only readsarchitectures, not rope parameters.Files changed (9 files)
src/speculators/config.pyreload_schema,__init_subclass__,validate,base_model_pp_plan)src/speculators/convert/eagle/eagle_converter.pyLlamaConfig)src/speculators/convert/eagle/eagle3_converter.pyLlamaConfig)src/speculators/models/eagle3/model_definitions.pycache_position)tests/unit/models/test_eagle_config.pyLlamaConfig)tests/unit/models/test_eagle_model.pyLlamaConfig)tests/unit/train/test_setup_model.pyLlamaConfig)tests/integration/test_config.pypyproject.tomltransformers >= 4.57.0Test results
pytest tests/unit/— 193/194 pass (1 failure: distributed test requiring ncclport 29500, already in use — pre-existing env issue, unrelated to these changes)
pytest tests/integration/— 2 passed, 7 skippedEagle3SpeculatorConfig(...)construction: OK under both 4.57.6 and 5.4.0model.save_pretrained(path)+from_pretrained(path)roundtrip: OK under both versionsfrom speculators import reload_schemas; reload_schemas(): OK under both versions