Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2527,12 +2527,6 @@ class AB(A, B):
class BA(B, A):
pass

assert 'trait' in Base._trait_default_generators
assert 'trait' not in A._trait_default_generators
assert 'trait' in B._trait_default_generators
assert 'trait' not in AB._trait_default_generators
assert 'trait' not in BA._trait_default_generators

assert A().trait == 'base'
assert A().attr == 'base'
assert BA().trait == 'B'
Expand All @@ -2547,3 +2541,37 @@ def __init__(__self, cls, self):
pass

x = X(cls=None, self=None)


def test_override_default():
class C(HasTraits):
a = Unicode('hard default')
def _a_default(self):
return 'default method'

C._a_default = lambda self: 'overridden'
c = C()
assert c.a == 'overridden'

def test_override_default_decorator():
class C(HasTraits):
a = Unicode('hard default')
@default('a')
def _a_default(self):
return 'default method'

C._a_default = lambda self: 'overridden'
c = C()
assert c.a == 'overridden'

def test_override_default_instance():
class C(HasTraits):
a = Unicode('hard default')
@default('a')
def _a_default(self):
return 'default method'

c = C()
c._a_default = lambda self: 'overridden'
assert c.a == 'overridden'

41 changes: 23 additions & 18 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,16 +417,6 @@ class TraitType(BaseDescriptor):
info_text = 'any value'
default_value = Undefined

def class_init(self, cls, name):
super(TraitType, self).class_init(cls, name)
if self.name is not None and self.name not in cls._trait_default_generators:
cls._trait_default_generators[self.name] = self.default

def subclass_init(self, cls):
if '_%s_default' % self.name in cls.__dict__:
method = getattr(cls, '_%s_default' % self.name)
cls._trait_default_generators[self.name] = method

def __init__(self, default_value=Undefined, allow_none=False, read_only=None, help=None,
config=None, **kwargs):
"""Declare a traitlet.
Expand Down Expand Up @@ -1501,16 +1491,30 @@ def trait_values(self, **metadata):
"""
return {name: getattr(self, name) for name in self.trait_names(**metadata)}

@classmethod
def _get_trait_default_generator(cls, name):
def _get_trait_default_generator(self, name):
"""Return default generator for a given trait

Walk the MRO to resolve the correct default generator according to inheritance.
"""
for c in cls.mro():
method_name = '_%s_default' % name
if method_name in self.__dict__:
return getattr(self, method_name)
cls = self.__class__
trait = getattr(cls, name)
assert isinstance(trait, TraitType)
# truncate mro to the class on which the trait is defined
mro = cls.mro()
try:
mro = mro[:mro.index(trait.this_class) + 1]
except ValueError:
# this_class not in mro
pass
for c in mro:
if method_name in c.__dict__:
return getattr(c, method_name)
if name in c.__dict__.get('_trait_default_generators', {}):
return c._trait_default_generators[name]
raise KeyError("No default generator for trait %r found in %r" % (name, cls.mro()))
return trait.default

def trait_defaults(self, *names, **metadata):
"""Return a trait's default value or a dictionary of them
Expand All @@ -1519,13 +1523,14 @@ def trait_defaults(self, *names, **metadata):
-----
Dynamically generated default values may
depend on the current state of the object."""
if len(names) == 1 and len(metadata) == 0:
return self._get_trait_default_generator(names[0])(self)

for n in names:
if not has_trait(self, n):
if not self.has_trait(n):
raise TraitError("'%s' is not a trait of '%s' "
"instances" % (n, type(self).__name__))

if len(names) == 1 and len(metadata) == 0:
return self._get_trait_default_generator(names[0])(self)

trait_names = self.trait_names(**metadata)
trait_names.extend(names)

Expand Down