diff --git a/docs/tutorials/notebooks b/docs/tutorials/notebooks index 02859d31a..49ad4ef72 160000 --- a/docs/tutorials/notebooks +++ b/docs/tutorials/notebooks @@ -1 +1 @@ -Subproject commit 02859d31a0df0245d36af905d3eb3068a9965445 +Subproject commit 49ad4ef723f65ed56886c14bea242332f9244d8a diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index 719472612..bf70b3cc4 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -4,13 +4,16 @@ import logging import os.path import re +import sys import tempfile +import traceback import warnings from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager +from enum import Enum from functools import singledispatch from pathlib import Path -from typing import Any +from typing import Any, Literal import zarr from anndata import AnnData @@ -383,3 +386,59 @@ def save_transformations(sdata: SpatialData) -> None: stacklevel=2, ) sdata.write_transformations() + + +class BadFileHandleMethod(Enum): + ERROR = "error" + WARN = "warn" + + +@contextmanager +def handle_read_errors( + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN], + location: str, + exc_types: tuple[type[Exception], ...], +) -> Generator[None, None, None]: + """ + Handle read errors according to parameter `on_bad_files`. + + Parameters + ---------- + on_bad_files + Specifies what to do upon encountering an exception. + Allowed values are : + + - 'error', let the exception be raised. + - 'warn', convert the exception into a warning if it is one of the expected exception types. + location + String identifying the function call where the exception happened + exc_types + A tuple of expected exception classes that should be converted into warnings. + + Raises + ------ + If `on_bad_files="error"`, all encountered exceptions are raised. + If `on_bad_files="warn"`, any encountered exceptions not matching the `exc_types` are raised. + """ + on_bad_files = BadFileHandleMethod(on_bad_files) # str to enum + if on_bad_files == BadFileHandleMethod.WARN: + try: + yield + except exc_types as e: + # Extract the original filename and line number from the exception and + # create a warning from it. + exc_traceback = sys.exc_info()[-1] + last_frame, lineno = list(traceback.walk_tb(exc_traceback))[-1] + filename = last_frame.f_code.co_filename + # Include the location (element path) in the warning message. + message = f"{location}: {e.__class__.__name__}: {e.args[0]}" + warnings.warn_explicit( + message=message, + category=UserWarning, + filename=filename, + lineno=lineno, + ) + # continue + else: # on_bad_files == BadFileHandleMethod.ERROR + # Let it raise exceptions + yield diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 55136da12..023129d50 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -1,6 +1,8 @@ from __future__ import annotations import os +from json import JSONDecodeError +from typing import Literal import numpy as np import zarr @@ -8,14 +10,20 @@ from anndata import read_zarr as read_anndata_zarr from anndata._io.specs import write_elem as write_adata from ome_zarr.format import Format +from zarr.errors import ArrayNotFoundError +from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version from spatialdata._logging import logger from spatialdata.models import TableModel def _read_table( - zarr_store_path: str, group: zarr.Group, subgroup: zarr.Group, tables: dict[str, AnnData] + zarr_store_path: str, + group: zarr.Group, + subgroup: zarr.Group, + tables: dict[str, AnnData], + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> dict[str, AnnData]: """ Read in tables in the tables Zarr.group of a SpatialData Zarr store. @@ -30,6 +38,8 @@ def _read_table( The subgroup containing the tables. tables A dictionary of tables. + on_bad_files + Specifies what to do upon encountering a bad file, e.g. corrupted, invalid or missing files. Returns ------- @@ -40,33 +50,38 @@ def _read_table( f_elem = subgroup[table_name] f_elem_store = os.path.join(zarr_store_path, f_elem.path) - tables[table_name] = read_anndata_zarr(f_elem_store) + with handle_read_errors( + on_bad_files=on_bad_files, + location=f"{subgroup.path}/{table_name}", + exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError), + ): + tables[table_name] = read_anndata_zarr(f_elem_store) - f = zarr.open(f_elem_store, mode="r") - version = _parse_version(f, expect_attrs_key=False) - assert version is not None - # since have just one table format, we currently read it but do not use it; if we ever change the format - # we can rename the two _ to format and implement the per-format read logic (as we do for shapes) - _ = TablesFormats[version] - f.store.close() + f = zarr.open(f_elem_store, mode="r") + version = _parse_version(f, expect_attrs_key=False) + assert version is not None + # since have just one table format, we currently read it but do not use it; if we ever change the format + # we can rename the two _ to format and implement the per-format read logic (as we do for shapes) + _ = TablesFormats[version] + f.store.close() - # # replace with format from above - # version = "0.1" - # format = TablesFormats[version] - if TableModel.ATTRS_KEY in tables[table_name].uns: - # fill out eventual missing attributes that has been omitted because their value was None - attrs = tables[table_name].uns[TableModel.ATTRS_KEY] - if "region" not in attrs: - attrs["region"] = None - if "region_key" not in attrs: - attrs["region_key"] = None - if "instance_key" not in attrs: - attrs["instance_key"] = None - # fix type for region - if "region" in attrs and isinstance(attrs["region"], np.ndarray): - attrs["region"] = attrs["region"].tolist() + # # replace with format from above + # version = "0.1" + # format = TablesFormats[version] + if TableModel.ATTRS_KEY in tables[table_name].uns: + # fill out eventual missing attributes that has been omitted because their value was None + attrs = tables[table_name].uns[TableModel.ATTRS_KEY] + if "region" not in attrs: + attrs["region"] = None + if "region_key" not in attrs: + attrs["region_key"] = None + if "instance_key" not in attrs: + attrs["instance_key"] = None + # fix type for region + if "region" in attrs and isinstance(attrs["region"], np.ndarray): + attrs["region"] = attrs["region"].tolist() - count += 1 + count += 1 logger.debug(f"Found {count} elements in {subgroup}") return tables diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 0be7c8f4f..224ef1129 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,13 +1,17 @@ import logging import os import warnings +from json import JSONDecodeError from pathlib import Path +from typing import Literal import zarr from anndata import AnnData +from pyarrow import ArrowInvalid +from zarr.errors import ArrayNotFoundError, MetadataError from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import ome_zarr_logger +from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors, ome_zarr_logger from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes @@ -36,7 +40,11 @@ def _open_zarr_store(store: str | Path | zarr.Group) -> tuple[zarr.Group, str]: return f, f_store_path -def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = None) -> SpatialData: +def read_zarr( + store: str | Path | zarr.Group, + selection: None | tuple[str] = None, + on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, +) -> SpatialData: """ Read a SpatialData dataset from a zarr store (on-disk or remote). @@ -49,6 +57,16 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are read. + on_bad_files + Specifies what to do upon encountering a bad file, e.g. corrupted, invalid or missing files. + Allowed values are : + + - 'error', raise an exception when a bad file is encountered. Reading aborts immediately + with an error. + - 'warn', raise a warning when a bad file is encountered and skip that file. A SpatialData + object is returned containing only elements that could be read. Failures can only be + determined from the warnings. + Returns ------- A SpatialData object. @@ -67,23 +85,12 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non # read multiscale images if "images" in selector and "images" in f: - group = f["images"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - element = _read_multiscale(f_elem_store, raster_type="image") - images[subgroup_name] = element - count += 1 - logger.debug(f"Found {count} elements in {group}") - - # read multiscale labels - with ome_zarr_logger(logging.ERROR): - if "labels" in selector and "labels" in f: - group = f["labels"] + with handle_read_errors( + on_bad_files, + location="images", + exc_types=(JSONDecodeError, MetadataError), + ): + group = f["images"] count = 0 for subgroup_name in group: if Path(subgroup_name).name.startswith("."): @@ -91,39 +98,106 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non continue f_elem = group[subgroup_name] f_elem_store = os.path.join(f_store_path, f_elem.path) - labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") - count += 1 + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=( + JSONDecodeError, # JSON parse error + ValueError, # ome_zarr: Unable to read the NGFF file + KeyError, # Missing JSON key + ArrayNotFoundError, # Image chunks missing + TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 + ), + ): + element = _read_multiscale(f_elem_store, raster_type="image") + images[subgroup_name] = element + count += 1 logger.debug(f"Found {count} elements in {group}") + # read multiscale labels + with ome_zarr_logger(logging.ERROR): + if "labels" in selector and "labels" in f: + with handle_read_errors( + on_bad_files, + location="labels", + exc_types=(JSONDecodeError, MetadataError), + ): + group = f["labels"] + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + f_elem = group[subgroup_name] + f_elem_store = os.path.join(f_store_path, f_elem.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError, TypeError), + ): + labels[subgroup_name] = _read_multiscale(f_elem_store, raster_type="labels") + count += 1 + logger.debug(f"Found {count} elements in {group}") + # now read rest of the data if "points" in selector and "points" in f: - group = f["points"] - count = 0 - for subgroup_name in group: - f_elem = group[subgroup_name] - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem_store = os.path.join(f_store_path, f_elem.path) - points[subgroup_name] = _read_points(f_elem_store) - count += 1 - logger.debug(f"Found {count} elements in {group}") + with handle_read_errors( + on_bad_files, + location="points", + exc_types=(JSONDecodeError, MetadataError), + ): + group = f["points"] + count = 0 + for subgroup_name in group: + f_elem = group[subgroup_name] + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + f_elem_store = os.path.join(f_store_path, f_elem.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=(JSONDecodeError, KeyError, ArrowInvalid), + ): + points[subgroup_name] = _read_points(f_elem_store) + count += 1 + logger.debug(f"Found {count} elements in {group}") if "shapes" in selector and "shapes" in f: - group = f["shapes"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - shapes[subgroup_name] = _read_shapes(f_elem_store) - count += 1 - logger.debug(f"Found {count} elements in {group}") + with handle_read_errors( + on_bad_files, + location="shapes", + exc_types=(JSONDecodeError, MetadataError), + ): + group = f["shapes"] + count = 0 + for subgroup_name in group: + if Path(subgroup_name).name.startswith("."): + # skip hidden files like .zgroup or .zmetadata + continue + f_elem = group[subgroup_name] + f_elem_store = os.path.join(f_store_path, f_elem.path) + with handle_read_errors( + on_bad_files, + location=f"{group.path}/{subgroup_name}", + exc_types=( + JSONDecodeError, + ValueError, + KeyError, + ArrayNotFoundError, + ), + ): + shapes[subgroup_name] = _read_shapes(f_elem_store) + count += 1 + logger.debug(f"Found {count} elements in {group}") if "tables" in selector and "tables" in f: - group = f["tables"] - tables = _read_table(f_store_path, f, group, tables) + with handle_read_errors( + on_bad_files, + location="tables", + exc_types=(JSONDecodeError, MetadataError), + ): + group = f["tables"] + tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) if "table" in selector and "table" in f: warnings.warn( @@ -133,10 +207,15 @@ def read_zarr(store: str | Path | zarr.Group, selection: None | tuple[str] = Non stacklevel=2, ) subgroup_name = "table" - group = f[subgroup_name] - tables = _read_table(f_store_path, f, group, tables) + with handle_read_errors( + on_bad_files, + location=subgroup_name, + exc_types=(JSONDecodeError, MetadataError), + ): + group = f[subgroup_name] + tables = _read_table(f_store_path, f, group, tables, on_bad_files=on_bad_files) - logger.debug(f"Found {count} elements in {group}") + logger.debug(f"Found {count} elements in {group}") # read attrs metadata attrs = f.attrs.asdict() diff --git a/tests/io/test_partial_read.py b/tests/io/test_partial_read.py new file mode 100644 index 000000000..7c7cdbfa2 --- /dev/null +++ b/tests/io/test_partial_read.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import json +import os +import re +import tempfile +from collections.abc import Generator, Iterable +from contextlib import contextmanager +from dataclasses import dataclass +from json import JSONDecodeError +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import py +import pytest +import zarr +from pyarrow import ArrowInvalid +from zarr.errors import ArrayNotFoundError, MetadataError + +from spatialdata import SpatialData, read_zarr +from spatialdata.datasets import blobs + +if TYPE_CHECKING: + import _pytest.fixtures + + +@contextmanager +def pytest_warns_multiple( + expected_warning: type[Warning] | tuple[type[Warning], ...] = Warning, matches: Iterable[str] = () +) -> Generator[None, None, None]: + """ + Assert that code raises a warnings matching particular patterns. + + Like `pytest.warns`, but with multiple patterns which each must match a warning. + + Parameters + ---------- + expected_warning + A warning class or a tuple of warning classes for which at least one matching warning must be found + matches + Regular expression patterns that of which each must be found in at least one warning message. + """ + if not matches: + yield + else: + with ( + pytest.warns(expected_warning, match=matches[0]), + pytest_warns_multiple(expected_warning, matches=matches[1:]), + ): + yield + + +@pytest.fixture(scope="module") +def test_case(request: _pytest.fixtures.SubRequest): + """ + Fixture that helps to use fixtures as arguments in parametrize. + + The fixture `test_case` takes up values from other fixture functions used as parameters. + """ + fixture_function = request.param + fixture_name = fixture_function.__name__ + return request.getfixturevalue(fixture_name) + + +@dataclass +class PartialReadTestCase: + path: Path + expected_elements: list[str] + expected_exceptions: type[Exception] | tuple[type[Exception], ...] + warnings_patterns: list[str] + + +@pytest.fixture(scope="session") +def session_tmp_path(request: _pytest.fixtures.SubRequest) -> Path: + """ + Create a temporary directory as a fixture with session scope and deletes it afterward. + + The default tmp_path fixture has function scope and cannot be used as input to session-scoped + fixtures. + """ + directory = py.path.local(tempfile.mkdtemp()) + request.addfinalizer(lambda: directory.remove(rec=1)) + return Path(directory) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_elem_type_zgroup(session_tmp_path: Path) -> PartialReadTestCase: + # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_top_level_zgroup.zarr" + sdata.write(sdata_path) + + (sdata_path / "images" / ".zgroup").unlink() # missing, not detected by reader. So it doesn't raise an exception, + # but it will not be found in the read SpatialData object + (sdata_path / "labels" / ".zgroup").write_text("") # corrupted + (sdata_path / "points" / ".zgroup").write_text("{}") # invalid + not_corrupted = [name for t, name, _ in sdata.gen_elements() if t not in ("images", "labels", "points")] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(JSONDecodeError, MetadataError), + warnings_patterns=["labels: JSONDecodeError", "points: MetadataError"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_zattrs(session_tmp_path: Path) -> PartialReadTestCase: + # .zattrs is a zero-byte file, aborted during write, or contains invalid JSON syntax + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_zattrs.zarr" + sdata.write(sdata_path) + + corrupted_elements = ["blobs_image", "blobs_labels", "blobs_points", "blobs_polygons", "table"] + warnings_patterns = [] + for corrupted_element in corrupted_elements: + elem_path = sdata.locate_element(sdata[corrupted_element])[0] + (sdata_path / elem_path / ".zattrs").write_bytes(b"") + warnings_patterns.append(f"{elem_path}: JSONDecodeError") + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name not in corrupted_elements] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=JSONDecodeError, + warnings_patterns=warnings_patterns, + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_image_chunks(session_tmp_path: Path) -> PartialReadTestCase: + # images/blobs_image/0 is a zero-byte file or aborted during write + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_image_chunks.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") # it will hide the "0" array from the Zarr reader + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + (sdata_path / "images" / corrupted / "0").touch() + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=( + ArrayNotFoundError, + TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 + ), + warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_corrupted_parquet(session_tmp_path: Path) -> PartialReadTestCase: + # points/blobs_points/0 is a zero-byte file or aborted during write + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_corrupted_parquet.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_points" + os.rename( + sdata_path / "points" / corrupted / "points.parquet", + sdata_path / "points" / corrupted / "points_corrupted.parquet", + ) + (sdata_path / "points" / corrupted / "points.parquet").touch() + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=ArrowInvalid, + warnings_patterns=[rf"points/{corrupted}: ArrowInvalid"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_zattrs(session_tmp_path: Path) -> PartialReadTestCase: + # .zattrs is missing + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_zattrs.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_image" + (sdata_path / "images" / corrupted / ".zattrs").unlink() + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=ValueError, + warnings_patterns=[rf"images/{corrupted}: .* Unable to read the NGFF file"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_missing_image_chunks( + session_tmp_path: Path, +) -> PartialReadTestCase: + # .zattrs exists, but refers to binary array chunks that do not exist + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_missing_image_chunks.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_image" + os.unlink(sdata_path / "images" / corrupted / "0" / ".zarray") + os.rename(sdata_path / "images" / corrupted / "0", sdata_path / "images" / corrupted / "0_corrupted") + + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=( + ArrayNotFoundError, + TypeError, # instead of ArrayNotFoundError, with dask>=2024.10.0 zarr<=2.18.3 + ), + warnings_patterns=[rf"images/{corrupted}: (ArrayNotFoundError|TypeError)"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_invalid_zattrs_violating_spec(session_tmp_path: Path) -> PartialReadTestCase: + # .zattrs contains readable JSON which is not valid for SpatialData/NGFF specs + # for example due to a missing/misspelled/renamed key + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_violating_spec.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_image" + json_dict = json.loads((sdata_path / "images" / corrupted / ".zattrs").read_text()) + del json_dict["multiscales"][0]["coordinateTransformations"] + (sdata_path / "images" / corrupted / ".zattrs").write_text(json.dumps(json_dict, indent=4)) + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=KeyError, + warnings_patterns=[rf"images/{corrupted}: KeyError: coordinateTransformations"], + ) + + +@pytest.fixture(scope="module") +def sdata_with_invalid_zattrs_table_region_not_found(session_tmp_path: Path) -> PartialReadTestCase: + # table/table/.zarr referring to a region that is not found + # This has been emitting just a warning, but does not fail reading the table element. + sdata = blobs() + sdata_path = session_tmp_path / "sdata_with_invalid_zattrs_table_region_not_found.zarr" + sdata.write(sdata_path) + + corrupted = "blobs_labels" + # The element data is missing + os.unlink(sdata_path / "labels" / corrupted / ".zgroup") + os.rename(sdata_path / "labels" / corrupted, sdata_path / "labels" / f"{corrupted}_corrupted") + # But the labels element is referenced as a region in a table + regions = zarr.open_group(sdata_path / "tables" / "table" / "obs" / "region", mode="r") + assert corrupted in np.asarray(regions.categories)[regions.codes] + not_corrupted = [name for _, name, _ in sdata.gen_elements() if name != corrupted] + + return PartialReadTestCase( + path=sdata_path, + expected_elements=not_corrupted, + expected_exceptions=(), + warnings_patterns=[ + rf"The table is annotating '{re.escape(corrupted)}', which is not present in the SpatialData object" + ], + ) + + +@pytest.mark.parametrize( + "test_case", + [ + sdata_with_corrupted_zattrs, + sdata_with_corrupted_image_chunks, + sdata_with_corrupted_parquet, + sdata_with_missing_zattrs, + sdata_with_missing_image_chunks, + sdata_with_invalid_zattrs_violating_spec, + sdata_with_invalid_zattrs_table_region_not_found, + sdata_with_corrupted_elem_type_zgroup, + ], + indirect=True, +) +def test_read_zarr_with_error(test_case: PartialReadTestCase): + # The specific type of exception depends on the read function for the SpatialData element + if test_case.expected_exceptions: + with pytest.raises(test_case.expected_exceptions): + read_zarr(test_case.path, on_bad_files="error") + else: + read_zarr(test_case.path, on_bad_files="error") + + +@pytest.mark.parametrize( + "test_case", + [ + sdata_with_corrupted_zattrs, + sdata_with_corrupted_image_chunks, + sdata_with_corrupted_parquet, + sdata_with_missing_zattrs, + sdata_with_missing_image_chunks, + sdata_with_invalid_zattrs_violating_spec, + sdata_with_invalid_zattrs_table_region_not_found, + sdata_with_corrupted_elem_type_zgroup, + ], + indirect=True, +) +def test_read_zarr_with_warnings(test_case: PartialReadTestCase): + with pytest_warns_multiple(UserWarning, matches=test_case.warnings_patterns): + actual: SpatialData = read_zarr(test_case.path, on_bad_files="warn") + + actual_elements = {name for _, name, _ in actual.gen_elements()} + assert set(test_case.expected_elements) == actual_elements diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 31346528e..f9778f5c7 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -1,10 +1,12 @@ import os import tempfile +from contextlib import nullcontext import dask.dataframe as dd +import pytest from spatialdata import read_zarr -from spatialdata._io._utils import get_dask_backing_files +from spatialdata._io._utils import get_dask_backing_files, handle_read_errors def test_backing_files_points(points): @@ -118,3 +120,20 @@ def test_backing_files_combining_points_and_images(points, images): os.path.realpath(os.path.join(f1, "images/image2d")), ] assert set(files) == set(expected_zarr_locations_old) or set(files) == set(expected_zarr_locations_new) + + +@pytest.mark.parametrize( + ("on_bad_files", "actual_error", "expectation"), + [ + ("error", None, nullcontext()), + ("error", KeyError("key"), pytest.raises(KeyError)), + ("warn", None, nullcontext()), + ("warn", KeyError("key"), pytest.warns(UserWarning, match="location: KeyError")), + ("warn", RuntimeError("unhandled"), pytest.raises(RuntimeError)), + ], +) +def test_handle_read_errors(on_bad_files: str, actual_error: Exception, expectation): + with expectation: # noqa: SIM117 + with handle_read_errors(on_bad_files=on_bad_files, location="location", exc_types=KeyError): + if actual_error is not None: + raise actual_error