-
-
Notifications
You must be signed in to change notification settings - Fork 174
Open
Labels
questionUser queriesUser queries
Description
Sorry, took me a while to carve out some time to have a look at this again. Looks like state sensitivity in post-event simulations is now always wrong, independent of usage of ClipStepSizeController (but computed gradient values still appear to depend on it).
this is with latest jax (0.9.0.1) and diffrax (0.7.1)
I understand that for this particular problem, we could simply use the ClipStepSizeController and not add the event to diffeqsolve, but this does not work for other settings we are looking at.
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import optimistix as optx
import diffrax
jump_time = 0.98
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time])
def solve(event_time):
term = diffrax.ODETerm(
lambda t, y, args: jnp.array([jnp.select(
[jnp.less(t,event_time),True],
[1.0, 0.0]
)])
)
solver = diffrax.Heun()
sol_event = diffrax.diffeqsolve(
term,
solver,
t0=0,
t1=2,
dt0=None,
y0=jnp.array([0.0]),
stepsize_controller=controller,
event=diffrax.Event(
cond_fn=lambda t, y, args, **kw: event_time - t,
root_finder=optx.Newton(atol=1e-4, rtol=1e-4),
),
max_steps=100,
)
sol = diffrax.diffeqsolve(
term,
solver,
t0=sol_event.ts[-1],
t1=2,
dt0=None,
y0=sol_event.ys[-1],
stepsize_controller=controller,
max_steps=100,
)
return sol_event, sol
def compute_end_state(event_time):
sol_event, sol = solve(event_time)
return sol.ys[-1, 0]
def compute_event_time(event_time):
sol_event, sol = solve(event_time)
return sol_event.ts[-1]
assert jnp.isclose(compute_end_state(jump_time),jump_time) # pass
assert jnp.isclose(compute_event_time(jump_time),jump_time) # pass
assert jnp.isclose(jax.grad(compute_event_time)(jump_time), 1.0) # pass
print(jax.grad(compute_end_state)(jump_time))
# 0.62716916609091 without `ClipStepSizeController`
# 2.220446049250313e-16 with `ClipStepSizeController`
assert jnp.isclose(jax.grad(compute_end_state)(jump_time),1.0) # fail
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
questionUser queriesUser queries