diff --git a/pyrenew/math.py b/pyrenew/math.py index 62b08308..a065c3c5 100755 --- a/pyrenew/math.py +++ b/pyrenew/math.py @@ -27,6 +27,7 @@ def _positive_ints_like(vec: ArrayLike) -> jnp.ndarray: jnp.ndarray The resulting array ``[1, ..., n]``. """ + vec = jnp.asarray(vec) return jnp.arange(1, jnp.size(vec) + 1) diff --git a/test/test_convolve.py b/test/test_convolve.py index a9211f85..08f1273e 100644 --- a/test/test_convolve.py +++ b/test/test_convolve.py @@ -82,7 +82,7 @@ def transform_a(x: any): [ jnp.ones(3), jnp.array(np.array([0.5, 0.3, 0.2] * 3)).reshape(3, 3), - jnp.ones((3, 3)), + jnp.array([[-0.25, 0.25, 0.25], [0, 1, 0], [0.5, -0.5, 0.5]]), t.ExpTransform(), ], ], @@ -95,14 +95,13 @@ def test_convolve_scanner_using_scan(arr, history, multipliers, transform): """ scanner = pc.new_convolve_scanner(arr, transform) - _, result = jax.lax.scan(f=scanner, init=history, xs=multipliers) + _, result = jax.lax.scan(f=scanner, init=jnp.array(history), xs=multipliers) result_not_scanned = [] for multiplier in multipliers: history, new_val = scanner(history, multiplier) result_not_scanned.append(new_val) - - assert jnp.array_equal(result, result_not_scanned) + assert_array_equal(result, result_not_scanned) @pytest.mark.parametrize( diff --git a/test/test_infectionwithfeedback_plate_compatibility.py b/test/test_infectionwithfeedback_plate_compatibility.py index 5cce03c9..d6b7a740 100644 --- a/test/test_infectionwithfeedback_plate_compatibility.py +++ b/test/test_infectionwithfeedback_plate_compatibility.py @@ -20,7 +20,9 @@ def test_infections_with_feedback_plate_compatibility(): Rt = jnp.ones((10, 5)) gen_int = jnp.array([0.4, 0.25, 0.25, 0.1]) - inf_feed_strength = DistributionalVariable("inf_feed_strength", dist.Beta(1, 1)) + inf_feed_strength = DistributionalVariable( + "inf_feed_strength", dist.LogNormal(0.0, 1.0) + ) inf_feedback_pmf = DeterministicPMF(name="inf_feedback_pmf", value=gen_int) # Test the InfectionsWithFeedback class diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index ec6c2af6..5aed3775 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -150,7 +150,7 @@ class TestCountsWithPriors: def test_with_stochastic_ascertainment(self, short_shedding_pmf): """Test with uncertain ascertainment rate parameter.""" delay = DeterministicPMF("delay", jnp.array([0.2, 0.5, 0.3])) - ascertainment = DistributionalVariable("ihr", dist.Beta(2, 100)) + ascertainment = DistributionalVariable("ihr", dist.LogNormal(-1, 100.0)) concentration = DeterministicVariable("conc", 10.0) process = Counts( diff --git a/uv.lock b/uv.lock index 013b481d..4c5f1d02 100644 --- a/uv.lock +++ b/uv.lock @@ -609,7 +609,7 @@ wheels = [ [[package]] name = "jax" -version = "0.9.0.1" +version = "0.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, @@ -618,14 +618,14 @@ dependencies = [ { name = "opt-einsum" }, { name = "scipy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/52/40/f85d1feadd8f793fc1bfab726272523ef34b27302b55861ea872ec774019/jax-0.9.0.1.tar.gz", hash = "sha256:e395253449d74354fa813ff9e245acb6e42287431d8a01ff33d92e9ee57d36bd", size = 2534795, upload-time = "2026-02-05T18:47:33.088Z" } +sdist = { url = "https://files.pythonhosted.org/packages/25/4d/f45853fdc2b811e78b866d5f80b8a21a848278361f66c066706132f415cf/jax-0.9.1.tar.gz", hash = "sha256:ce1b82477ee192f0b1d9801b095aa0cf3839bc1fe0cbc071c961a24b3ff30361", size = 2625994, upload-time = "2026-03-02T11:24:18.382Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/1e/63ac22ec535e08129e16cb71b7eeeb8816c01d627ea1bc9105e925a71da0/jax-0.9.0.1-py3-none-any.whl", hash = "sha256:3baeaec6dc853394c272eb38a35ffba1972d67cf55d07a76bdb913bcd867e2ca", size = 2955477, upload-time = "2026-02-05T18:45:22.885Z" }, + { url = "https://files.pythonhosted.org/packages/80/e4/88778c6a23b65224e5088e68fd0924e5bde2196a26e76edb3ea3543fed6a/jax-0.9.1-py3-none-any.whl", hash = "sha256:d11cb53d362912253013e8c4d6926cb9f3a4b59ab5b25a7dc08123567067d088", size = 3062162, upload-time = "2026-03-02T11:22:05.089Z" }, ] [[package]] name = "jaxlib" -version = "0.9.0.1" +version = "0.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, @@ -633,20 +633,20 @@ dependencies = [ { name = "scipy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/8d/f5a78b4d2a08e2d358e01527a3617af2df67c70231029ce1bdbb814219ff/jaxlib-0.9.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e857cafdd12e18493d96d4a290ed31aa9d99a0dc3056b4b42974c0f342c9bb0c", size = 56103168, upload-time = "2026-02-05T18:46:46.481Z" }, - { url = "https://files.pythonhosted.org/packages/47/c3/fd3a9e2f02c1a04a1a00ff74adb6dd09e34040587bbb1b51b0176151dfa1/jaxlib-0.9.0.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:b73b85f927d9b006f07622d5676092eab916645c4804fed6568da5fb4a541dfc", size = 74768692, upload-time = "2026-02-05T18:46:49.571Z" }, - { url = "https://files.pythonhosted.org/packages/d9/48/34923a6add7dda5fb8f30409a98b638f0dbd2d9571dbbf73db958eaec44a/jaxlib-0.9.0.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:54dd2d34c6bec4f099f888a2f7895069a47c3ba86aaa77b0b78e9c3f9ef948f1", size = 80337646, upload-time = "2026-02-05T18:46:53.299Z" }, - { url = "https://files.pythonhosted.org/packages/a8/a9/629bed81406902653973d57de5af92842c7da63dfa8fcd84ee490c62ee94/jaxlib-0.9.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:27db7fbc49938f819f2a93fefef0bdc25bd523b499ab4d8a71ed8915c037c0b4", size = 60508306, upload-time = "2026-02-05T18:46:56.441Z" }, - { url = "https://files.pythonhosted.org/packages/45/e3/6943589aaa58d9934838e00c6149dd1fc81e0c8555e9fcc9f527648faf5c/jaxlib-0.9.0.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9312fcfb4c5586802c08bc1b3b2419e48aa2a4cd1356251fe791ad71edc2da2a", size = 56210697, upload-time = "2026-02-05T18:46:59.642Z" }, - { url = "https://files.pythonhosted.org/packages/7e/ff/39479759b71f1d281b77050184759ac76dfd23a3ae75132ef92d168099c5/jaxlib-0.9.0.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:b536512cf84a0cb031196d6d5233f7093745e87eb416e45ad96fbb764b2befed", size = 74882879, upload-time = "2026-02-05T18:47:02.708Z" }, - { url = "https://files.pythonhosted.org/packages/87/0d/e41eeddd761110d733688d6493defe776440c8f3d114419a8ecaef55601f/jaxlib-0.9.0.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:c4dc8828bb236532033717061d132906075452556b12d1ff6ccc10e569435dfe", size = 80438424, upload-time = "2026-02-05T18:47:06.437Z" }, - { url = "https://files.pythonhosted.org/packages/fd/ec/54b1251cea5c74a2f0d22106f5d1c7dc9e7b6a000d6a81a88deffa34c6fe/jaxlib-0.9.0.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:43272e52e5c89dbc4f02c7ccb6ffa5d587a09ac8db5163cb0c43e125b7075129", size = 56101484, upload-time = "2026-02-05T18:47:09.46Z" }, - { url = "https://files.pythonhosted.org/packages/29/ce/91ba780439aa1e6bae964ea641169e8b9c9349c175fcb1a723b96ba54313/jaxlib-0.9.0.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:82348cee1521d6123038c4c3beeafa2076c8f4ae29a233b8abff9d6dc8b44145", size = 74789558, upload-time = "2026-02-05T18:47:12.394Z" }, - { url = "https://files.pythonhosted.org/packages/ce/9b/3d7baca233c378b01fa445c9f63b260f592249ff69950baf893cea631b10/jaxlib-0.9.0.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:e61e88032eeb31339c72ead9ed60c6153cd2222512624caadea67c350c78432e", size = 80343053, upload-time = "2026-02-05T18:47:16.042Z" }, - { url = "https://files.pythonhosted.org/packages/92/5d/80efe5295133d5114fb7b0f27bdf82bc7a2308356dde6ba77c2afbaa3a36/jaxlib-0.9.0.1-cp314-cp314-win_amd64.whl", hash = "sha256:abd9f127d23705105683448781914f17898b2b6591a051b259e6b947d4dcb93f", size = 62826248, upload-time = "2026-02-05T18:47:19.986Z" }, - { url = "https://files.pythonhosted.org/packages/f9/a9/f72578daa6af9bed9bda75b842c97581b31a577d7b2072daf8ba3d5a8156/jaxlib-0.9.0.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5b01a75fbac8098cc985f6f1690bfb62f98b0785c84199287e0baaae50fa4238", size = 56209722, upload-time = "2026-02-05T18:47:23.193Z" }, - { url = "https://files.pythonhosted.org/packages/95/ea/eefb118305dd5e1b0ad8d942f2bf43616c964d89fe491bec8628173da24d/jaxlib-0.9.0.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:76f23cbb109e673ea7a90781aca3e02a0c72464410c019fe14fba3c044f2b778", size = 74881382, upload-time = "2026-02-05T18:47:26.703Z" }, - { url = "https://files.pythonhosted.org/packages/0a/aa/a42fb912fd1f9c83e22dc2577cdfbf1a1b07d6660532cb44724db7a7c479/jaxlib-0.9.0.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:f80d30dedce96c73a7f5dcb79c4c827a1bde2304f502a56ce7e7f723df2a5398", size = 80438052, upload-time = "2026-02-05T18:47:30.039Z" }, + { url = "https://files.pythonhosted.org/packages/08/18/fee700125fe4367c75be1d0f300d13069f5ed119a635ea9199de4b4bc9dc/jaxlib-0.9.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e9915bcaa9ffefd40cd3fdb08a83b16b79f1f3c9ba187884f5b442ad2a47ffd1", size = 57982624, upload-time = "2026-03-02T11:23:31.412Z" }, + { url = "https://files.pythonhosted.org/packages/fd/5f/d4a79d6802f3cef02773852453d9528569dd0896964117d4401658828aba/jaxlib-0.9.1-cp313-cp313-manylinux_2_27_aarch64.whl", hash = "sha256:9e88c35248b37d5219423ff8ddca60c6a561e665ded5c4fcbc61f0763e03f1e3", size = 76828438, upload-time = "2026-03-02T11:23:34.793Z" }, + { url = "https://files.pythonhosted.org/packages/3e/2e/d84cafbd07e8cdc7701d9f840f4eea0cfcf3487a99ada14507702172da14/jaxlib-0.9.1-cp313-cp313-manylinux_2_27_x86_64.whl", hash = "sha256:da60d967b4ac2084a3e3535ad982392894dd6bdf79c9a56978aba08404a58c82", size = 82473711, upload-time = "2026-03-02T11:23:38.356Z" }, + { url = "https://files.pythonhosted.org/packages/45/e6/4d09ec33a5d096c541025272dc31a36aa9d9a5752b37e05193b23c125810/jaxlib-0.9.1-cp313-cp313-win_amd64.whl", hash = "sha256:7ec6e2f43be6e1ae9321efe9a98affcd8acbe0e1fe59aba1d307ba0462752988", size = 62164682, upload-time = "2026-03-02T11:23:41.761Z" }, + { url = "https://files.pythonhosted.org/packages/8a/be/7d810371aa3bdf30882df60965c15773b8990c90e350a650e366e6dedbaa/jaxlib-0.9.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:872e5917ad20cfde85ce6d50a6dffb205ce551d5c691532f0f07e30c34bbb6c3", size = 58092440, upload-time = "2026-03-02T11:23:46.233Z" }, + { url = "https://files.pythonhosted.org/packages/e9/63/0f5acacd3bd6906f2e1f730ceeafac4afc5cc612f43be4820785608cb951/jaxlib-0.9.1-cp313-cp313t-manylinux_2_27_aarch64.whl", hash = "sha256:469f08a30f6b541557e29c5de61ea6df16ac0ef9225879373bb2b332f1b27d14", size = 76949185, upload-time = "2026-03-02T11:23:49.378Z" }, + { url = "https://files.pythonhosted.org/packages/91/c5/a4dee13627d913c7bd0cf29b7f5c1d6a2605760d08a7cff952f9098ebb61/jaxlib-0.9.1-cp313-cp313t-manylinux_2_27_x86_64.whl", hash = "sha256:2e2225b80689610cbb472822dadf7cc200aa4bdac813112a3f6e074d96b1458c", size = 82584273, upload-time = "2026-03-02T11:23:52.762Z" }, + { url = "https://files.pythonhosted.org/packages/a4/b0/f2c9caa6f545d4ecc1eab528c68c9191e40087f1bc79a6da2e29c6416510/jaxlib-0.9.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:3071bf493f6f48207c56b1e9a5bf895e2acebc5bd40f6f35458e76eb8bf210c7", size = 57984052, upload-time = "2026-03-02T11:23:55.766Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e7/237ec5f4cd07420ef50d79a048b769664dbe306e31bdb10f9dcb9accabe9/jaxlib-0.9.1-cp314-cp314-manylinux_2_27_aarch64.whl", hash = "sha256:531dff9fae7aea14449ee544cc1415880cc8a346a9287d347dbd1b2b51d8aabd", size = 76846925, upload-time = "2026-03-02T11:23:59.18Z" }, + { url = "https://files.pythonhosted.org/packages/76/fe/67d2c414b0860d42f4a20b1fadbe7aeffb1b3d885efebd7aedf22a4bc2a2/jaxlib-0.9.1-cp314-cp314-manylinux_2_27_x86_64.whl", hash = "sha256:2287a1c891b152c52eb9b73925f57cde01be35d2bab4dad9673d3c83c5982ca8", size = 82484342, upload-time = "2026-03-02T11:24:02.541Z" }, + { url = "https://files.pythonhosted.org/packages/54/0d/a8e27c1c434e489883c1182bd52de27775b8a78013de62e6eabf80991df5/jaxlib-0.9.1-cp314-cp314-win_amd64.whl", hash = "sha256:61160d686e6a4703ef30a6a3aa199c934e6359f42d0aa1c0f9c475d3953b9459", size = 64553355, upload-time = "2026-03-02T11:24:05.976Z" }, + { url = "https://files.pythonhosted.org/packages/fa/4a/e5cb3a32320da2e9496c66045a4e19e16597c92a6496dd493b630585c219/jaxlib-0.9.1-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5ac3db6b164a8a5b473c77ad9da4f43937d309a27f5cb2f38932930b26e42c68", size = 58096335, upload-time = "2026-03-02T11:24:09.01Z" }, + { url = "https://files.pythonhosted.org/packages/50/d2/35ecc2e92065ac035a954fcb4b752baa72747dcc3a3466525c42c4404958/jaxlib-0.9.1-cp314-cp314t-manylinux_2_27_aarch64.whl", hash = "sha256:30fe58e8e4e105dffe364a6f0dccca16d93433576d4a015babc83339ca7f1f38", size = 76948543, upload-time = "2026-03-02T11:24:12.026Z" }, + { url = "https://files.pythonhosted.org/packages/ba/cb/a8de776aee88f42937d07472953cf7980e45f5fb30aa9d5ee652b4acc771/jaxlib-0.9.1-cp314-cp314t-manylinux_2_27_x86_64.whl", hash = "sha256:6b6654a20d54e7cc77d1d54c33f1db851ef9d70bb112b627776178221036e720", size = 82585090, upload-time = "2026-03-02T11:24:15.783Z" }, ] [[package]] @@ -1602,7 +1602,7 @@ wheels = [ [[package]] name = "pyrenew" -version = "0.1.6" +version = "0.1.7" source = { editable = "." } dependencies = [ { name = "jax" },