deps: Update dependency jax to >=0.9.0#663
Merged
renovate[bot] merged 1 commit intomainfrom Apr 1, 2026
Merged
Conversation
dee1119 to
6af4b1a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR contains the following updates:
>=0.7.2→>=0.9.00.9.1Release Notes
jax-ml/jax (jax)
v0.9.0Compare Source
New features:
jax.thread_guard, a context manager that detects when devicesare used by multiple threads in multi-controller JAX.
Bug fixes:
magma_zgeqp3_gpu)in MAGMA 2.9.0 when using
use_magma=Trueandpivoting=True.({jax-issue}
#34145).Deprecations:
jax_collectives_common_channel_idwas removed.jax_pmap_no_rank_reductionconfig state has been removed. Theno-rank-reduction behavior is now the only supported behavior: a
jax.pmapped functionfsees inputs of the same rank as the input tojax.pmap(f). For example, ifjax.pmap(f)receives shape(8, 128)on8 devices, then
freceives shape(1, 128).jax_pmap_shmap_mergeconfig state is deprecated in JAX v0.9.0and will be removed in JAX v0.10.0.
jax.numpy.fixis deprecated, anticipating the deprecation of{func}
numpy.fixin NumPy v2.5.0. {func}jax.numpy.truncis a drop-inreplacement.
Changes:
jax.exportnow supports explicit sharding. This required a newexport serialization format version that includes the NamedSharding,
including the abstract mesh, and the partition spec. As part of this
change we have added a restriction in the use of exported modules: when
calling them the abstract mesh must match the one used at export time,
including the axis names. Previously, only the number of the devices
mattered.
v0.8.2Compare Source
Deprecations
jax.lax.pvaryhas been deprecated.Please use
jax.lax.pcast(..., to='varying')as the replacement.jax.numpy.arangenow result in adeprecation warning, because the output is poorly-defined.
jax.corea number of symbols are newly deprecated including:call_impl,get_aval,mapped_aval,subjaxprs,set_current_trace,take_current_trace,traverse_jaxpr_params,unmapped_aval,AbstractToken, andTraceTag.jax.interpreters.pxlaare deprecated. These areprimarily JAX internal APIs, and users should not rely on them.
Changes:
jax's
Tracerno longer inherits fromjax.Arrayat runtime. However,jax.Arraynow uses a custom metaclass suchisinstance(x, Array)is trueif an object
xrepresents a tracedArray. Only someTracers representArrays, so it is not correct forTracerto inherit fromArray.For the moment, during Python type checking, we continue to declare
Traceras a subclass of
Array, however we expect to remove this in a futurerelease.
jax.experimental.si_vjphas been deleted.jax.vjpsubsumes it's functionality.Configuration
📅 Schedule: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).
🚦 Automerge: Enabled.
♻ Rebasing: Whenever PR is behind base branch, or you tick the rebase/retry checkbox.
🔕 Ignore: Close this PR and you won't be reminded about this update again.
This PR was generated by Mend Renovate. View the repository job log.