-
Notifications
You must be signed in to change notification settings - Fork 1
Description
While I have been able to fit innovationMLR model to counts data for different dates, there is a particular date for which the optimization fails.
The specific error (at the bottom of the issue description) is deep in the jaxlib optimization. I am guessing 'math.prod(other_sizes)' is 0 for some reason.
I am wondering whether this is due to some unusual parent-child variant relationship graph defined for the variants circulating at the data window. This would result in an innovation_matrix that is not appropriate for optimization. Since the same code works for data at different dates, I thought this is the most likely reason.
I don't know much about the details of the optimization to have a guess, but are there constraints on the innovation_matrix or the parent-child graph?
I included debugging code output below with some details about the input matrices(column or row sums) to the model optimization.
col sums of seq_counts
[ 258. 257. 2257. 159. 112. 721. 1547. 190. 279. 477.
443. 757. 312. 169. 158. 603. 468. 392. 232. 150.
185. 169. 2929. 248. 1248. 9853. 2314. 847. 138. 528.
2963. 6672. 161. 253. 425. 236. 180. 605. 557. 1656.
184. 1519. 158. 204. 499. 158. 145. 721. 687. 135.
492. 1638. 13521. 651. 197. 3739. 238. 277. 197. 878.
672. 487. 331. 500. 520. 504. 647. 1755. 162. 666.
1658. 134. 1174. 949. 614. 464. 133. 288. 142. 370.
337. 169. 310. 228. 420. 526. 115. 214. 252. 161.
157. 170. 174. 137. 431. 144. 168. 452. 275. 639.
33854.]
row sums of seq_counts
[2074. 2040. 1718. 1216. 1230. 2428. 2267. 2201. 2003. 1705. 1313. 1258.
2479. 2265. 2095. 1883. 1630. 1256. 1045. 1131. 2031. 2490. 2536. 2174.
1420. 1291. 1467. 2550. 2367. 2453. 1880. 1249. 1190. 2298. 1856. 1792.
1896. 1645. 1024. 1091. 1685. 1778. 1781. 1720. 1374. 986. 1023. 1822.
1684. 1544. 1426. 1192. 821. 801. 1813. 1442. 1319. 1251. 1115. 803.
777. 1371. 1322. 1252. 1153. 957. 622. 602. 1259. 837. 870. 833.
722. 530. 536. 806. 807. 714. 666. 594. 317. 349. 624. 550.
435. 261. 277. 152. 128. 209.]
N
(90,)
col sums of N
119849.0
innovation_matrix
(101, 101)
col sums of innovation_matrix
[ 9 8 6 1 1 12 5 2 1 2 1 1 1 1 1 1 2 1 1 1 1 1 3 1
1 1 2 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 2 1 1 1 1 1
1 1 1 1 7 1 1 2 1 1 1 1 2 1 1 1 1 1 1 2 1 3 1 1
1 2 1 1 1 1 1 1 12 6 1 1 1 2 1 1 4 3 1 1 2 1 1 1
1 1 1 6 48]
row sums of innovation_matrix
[1 2 3 3 1 1 2 2 2 2 2 1 1 1 1 1 4 5 1 1 1 3 3 4 4 3 1 2 1 4 3 2 3 1 1 2 2
2 2 2 2 2 2 3 2 2 2 2 2 2 2 2 2 3 3 3 3 3 2 2 2 3 2 2 2 2 2 2 3 2 3 3 2 4
5 4 4 4 1 4 1 2 3 3 3 3 2 1 2 3 1 1 1 2 3 3 3 3 3 2 1]
trace: 101
integer division or modulo by zero
Traceback (most recent call last):
File ncov_forecasting_growth/script/run_innovation_mlr_predictive.py", line 370, in run_mlr_innovation_model
posterior = inference_method.fit(mlr, variant_freqs,name=location)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/evofr/infer/InferSVI.py", line 84, in fit
self.handler.fit(model.model_fn, guide, input, self.iters)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/evofr/infer/SVI_handler.py", line 54, in fit
self.init_svi(model, guide, data)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/evofr/infer/SVI_handler.py", line 48, in init_svi
svi_state = self.svi.init(self.rng_key, **data)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/numpyro/infer/svi.py", line 184, in init
guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/numpyro/infer/autoguide.py", line 589, in __call__
log_density = sum_rightmost(
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/numpyro/distributions/util.py", line 320, in sum_rightmost
x = jnp.reshape(jnp.expand_dims(x, -1), jnp.shape(x)[:out_dim] + (-1,))
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py", line 794, in reshape
return a.reshape(newshape, order=order) # type: ignore[call-overload,union-attr]
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py", line 155, in _reshape
newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
File .snakemake/conda/01d0d78011f1c000c46728873908ef0e_/lib/python3.9/site-packages/jax/_src/numpy/array_methods.py", line 132, in _compute_newshape
np.size(a) % math.prod(other_sizes) != 0):
ZeroDivisionError: integer division or modulo by zero