Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pyrenew/deterministic/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,10 @@ def sample(
**kwargs: object,
) -> ArrayLike:
"""
Retrieve the value of the deterministic Rv
Retrieve the value of the deterministic RV

Parameters
----------
record
Whether to record the value of the deterministic
RandomVariable. Defaults to False.
**kwargs
Additional keyword arguments passed through to internal
sample calls, should there be any.
Expand All @@ -87,6 +84,5 @@ def sample(
-------
ArrayLike
"""
if record:
numpyro.deterministic(self.name, self)
numpyro.deterministic(self.name, self.value)
return self.value
18 changes: 18 additions & 0 deletions pyrenew/metaclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax.random as jr
import numpy as np
import numpyro
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample

Expand Down Expand Up @@ -101,6 +102,23 @@ def validate(**kwargs: object) -> None:
"""
pass

def scope(self) -> numpyro.handlers.scope:
"""
Standardized [`numpyro.handlers.scope`][] context for
PyRenew [`RandomVariable`][]s. This can be used to
naming of any internal sampling sites within the
[`RandomVariable`][]'s [`self.sample()`][] method.

The scope prefix is always the [`name`][self.name] of the `RandomVariable`
and the divider is always `::`.

Returns
-------
numpyro.handlers.scope
A properly configured scope handler.
"""
return numpyro.handlers.scope(prefix=self.name, divider="::")

def __call__(self, **kwargs: object) -> tuple:
"""
Alias for `sample`.
Expand Down
13 changes: 7 additions & 6 deletions pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,13 @@ def transition(
recent_vals: ArrayLike, _: ArrayLike
) -> tuple[ArrayLike, ArrayLike]: # numpydoc ignore=GL08
with numpyro.handlers.reparam(config={noise_name: LocScaleReparam(0)}):
next_noise = numpyro.sample(
noise_name,
numpyro.distributions.Normal(
loc=jnp.zeros(noise_shape), scale=noise_sd
),
)
with self.scope():
next_noise = numpyro.sample(
noise_name,
numpyro.distributions.Normal(
loc=jnp.zeros(noise_shape), scale=noise_sd
),
)

dot_prod = jnp.einsum("i...,i...->...", autoreg, recent_vals)
new_term = dot_prod + next_noise
Expand Down
21 changes: 11 additions & 10 deletions pyrenew/process/differencedprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,16 @@ def sample(
)
n_diffs = n - self.differencing_order

if n_diffs > 0:
diff_samp = self.fundamental_process.sample(
*args,
n=n_diffs,
init_vals=fundamental_process_init_vals,
**kwargs,
)
diffs = diff_samp
else:
diffs = jnp.array([])
with self.scope():
if n_diffs > 0:
diff_samp = self.fundamental_process.sample(
*args,
n=n_diffs,
init_vals=fundamental_process_init_vals,
**kwargs,
)
diffs = diff_samp
else:
diffs = jnp.array([])
integrated_ts = integrate_discrete(init_vals, diffs)[:n]
return integrated_ts
36 changes: 18 additions & 18 deletions pyrenew/process/iidrandomsequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(
-------
None
"""
super().__init__(name=name, **kwargs)
self.element_rv = element_rv
super().__init__(name=name, **kwargs)

def sample(
self, n: int, *args: object, vectorize: bool = False, **kwargs: object
Expand Down Expand Up @@ -76,22 +76,22 @@ def sample(
`n` samples from self.distribution`.
"""

if vectorize and hasattr(self.element_rv, "expand_by"):
result = self.element_rv.expand_by((n,)).sample(*args, **kwargs)
else:

def transition(_carry: None, _x: None) -> tuple[None, ArrayLike]:
# numpydoc ignore=GL08
el = self.element_rv.sample(*args, **kwargs)
return None, el

_, result = scan(
transition,
xs=None,
init=None,
length=n,
)

with self.scope():
if vectorize and hasattr(self.element_rv, "expand_by"):
result = self.element_rv.expand_by((n,)).sample(*args, **kwargs)
else:

def transition(_carry: None, _x: None) -> tuple[None, ArrayLike]:
# numpydoc ignore=GL08
el = self.element_rv.sample(*args, **kwargs)
return None, el

_, result = scan(
transition,
xs=None,
init=None,
length=n,
)
return result

@staticmethod
Expand Down Expand Up @@ -138,6 +138,6 @@ def __init__(
super().__init__(
name=name,
element_rv=DistributionalVariable(
name=f"{name}_element", distribution=dist.Normal(0, 1)
name="element", distribution=dist.Normal(0, 1)
).expand_by(element_shape),
)
6 changes: 2 additions & 4 deletions pyrenew/process/randomwalk.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ class constructor.
"""
super().__init__(
name=name,
fundamental_process=IIDRandomSequence(
name=f"{name}_iid_seq", element_rv=step_rv
),
fundamental_process=IIDRandomSequence(name="iid_seq", element_rv=step_rv),
differencing_order=1,
**kwargs,
)
Expand Down Expand Up @@ -85,7 +83,7 @@ def __init__(
super().__init__(
name=name,
step_rv=DistributionalVariable(
name=f"{name}_step", distribution=dist.Normal(0.0, 1.0)
name="step", distribution=dist.Normal(0.0, 1.0)
),
**kwargs,
)
17 changes: 6 additions & 11 deletions pyrenew/randomvariable/transformedvariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,14 @@ def __init__(
self.transforms = transforms
self.validate()

def sample(self, record: bool = False, **kwargs: object) -> tuple:
def sample(self, **kwargs: object) -> tuple:
"""
Sample method. Call self.base_rv.sample()
and then apply the transforms specified
in self.transforms.

Parameters
----------
record
Whether to record the value of the deterministic
RandomVariable. Defaults to False.
**kwargs
Keyword arguments passed to self.base_rv.sample()

Expand All @@ -79,10 +76,11 @@ def sample(self, record: bool = False, **kwargs: object) -> tuple:
t(uv) for t, uv in zip(self.transforms, untransformed_values)
)

if record:
if len(untransformed_values) == 1:
numpyro.deterministic(self.name, transformed_values)
else:
if len(transformed_values) == 1:
transformed_values = transformed_values[0]
numpyro.deterministic(self.name, transformed_values)
else:
with self.scope():
suffixes = (
untransformed_values._fields
if hasattr(untransformed_values, "_fields")
Expand All @@ -91,9 +89,6 @@ def sample(self, record: bool = False, **kwargs: object) -> tuple:
for suffix, tv in zip(suffixes, transformed_values):
numpyro.deterministic(f"{self.name}_{suffix}", tv)

if len(transformed_values) == 1:
transformed_values = transformed_values[0]

return transformed_values

def sample_length(self) -> int:
Expand Down