Skip to content

Gradients of state variables in post-event simulations #729

@FFroehlich

Description

@FFroehlich

Sorry, took me a while to carve out some time to have a look at this again. Looks like state sensitivity in post-event simulations is now always wrong, independent of usage of ClipStepSizeController (but computed gradient values still appear to depend on it).

this is with latest jax (0.9.0.1) and diffrax (0.7.1)

I understand that for this particular problem, we could simply use the ClipStepSizeController and not add the event to diffeqsolve, but this does not work for other settings we are looking at.

import jax

jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import optimistix as optx
import diffrax

jump_time = 0.98
controller = diffrax.PIDController(rtol=1e-6, atol=1e-6)
controller = diffrax.ClipStepSizeController(controller, jump_ts=[jump_time])

def solve(event_time):
    term = diffrax.ODETerm(
        lambda t, y, args: jnp.array([jnp.select(
            [jnp.less(t,event_time),True],
            [1.0, 0.0]
        )])
    )
    solver = diffrax.Heun()

    sol_event = diffrax.diffeqsolve(
        term,
        solver,
        t0=0,
        t1=2,
        dt0=None,
        y0=jnp.array([0.0]),
        stepsize_controller=controller,
        event=diffrax.Event(
            cond_fn=lambda t, y, args, **kw: event_time - t,
            root_finder=optx.Newton(atol=1e-4, rtol=1e-4),
        ),
        max_steps=100,
    )
    sol = diffrax.diffeqsolve(
        term,
        solver,
        t0=sol_event.ts[-1],
        t1=2,
        dt0=None,
        y0=sol_event.ys[-1],
        stepsize_controller=controller,
        max_steps=100,
    )
    return sol_event, sol

def compute_end_state(event_time):
    sol_event, sol = solve(event_time)
    return sol.ys[-1, 0]

def compute_event_time(event_time):
    sol_event, sol = solve(event_time)
    return sol_event.ts[-1]

assert jnp.isclose(compute_end_state(jump_time),jump_time) # pass
assert jnp.isclose(compute_event_time(jump_time),jump_time) # pass

assert jnp.isclose(jax.grad(compute_event_time)(jump_time), 1.0) # pass
print(jax.grad(compute_end_state)(jump_time))
# 0.62716916609091 without `ClipStepSizeController`
# 2.220446049250313e-16 with `ClipStepSizeController`
assert jnp.isclose(jax.grad(compute_end_state)(jump_time),1.0) # fail

Metadata

Metadata

Assignees

No one assigned

    Labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions