Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions traitlets/tests/test_traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,125 @@ class C(HasTraits):
assert C.d.default_value_repr() == '{}'


class Bounded(HasTraits):
"""From ipywidgets which was broken by cross-validation of defaults"""
max = Integer(10)
min = Integer(0)
value = Integer(5)

@validate('value')
def _validate_value(self, proposal):
"""Cap and floor value"""
value = proposal.value
if self.min > value or self.max < value:
value = min(max(value, self.min), self.max)
return value

@validate('min')
def _validate_min(self, proposal):
"""Enforce min <= value <= max"""
min = proposal.value
if min > self.max:
raise TraitError('setting min > max')
if min > self.value:
self.value = min
return min

@validate('max')
def _validate_max(self, proposal):
"""Enforce min <= value <= max"""
max = proposal.value
if max < self.min:
raise TraitError('setting max < min')
if max < self.value:
self.value = max
return max


def test_cross_validate_cycles():
obj = Bounded()
assert obj.min == 0
assert obj.max == 10
assert obj.value == 5

obj = Bounded(value=25, max=50)
assert obj.max == 50
assert obj.value == 25

obj = Bounded(max=2)
assert obj.max == 2
assert obj.value == 2

obj = Bounded(min=7)
assert obj.min == 7
assert obj.value == 7

obj = Bounded(min=7, value=8)
assert obj.min == 7
assert obj.value == 8

obj = Bounded(value=8, max=4)
assert obj.value == 4
assert obj.max == 4

obj = Bounded(value=-5, max=-1, min=-10)
assert obj.value == -5
assert obj.max == -1


def test_cross_validate_defaults():
class A(HasTraits):
x = Any(())
y = Any(5)

@validate('x')
def f(self, proposal):
return (1, self.y)

assert A().x == (1, 5)
assert A(y=2).x == (1, 2)

def test_cross_validate_default_cycles():
a_called_with = []
b_called_with = []
class A(HasTraits):
a = Integer(0)
@validate('a')
def _validate_a(self, proposal):
a = proposal.value
a_called_with.append((a, self.b))
return 2

b = Integer(0)
@validate('b')
def _validate_b(self, proposal):
b = proposal.value
b_called_with.append((b, self.a))
return 3

a = A()
assert (a.a, a.b) == (2, 3)
assert a_called_with == [(0, 3)]
assert b_called_with == [(0, 0)]

a_called_with[:] = []
b_called_with[:] = []

a = A()
assert (a.b, a.a) == (3, 2)
assert a_called_with == [(0, 0)]
assert b_called_with == [(0, 2)]

a_called_with[:] = []
b_called_with[:] = []

# passing in constructor prevents validator from being called with defaults
a = A(a=2, b=3)
assert (a.b, a.a) == (3, 2)
assert a_called_with == [(2, 3)]
assert b_called_with == [(3, 2)]


class TransitionalClass(HasTraits):

d = Any()
Expand Down Expand Up @@ -2481,6 +2600,7 @@ class SuperHasTraits(HasTraits, SuperRecorder):
assert obj.super_args == ('a1' , 'a2')
assert obj.super_kwargs == {'b': 10 , 'c': 'x'}


def test_super_bad_args():
class SuperHasTraits(HasTraits):
a = Integer()
Expand Down
14 changes: 12 additions & 2 deletions traitlets/traitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,18 @@ def get(self, obj, cls=None):
raise TraitError("No default value found for "
"the '%s' trait named '%s' of %r" % (
type(self).__name__, self.name, obj))
value = self._validate(obj, default)
# store initial value without running through cross-validation
# to avoid infinite recursion on cyclical validators
with obj.cross_validation_lock:
value = self._validate(obj, default)
obj._trait_values[self.name] = value
# can run cross-validation now, after storing initial value
# this runs internal trait validation twice,
# but better to avoid diverging from _validate
if not obj._cross_validation_lock:
value = self._validate(obj, default)
obj._trait_values[self.name] = value

obj.notify_change(Bunch(
name=self.name,
value=value,
Expand Down Expand Up @@ -582,7 +592,7 @@ def _validate(self, obj, value):
return value
if hasattr(self, 'validate'):
value = self.validate(obj, value)
if obj._cross_validation_lock is False:
if not obj._cross_validation_lock:
value = self._cross_validate(obj, value)
return value

Expand Down