diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 2d9665ba..976df4a0 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -2319,6 +2319,125 @@ class C(HasTraits): assert C.d.default_value_repr() == '{}' +class Bounded(HasTraits): + """From ipywidgets which was broken by cross-validation of defaults""" + max = Integer(10) + min = Integer(0) + value = Integer(5) + + @validate('value') + def _validate_value(self, proposal): + """Cap and floor value""" + value = proposal.value + if self.min > value or self.max < value: + value = min(max(value, self.min), self.max) + return value + + @validate('min') + def _validate_min(self, proposal): + """Enforce min <= value <= max""" + min = proposal.value + if min > self.max: + raise TraitError('setting min > max') + if min > self.value: + self.value = min + return min + + @validate('max') + def _validate_max(self, proposal): + """Enforce min <= value <= max""" + max = proposal.value + if max < self.min: + raise TraitError('setting max < min') + if max < self.value: + self.value = max + return max + + +def test_cross_validate_cycles(): + obj = Bounded() + assert obj.min == 0 + assert obj.max == 10 + assert obj.value == 5 + + obj = Bounded(value=25, max=50) + assert obj.max == 50 + assert obj.value == 25 + + obj = Bounded(max=2) + assert obj.max == 2 + assert obj.value == 2 + + obj = Bounded(min=7) + assert obj.min == 7 + assert obj.value == 7 + + obj = Bounded(min=7, value=8) + assert obj.min == 7 + assert obj.value == 8 + + obj = Bounded(value=8, max=4) + assert obj.value == 4 + assert obj.max == 4 + + obj = Bounded(value=-5, max=-1, min=-10) + assert obj.value == -5 + assert obj.max == -1 + + +def test_cross_validate_defaults(): + class A(HasTraits): + x = Any(()) + y = Any(5) + + @validate('x') + def f(self, proposal): + return (1, self.y) + + assert A().x == (1, 5) + assert A(y=2).x == (1, 2) + +def test_cross_validate_default_cycles(): + a_called_with = [] + b_called_with = [] + class A(HasTraits): + a = Integer(0) + @validate('a') + def _validate_a(self, proposal): + a = proposal.value + a_called_with.append((a, self.b)) + return 2 + + b = Integer(0) + @validate('b') + def _validate_b(self, proposal): + b = proposal.value + b_called_with.append((b, self.a)) + return 3 + + a = A() + assert (a.a, a.b) == (2, 3) + assert a_called_with == [(0, 3)] + assert b_called_with == [(0, 0)] + + a_called_with[:] = [] + b_called_with[:] = [] + + a = A() + assert (a.b, a.a) == (3, 2) + assert a_called_with == [(0, 0)] + assert b_called_with == [(0, 2)] + + a_called_with[:] = [] + b_called_with[:] = [] + + # passing in constructor prevents validator from being called with defaults + a = A(a=2, b=3) + assert (a.b, a.a) == (3, 2) + assert a_called_with == [(2, 3)] + assert b_called_with == [(3, 2)] + + class TransitionalClass(HasTraits): d = Any() @@ -2481,6 +2600,7 @@ class SuperHasTraits(HasTraits, SuperRecorder): assert obj.super_args == ('a1' , 'a2') assert obj.super_kwargs == {'b': 10 , 'c': 'x'} + def test_super_bad_args(): class SuperHasTraits(HasTraits): a = Integer() diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index df06228b..e5362922 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -519,8 +519,18 @@ def get(self, obj, cls=None): raise TraitError("No default value found for " "the '%s' trait named '%s' of %r" % ( type(self).__name__, self.name, obj)) - value = self._validate(obj, default) + # store initial value without running through cross-validation + # to avoid infinite recursion on cyclical validators + with obj.cross_validation_lock: + value = self._validate(obj, default) obj._trait_values[self.name] = value + # can run cross-validation now, after storing initial value + # this runs internal trait validation twice, + # but better to avoid diverging from _validate + if not obj._cross_validation_lock: + value = self._validate(obj, default) + obj._trait_values[self.name] = value + obj.notify_change(Bunch( name=self.name, value=value, @@ -582,7 +592,7 @@ def _validate(self, obj, value): return value if hasattr(self, 'validate'): value = self.validate(obj, value) - if obj._cross_validation_lock is False: + if not obj._cross_validation_lock: value = self._cross_validate(obj, value) return value