diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 778bbf88..2087955a 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -2527,12 +2527,6 @@ class AB(A, B): class BA(B, A): pass - assert 'trait' in Base._trait_default_generators - assert 'trait' not in A._trait_default_generators - assert 'trait' in B._trait_default_generators - assert 'trait' not in AB._trait_default_generators - assert 'trait' not in BA._trait_default_generators - assert A().trait == 'base' assert A().attr == 'base' assert BA().trait == 'B' @@ -2547,3 +2541,37 @@ def __init__(__self, cls, self): pass x = X(cls=None, self=None) + + +def test_override_default(): + class C(HasTraits): + a = Unicode('hard default') + def _a_default(self): + return 'default method' + + C._a_default = lambda self: 'overridden' + c = C() + assert c.a == 'overridden' + +def test_override_default_decorator(): + class C(HasTraits): + a = Unicode('hard default') + @default('a') + def _a_default(self): + return 'default method' + + C._a_default = lambda self: 'overridden' + c = C() + assert c.a == 'overridden' + +def test_override_default_instance(): + class C(HasTraits): + a = Unicode('hard default') + @default('a') + def _a_default(self): + return 'default method' + + c = C() + c._a_default = lambda self: 'overridden' + assert c.a == 'overridden' + diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 2d4ae73b..803d047e 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -417,16 +417,6 @@ class TraitType(BaseDescriptor): info_text = 'any value' default_value = Undefined - def class_init(self, cls, name): - super(TraitType, self).class_init(cls, name) - if self.name is not None and self.name not in cls._trait_default_generators: - cls._trait_default_generators[self.name] = self.default - - def subclass_init(self, cls): - if '_%s_default' % self.name in cls.__dict__: - method = getattr(cls, '_%s_default' % self.name) - cls._trait_default_generators[self.name] = method - def __init__(self, default_value=Undefined, allow_none=False, read_only=None, help=None, config=None, **kwargs): """Declare a traitlet. @@ -1501,16 +1491,30 @@ def trait_values(self, **metadata): """ return {name: getattr(self, name) for name in self.trait_names(**metadata)} - @classmethod - def _get_trait_default_generator(cls, name): + def _get_trait_default_generator(self, name): """Return default generator for a given trait Walk the MRO to resolve the correct default generator according to inheritance. """ - for c in cls.mro(): + method_name = '_%s_default' % name + if method_name in self.__dict__: + return getattr(self, method_name) + cls = self.__class__ + trait = getattr(cls, name) + assert isinstance(trait, TraitType) + # truncate mro to the class on which the trait is defined + mro = cls.mro() + try: + mro = mro[:mro.index(trait.this_class) + 1] + except ValueError: + # this_class not in mro + pass + for c in mro: + if method_name in c.__dict__: + return getattr(c, method_name) if name in c.__dict__.get('_trait_default_generators', {}): return c._trait_default_generators[name] - raise KeyError("No default generator for trait %r found in %r" % (name, cls.mro())) + return trait.default def trait_defaults(self, *names, **metadata): """Return a trait's default value or a dictionary of them @@ -1519,13 +1523,14 @@ def trait_defaults(self, *names, **metadata): ----- Dynamically generated default values may depend on the current state of the object.""" - if len(names) == 1 and len(metadata) == 0: - return self._get_trait_default_generator(names[0])(self) - for n in names: - if not has_trait(self, n): + if not self.has_trait(n): raise TraitError("'%s' is not a trait of '%s' " "instances" % (n, type(self).__name__)) + + if len(names) == 1 and len(metadata) == 0: + return self._get_trait_default_generator(names[0])(self) + trait_names = self.trait_names(**metadata) trait_names.extend(names)