diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 01ae19f..c767c48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -89,7 +89,7 @@ repos: # - id: cmake-format - repo: https://github.com/abravalheri/validate-pyproject - rev: "v0.23" + rev: "v0.25" hooks: - id: validate-pyproject additional_dependencies: ["validate-pyproject-schema-store[all]"] diff --git a/src/pscpy/psc.py b/src/pscpy/psc.py index fd3859c..a8d5ba3 100644 --- a/src/pscpy/psc.py +++ b/src/pscpy/psc.py @@ -37,6 +37,7 @@ def __init__( self.x = self._get_coord(0) self.y = self._get_coord(1) self.z = self._get_coord(2) + self.t = float(ds.attrs["time"]) def _get_coord(self, coord_idx: int) -> NDArray[Any]: return np.linspace( @@ -68,17 +69,17 @@ def decode_psc( length: ArrayLike | None = None, corner: ArrayLike | None = None, ) -> xr.Dataset: - da = ds[next(iter(ds))] # first dataset - if da.dims[0] == "dim_0_1": + dims = list(ds.dims) + if "dim_0_1" in dims: # for compatibility, if dimensions weren't saved as attribute in the .bp file, # fix them up here ds = ds.rename_dims( { - da.dims[0]: "step", - # dims[1] is the "component" dimension, which gets removed later - da.dims[2]: "z", - da.dims[3]: "y", - da.dims[4]: "x", + dims[0]: "step", + dims[1]: "component", + dims[2]: "z", + dims[3]: "y", + dims[4]: "x", } ) ds = ds.squeeze("step") @@ -86,7 +87,7 @@ def decode_psc( for var_name in ds: components = list(iter_components(var_name, species_names)) for component_idx, component in enumerate(components): - ds = ds.assign({component: ds[var_name][component_idx, :, :, :]}) + ds = ds.assign({component: ds[var_name].isel(component=component_idx)}) if var_name not in components: ds = ds.drop_vars([var_name]) @@ -95,6 +96,7 @@ def decode_psc( "x": ("x", run_info.x), "y": ("y", run_info.y), "z": ("z", run_info.z), + "t": run_info.t, } ds = ds.assign_coords(coords) diff --git a/tests/test_xarray_adios2.py b/tests/test_xarray_adios2.py index aafbe89..a5acce6 100644 --- a/tests/test_xarray_adios2.py +++ b/tests/test_xarray_adios2.py @@ -98,7 +98,7 @@ def ds_pfd_moments_decoded(ds_pfd_moments_raw) -> xr.Dataset: def test_open_dataset(ds_pfd_decoded): assert "jx_ec" in ds_pfd_decoded - assert ds_pfd_decoded.coords.keys() == set({"x", "y", "z"}) + assert ds_pfd_decoded.coords.keys() == set({"x", "y", "z", "t"}) assert ds_pfd_decoded.jx_ec.sizes == dict(x=1, y=128, z=512) # noqa: C408 assert np.allclose( ds_pfd_decoded.jx_ec.z.data, np.linspace(-25.6, 25.6, 512, endpoint=False).data