diff --git a/learned_optimization/outer_trainers/gradient_learner.py b/learned_optimization/outer_trainers/gradient_learner.py index a259aa3c..d6fb1749 100644 --- a/learned_optimization/outer_trainers/gradient_learner.py +++ b/learned_optimization/outer_trainers/gradient_learner.py @@ -428,6 +428,8 @@ def extract_one(idx, x): onp.asarray, estimator_out.unroll_info.task_param) iteration = estimator_out.unroll_info.iteration[ idx] if estimator_out.unroll_info.iteration is not None else None + if worker_weights.outer_state is None: + raise ValueError("worker_weights.outer_state is None") event_info.append({ "loss": estimator_out.unroll_info.loss[idx, :], "task_param": jax.tree_util.tree_map(fn, onp_task_params),