diff --git a/adit-client/adit_client/client.py b/adit-client/adit_client/client.py index 6ed2d9fdb..0414cd3dd 100644 --- a/adit-client/adit_client/client.py +++ b/adit-client/adit_client/client.py @@ -1,5 +1,7 @@ import importlib.metadata -from typing import Iterator +import os +from io import BytesIO +from typing import Iterator, Union from dicomweb_client import DICOMwebClient, session_utils from pydicom import Dataset @@ -189,6 +191,154 @@ def store_images(self, ae_title: str, images: list[Dataset]) -> Dataset: """Store images.""" return self._create_dicom_web_client(ae_title).store_instances(images) + def retrieve_nifti_study(self, ae_title: str, study_uid: str) -> list[tuple[str, BytesIO]]: + """Retrieve NIfTI files for a study.""" + url = f"{self.server_url}/api/dicom-web/{ae_title}/wadors/studies/{study_uid}/nifti" + dicomweb_client = self._create_dicom_web_client(ae_title) + response = dicomweb_client._http_get( + url, + headers={"Accept": "multipart/related; type=application/octet-stream"}, + stream=True, + ) + return list(self._iter_multipart_response(response, stream=False)) + + def iter_nifti_study(self, ae_title: str, study_uid: str) -> Iterator[tuple[str, BytesIO]]: + """Iterate over NIfTI files for a study.""" + url = f"{self.server_url}/api/dicom-web/{ae_title}/wadors/studies/{study_uid}/nifti" + dicomweb_client = self._create_dicom_web_client(ae_title) + response = dicomweb_client._http_get( + url, + headers={"Accept": "multipart/related; type=application/octet-stream"}, + stream=True, + ) + yield from self._iter_multipart_response(response, stream=True) + + def retrieve_nifti_series( + self, ae_title: str, study_uid: str, series_uid: str + ) -> list[tuple[str, BytesIO]]: + """Retrieve NIfTI files for a series.""" + url = ( + f"{self.server_url}/api/dicom-web/{ae_title}/wadors/studies/{study_uid}/" + f"series/{series_uid}/nifti" + ) + dicomweb_client = self._create_dicom_web_client(ae_title) + response = dicomweb_client._http_get( + url, + headers={"Accept": "multipart/related; type=application/octet-stream"}, + stream=True, + ) + return list(self._iter_multipart_response(response, stream=False)) + + def iter_nifti_series( + self, ae_title: str, study_uid: str, series_uid: str + ) -> Iterator[tuple[str, BytesIO]]: + """Iterate over NIfTI files for a series.""" + url = ( + f"{self.server_url}/api/dicom-web/{ae_title}/wadors/studies/{study_uid}/" + f"series/{series_uid}/nifti" + ) + dicomweb_client = self._create_dicom_web_client(ae_title) + response = dicomweb_client._http_get( + url, + headers={"Accept": "multipart/related; type=application/octet-stream"}, + stream=True, + ) + yield from self._iter_multipart_response(response, stream=True) + + def retrieve_nifti_image( + self, ae_title: str, study_uid: str, series_uid: str, image_uid: str + ) -> list[tuple[str, BytesIO]]: + """Retrieve NIfTI files for a single image.""" + url = ( + f"{self.server_url}/api/dicom-web/{ae_title}/wadors/studies/{study_uid}/" + f"series/{series_uid}/instances/{image_uid}/nifti" + ) + dicomweb_client = self._create_dicom_web_client(ae_title) + response = dicomweb_client._http_get( + url, + headers={"Accept": "multipart/related; type=application/octet-stream"}, + stream=True, + ) + return list(self._iter_multipart_response(response, stream=False)) + + def iter_nifti_image( + self, ae_title: str, study_uid: str, series_uid: str, image_uid: str + ) -> Iterator[tuple[str, BytesIO]]: + """Iterate over NIfTI files for a single image.""" + url = ( + f"{self.server_url}/api/dicom-web/{ae_title}/wadors/studies/{study_uid}/" + f"series/{series_uid}/instances/{image_uid}/nifti" + ) + dicomweb_client = self._create_dicom_web_client(ae_title) + response = dicomweb_client._http_get( + url, + headers={"Accept": "multipart/related; type=application/octet-stream"}, + stream=True, + ) + yield from self._iter_multipart_response(response, stream=True) + + def _extract_filename(self, content_disposition: str | None) -> str: + """Extract filename from Content-Disposition header.""" + if not content_disposition or "filename=" not in content_disposition: + raise ValueError("No filename found in Content-Disposition header") + filename = content_disposition.split("filename=")[1].strip('"') + filename = os.path.basename(filename) + if not filename: + raise ValueError("Content-Disposition filename resolved to empty string") + return filename + + def _extract_part_content_with_headers(self, part: bytes) -> Union[bytes, None]: + """Extract content from a multipart part, keeping headers intact. + + Used to patch DICOMwebClient's _extract_part_content to allow access + to per-part headers (especially Content-Disposition for filenames). + """ + if part in (b"", b"--", b"\r\n") or part.startswith(b"--\r\n"): + return None + return part + + # NOTE: This method monkey-patches DICOMwebClient._extract_part_content and + # uses DICOMwebClient._decode_multipart_message — both private APIs. + # These methods are not part of the public interface and may change without + # warning. If dicomweb-client is upgraded, verify that these private methods + # still exist and behave the same way. + def _iter_multipart_response(self, response, stream=False) -> Iterator[tuple[str, BytesIO]]: + """Parse a multipart response, yielding (filename, content) tuples.""" + dicomweb_client = self._create_dicom_web_client("") + original_extract_method = dicomweb_client._extract_part_content + + try: + dicomweb_client._extract_part_content = self._extract_part_content_with_headers + + for part in dicomweb_client._decode_multipart_message(response, stream=stream): + headers = {} + content = part + + idx = part.find(b"\r\n\r\n") + if idx > -1: + headers_bytes = part[:idx] + content = part[idx + 4 :] + + for header_line in headers_bytes.split(b"\r\n"): + if header_line and b":" in header_line: + name, value = header_line.split(b":", 1) + headers[name.decode("utf-8").strip()] = value.decode("utf-8").strip() + + content_disposition = headers.get("Content-Disposition") + if content_disposition: + filename = self._extract_filename(content_disposition) + else: + for header, value in response.headers.items(): + if header.lower() == "content-disposition": + filename = self._extract_filename(value) + break + else: + raise ValueError("No Content-Disposition header found in response") + + yield (filename, BytesIO(content)) + finally: + dicomweb_client._extract_part_content = original_extract_method + def _create_dicom_web_client(self, ae_title: str) -> DICOMwebClient: session = session_utils.create_session() diff --git a/adit-client/tests/test_nifti_client.py b/adit-client/tests/test_nifti_client.py new file mode 100644 index 000000000..75f115c8b --- /dev/null +++ b/adit-client/tests/test_nifti_client.py @@ -0,0 +1,122 @@ +from unittest.mock import MagicMock, patch + +import pytest +from adit_client.client import AditClient + + +class TestExtractFilename: + def test_valid_filename(self): + client = AditClient("http://localhost", "token") + result = client._extract_filename('attachment; filename="scan.nii.gz"') + assert result == "scan.nii.gz" + + def test_filename_with_path(self): + client = AditClient("http://localhost", "token") + result = client._extract_filename('attachment; filename="path/to/scan.nii.gz"') + assert result == "scan.nii.gz" + + def test_missing_header(self): + client = AditClient("http://localhost", "token") + with pytest.raises(ValueError, match="No filename found"): + client._extract_filename(None) + + def test_no_filename_field(self): + client = AditClient("http://localhost", "token") + with pytest.raises(ValueError, match="No filename found"): + client._extract_filename("attachment") + + +class TestExtractPartContentWithHeaders: + def test_empty_bytes(self): + client = AditClient("http://localhost", "token") + assert client._extract_part_content_with_headers(b"") is None + + def test_boundary_marker(self): + client = AditClient("http://localhost", "token") + assert client._extract_part_content_with_headers(b"--") is None + + def test_crlf(self): + client = AditClient("http://localhost", "token") + assert client._extract_part_content_with_headers(b"\r\n") is None + + def test_boundary_with_crlf(self): + client = AditClient("http://localhost", "token") + assert client._extract_part_content_with_headers(b"--\r\n") is None + + def test_normal_content(self): + client = AditClient("http://localhost", "token") + part = b"Content-Type: application/octet-stream\r\n\r\ndata" + assert client._extract_part_content_with_headers(part) == part + + +class TestIterMultipartResponse: + def test_parses_parts_with_content_disposition(self): + client = AditClient("http://localhost", "token") + + # Create a fake part with headers + content separated by \r\n\r\n + part = ( + b"Content-Type: application/octet-stream\r\n" + b'Content-Disposition: attachment; filename="scan.nii.gz"\r\n' + b"\r\n" + b"nifti content" + ) + + fake_dicomweb_client = MagicMock() + fake_dicomweb_client._decode_multipart_message.return_value = [part] + # Let _extract_part_content return part as-is (our patched method) + fake_dicomweb_client._extract_part_content = client._extract_part_content_with_headers + + response = MagicMock() + response.headers = {} + + with patch.object(client, "_create_dicom_web_client", return_value=fake_dicomweb_client): + results = list(client._iter_multipart_response(response, stream=False)) + + assert len(results) == 1 + assert results[0][0] == "scan.nii.gz" + assert results[0][1].read() == b"nifti content" + + def test_falls_back_to_response_headers(self): + client = AditClient("http://localhost", "token") + + # Part without Content-Disposition (no \r\n\r\n separator means no headers parsed) + part = b"just raw content without headers" + + fake_dicomweb_client = MagicMock() + fake_dicomweb_client._decode_multipart_message.return_value = [part] + fake_dicomweb_client._extract_part_content = client._extract_part_content_with_headers + + fake_headers = MagicMock() + fake_headers.items.return_value = [ + ("content-disposition", 'attachment; filename="fallback.nii.gz"') + ] + fake_headers.get.return_value = None + + response = MagicMock() + response.headers = fake_headers + + with patch.object(client, "_create_dicom_web_client", return_value=fake_dicomweb_client): + results = list(client._iter_multipart_response(response, stream=False)) + + assert len(results) == 1 + assert results[0][0] == "fallback.nii.gz" + + def test_no_disposition_anywhere_raises(self): + client = AditClient("http://localhost", "token") + + part = b"content without any disposition" + + fake_dicomweb_client = MagicMock() + fake_dicomweb_client._decode_multipart_message.return_value = [part] + fake_dicomweb_client._extract_part_content = client._extract_part_content_with_headers + + fake_headers = MagicMock() + fake_headers.items.return_value = [] + fake_headers.get.return_value = None + + response = MagicMock() + response.headers = fake_headers + + with patch.object(client, "_create_dicom_web_client", return_value=fake_dicomweb_client): + with pytest.raises(ValueError, match="No Content-Disposition"): + list(client._iter_multipart_response(response, stream=False)) diff --git a/adit/core/errors.py b/adit/core/errors.py index 8d0c4fa79..2508cad68 100644 --- a/adit/core/errors.py +++ b/adit/core/errors.py @@ -49,6 +49,48 @@ def is_retriable_http_status(status_code: int) -> bool: return status_code in retriable_status_codes +class DcmToNiftiConversionError(Exception): + """Base exception for DICOM to NIfTI conversion errors.""" + + pass + + +class NoValidDicomError(DcmToNiftiConversionError): + """Exception raised when no valid DICOM files are found.""" + + pass + + +class InvalidDicomError(DcmToNiftiConversionError): + """Exception raised when DICOM files are invalid or corrupt.""" + + pass + + +class OutputDirectoryError(DcmToNiftiConversionError): + """Exception raised when there are issues with the output directory.""" + + pass + + +class InputDirectoryError(DcmToNiftiConversionError): + """Exception raised when there are issues with the input directory.""" + + pass + + +class ExternalToolError(DcmToNiftiConversionError): + """Exception raised when there are issues with the external dcm2niix tool.""" + + pass + + +class NoSpatialDataError(DcmToNiftiConversionError): + """Exception raised when DICOM data doesn't have spatial attributes.""" + + pass + + class BatchFileSizeError(Exception): def __init__(self, batch_tasks_count: int, max_batch_size: int) -> None: super().__init__("Too many batch tasks.") diff --git a/adit/core/tests/utils/test_dicom_to_nifti_converter.py b/adit/core/tests/utils/test_dicom_to_nifti_converter.py new file mode 100644 index 000000000..0149ab838 --- /dev/null +++ b/adit/core/tests/utils/test_dicom_to_nifti_converter.py @@ -0,0 +1,199 @@ +import logging +import subprocess +from unittest.mock import MagicMock + +import pytest + +from adit.core.errors import ( + DcmToNiftiConversionError, + ExternalToolError, + InputDirectoryError, + InvalidDicomError, + NoValidDicomError, + OutputDirectoryError, +) +from adit.core.utils.dicom_to_nifti_converter import DicomToNiftiConverter + +CONVERTER_LOGGER = "adit.core.utils.dicom_to_nifti_converter" + + +@pytest.fixture(autouse=True) +def _enable_log_propagation(): + """Enable propagation on the adit logger so caplog can capture log messages.""" + adit_logger = logging.getLogger("adit") + original = adit_logger.propagate + adit_logger.propagate = True + yield + adit_logger.propagate = original + + +def _make_completed_process(returncode: int, stdout: str = "", stderr: str = ""): + mock = MagicMock(spec=subprocess.CompletedProcess) + mock.returncode = returncode + mock.stdout = stdout.encode("utf-8") + mock.stderr = stderr.encode("utf-8") + return mock + + +class TestDicomToNiftiConverter: + def test_convert_success(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + def fake_run(*args, **kwargs): + output_folder.mkdir(parents=True, exist_ok=True) + (output_folder / "output.nii").touch() + return _make_completed_process(0) + + monkeypatch.setattr(subprocess, "run", fake_run) + + converter = DicomToNiftiConverter() + converter.convert(dicom_folder, output_folder) + + def test_convert_no_dicom_found(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(2)) + + converter = DicomToNiftiConverter() + with pytest.raises(NoValidDicomError, match="No DICOM images found"): + converter.convert(dicom_folder, output_folder) + + def test_convert_corrupt_dicom(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(4)) + + converter = DicomToNiftiConverter() + with pytest.raises(InvalidDicomError, match="Corrupt DICOM"): + converter.convert(dicom_folder, output_folder) + + def test_convert_invalid_input_folder(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(5)) + + converter = DicomToNiftiConverter() + with pytest.raises(InputDirectoryError, match="Input folder invalid"): + converter.convert(dicom_folder, output_folder) + + def test_convert_invalid_output_folder(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(6)) + + converter = DicomToNiftiConverter() + with pytest.raises(OutputDirectoryError, match="Output folder invalid"): + converter.convert(dicom_folder, output_folder) + + def test_convert_write_permission_error(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(7)) + + converter = DicomToNiftiConverter() + with pytest.raises(OutputDirectoryError, match="Unable to write"): + converter.convert(dicom_folder, output_folder) + + def test_convert_partial_conversion(self, tmp_path, monkeypatch, caplog): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(8)) + + converter = DicomToNiftiConverter() + with caplog.at_level(logging.WARNING, logger=CONVERTER_LOGGER): + converter.convert(dicom_folder, output_folder) + + assert any("Converted some but not all" in msg for msg in caplog.messages) + + def test_convert_rename_error(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(9)) + + converter = DicomToNiftiConverter() + with pytest.raises(DcmToNiftiConversionError, match="Unable to rename"): + converter.convert(dicom_folder, output_folder) + + def test_convert_unspecified_error(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + monkeypatch.setattr(subprocess, "run", lambda *args, **kwargs: _make_completed_process(1)) + + converter = DicomToNiftiConverter() + with pytest.raises(DcmToNiftiConversionError, match="Unspecified error"): + converter.convert(dicom_folder, output_folder) + + def test_convert_subprocess_error(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + def raise_subprocess_error(*args, **kwargs): + raise subprocess.SubprocessError("dcm2niix not found") + + monkeypatch.setattr(subprocess, "run", raise_subprocess_error) + + converter = DicomToNiftiConverter() + with pytest.raises(ExternalToolError, match="Failed to execute dcm2niix"): + converter.convert(dicom_folder, output_folder) + + def test_convert_nonexistent_dicom_folder(self, tmp_path): + dicom_folder = tmp_path / "nonexistent" + output_folder = tmp_path / "output" + + converter = DicomToNiftiConverter() + with pytest.raises(ValueError, match="does not exist"): + converter.convert(dicom_folder, output_folder) + + def test_convert_creates_output_folder_if_missing(self, tmp_path, monkeypatch): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" / "nested" + + def fake_run(*args, **kwargs): + output_folder.mkdir(parents=True, exist_ok=True) + (output_folder / "output.nii").touch() + return _make_completed_process(0) + + monkeypatch.setattr(subprocess, "run", fake_run) + + converter = DicomToNiftiConverter() + converter.convert(dicom_folder, output_folder) + + assert output_folder.exists() + + def test_convert_logs_warning_on_dcm2niix_warnings(self, tmp_path, monkeypatch, caplog): + dicom_folder = tmp_path / "dicom" + dicom_folder.mkdir() + output_folder = tmp_path / "output" + + def fake_run(*args, **kwargs): + output_folder.mkdir(parents=True, exist_ok=True) + (output_folder / "output.nii").touch() + return _make_completed_process(0, stderr="Warning: some issue detected") + + monkeypatch.setattr(subprocess, "run", fake_run) + + converter = DicomToNiftiConverter() + with caplog.at_level(logging.WARNING, logger=CONVERTER_LOGGER): + converter.convert(dicom_folder, output_folder) + + assert any("Warnings during conversion" in msg for msg in caplog.messages) diff --git a/adit/core/utils/dicom_to_nifti_converter.py b/adit/core/utils/dicom_to_nifti_converter.py index d383b9fdc..7cf446e92 100644 --- a/adit/core/utils/dicom_to_nifti_converter.py +++ b/adit/core/utils/dicom_to_nifti_converter.py @@ -1,10 +1,36 @@ import logging import subprocess +from enum import IntEnum from pathlib import Path +from adit.core.errors import ( + DcmToNiftiConversionError, + ExternalToolError, + InputDirectoryError, + InvalidDicomError, + NoSpatialDataError, + NoValidDicomError, + OutputDirectoryError, +) + logger = logging.getLogger(__name__) +class DcmToNiftiExitCode(IntEnum): + """Exit codes for dcm2niix as documented in https://github.com/rordenlab/dcm2niix""" + + SUCCESS = 0 + UNSPECIFIED_ERROR = 1 + NO_DICOM_FOUND = 2 + VERSION_REPORT = 3 + CORRUPT_DICOM = 4 + INVALID_INPUT_FOLDER = 5 + INVALID_OUTPUT_FOLDER = 6 + WRITE_PERMISSION_ERROR = 7 + PARTIAL_CONVERSION = 8 + RENAME_ERROR = 9 + + class DicomToNiftiConverter: def __init__(self, dcm2niix_path: str = "dcm2niix"): """Initialize the converter with the path to the dcm2niix executable. @@ -22,7 +48,14 @@ def convert(self, dicom_folder: str | Path, output_folder: str | Path) -> None: dicom_folder: Path to the folder containing DICOM files. output_folder: Path to the folder where NIfTI files will be saved. Raises: - RuntimeError: If the conversion fails. + ValueError: If the dicom_folder doesn't exist. + NoValidDicomError: If no valid DICOM files are found. + NoSpatialDataError: If conversion succeeds but produces no NIfTI output. + InvalidDicomError: If DICOM files are invalid or corrupt. + OutputDirectoryError: If there are issues with the output directory. + InputDirectoryError: If there are issues with the input directory. + ExternalToolError: If there are issues with the dcm2niix tool. + DcmToNiftiConversionError: For other conversion errors. """ dicom_folder = Path(dicom_folder) output_folder = Path(output_folder) @@ -45,9 +78,49 @@ def convert(self, dicom_folder: str | Path, output_folder: str | Path) -> None: ] try: - subprocess.run(cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - except subprocess.CalledProcessError as e: - raise RuntimeError(f"Failed to convert DICOM to NIfTI: {e.stderr.decode('utf-8')}") + result = subprocess.run( + cmd, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + stderr = result.stderr.decode("utf-8") + stdout = result.stdout.decode("utf-8") + + if "Warning:" in stderr or "Warning:" in stdout: + logger.warning(f"Warnings during conversion: {stderr}\n{stdout}") + + exit_code = result.returncode + error_msg = f"{stderr}\n{stdout}".strip() + + if exit_code == DcmToNiftiExitCode.SUCCESS: + if not any(output_folder.glob("*.nii*")): + raise NoSpatialDataError( + "Conversion succeeded but produced no NIfTI files. " + "DICOM data may lack spatial attributes." + ) + elif exit_code == DcmToNiftiExitCode.NO_DICOM_FOUND: + raise NoValidDicomError(f"No DICOM images found in input folder: {error_msg}") + elif exit_code == DcmToNiftiExitCode.VERSION_REPORT: + logger.info(f"dcm2niix version report: {error_msg}") + elif exit_code == DcmToNiftiExitCode.CORRUPT_DICOM: + raise InvalidDicomError(f"Corrupt DICOM file: {error_msg}") + elif exit_code == DcmToNiftiExitCode.INVALID_INPUT_FOLDER: + raise InputDirectoryError(f"Input folder invalid: {error_msg}") + elif exit_code == DcmToNiftiExitCode.INVALID_OUTPUT_FOLDER: + raise OutputDirectoryError(f"Output folder invalid: {error_msg}") + elif exit_code == DcmToNiftiExitCode.WRITE_PERMISSION_ERROR: + raise OutputDirectoryError( + f"Unable to write to output folder (check permissions): {error_msg}" + ) + elif exit_code == DcmToNiftiExitCode.PARTIAL_CONVERSION: + logger.warning(f"Converted some but not all input DICOMs: {error_msg}") + elif exit_code == DcmToNiftiExitCode.RENAME_ERROR: + raise DcmToNiftiConversionError(f"Unable to rename files: {error_msg}") + else: + raise DcmToNiftiConversionError( + f"Unspecified error (exit code {exit_code}): {error_msg}" + ) + + except subprocess.SubprocessError as e: + raise ExternalToolError(f"Failed to execute dcm2niix: {e}") logger.debug( f"DICOM files in {dicom_folder} successfully converted to NIfTI format " diff --git a/adit/dicom_web/renderers.py b/adit/dicom_web/renderers.py index 305d87c39..aece41d71 100644 --- a/adit/dicom_web/renderers.py +++ b/adit/dicom_web/renderers.py @@ -87,6 +87,39 @@ def _end_stream(self) -> bytes: return stream.getvalue() +class WadoMultipartApplicationNiftiRenderer(DicomWebWadoRenderer): + media_type = "multipart/related; type=application/octet-stream" + format = "multipart" + subtype: str = "application/octet-stream" + boundary: str = "nifti-boundary" + charset: str = "utf-8" + + @property + def content_type(self) -> str: + return f"{self.media_type}; boundary={self.boundary}" + + def render(self, images: AsyncIterator[tuple[str, BytesIO]]) -> AsyncIterator[bytes]: + async def streaming_content(): + first_part = True + async for filename, file_content in images: + if first_part: + yield f"--{self.boundary}\r\n".encode() + first_part = False + else: + yield f"\r\n--{self.boundary}\r\n".encode() + + yield "Content-Type: application/octet-stream\r\n".encode() + safe_filename = "".join(c for c in filename if c.isprintable() and c != '"')[:255] + disposition = f'Content-Disposition: attachment; filename="{safe_filename}"' + yield f"{disposition}\r\n\r\n".encode() + + yield file_content.getvalue() + + yield f"\r\n--{self.boundary}--\r\n".encode() + + return streaming_content() + + class WadoApplicationDicomJsonRenderer(DicomWebWadoRenderer): media_type = "application/dicom+json" format = "json" diff --git a/adit/dicom_web/tests/acceptance/test_wadors.py b/adit/dicom_web/tests/acceptance/test_wadors.py index d2bdea059..525e36310 100644 --- a/adit/dicom_web/tests/acceptance/test_wadors.py +++ b/adit/dicom_web/tests/acceptance/test_wadors.py @@ -1,5 +1,8 @@ from http import HTTPStatus +from io import BytesIO +from typing import cast +import pandas as pd import pydicom import pytest from adit_client import AditClient @@ -579,3 +582,178 @@ def test_retrieve_image_metadata_with_invalid_pseudonym(channels_live_server: Ch assert response.status_code == HTTPStatus.BAD_REQUEST error = response.json() assert "pseudonym" in error or "invalid" in str(error).lower(), f"Unexpected error: {error}" + + +@pytest.mark.acceptance +@pytest.mark.order("last") +@pytest.mark.django_db(transaction=True) +def test_retrieve_nifti_study(channels_live_server: ChannelsLiveServer): + setup_dimse_orthancs() + + _, group, token = create_user_with_dicom_web_group_and_token() + server = DicomServer.objects.get(ae_title="ORTHANC1") + grant_access(group, server, source=True) + adit_client = AditClient(server_url=channels_live_server.url, auth_token=token) + + metadata = load_sample_dicoms_metadata("1001") + study_uid: str = metadata["StudyInstanceUID"].iloc[0] + + results = adit_client.retrieve_nifti_study(server.ae_title, study_uid) + + assert len(results) > 0, "Expected at least one NIfTI file" + + filenames = [filename for filename, _ in results] + nifti_files = [f for f in filenames if f.endswith(".nii.gz") or f.endswith(".nii")] + json_files = [f for f in filenames if f.endswith(".json")] + + assert len(nifti_files) > 0, "Expected at least one .nii.gz or .nii file" + assert len(json_files) > 0, "Expected at least one .json sidecar file" + + for filename, content in results: + assert isinstance(content, BytesIO) + data = content.read() + assert len(data) > 0, f"File {filename} should not be empty" + + +@pytest.mark.acceptance +@pytest.mark.order("last") +@pytest.mark.django_db(transaction=True) +def test_retrieve_nifti_series(channels_live_server: ChannelsLiveServer): + setup_dimse_orthancs() + + _, group, token = create_user_with_dicom_web_group_and_token() + server = DicomServer.objects.get(ae_title="ORTHANC1") + grant_access(group, server, source=True) + adit_client = AditClient(server_url=channels_live_server.url, auth_token=token) + + metadata = load_sample_dicoms_metadata("1001") + image_metadata = cast(pd.DataFrame, metadata[~metadata["Modality"].isin(["SR", "KO", "PR"])]) + study_uid: str = image_metadata["StudyInstanceUID"].iloc[0] + series_uid: str = image_metadata["SeriesInstanceUID"].iloc[0] + + results = adit_client.retrieve_nifti_series(server.ae_title, study_uid, series_uid) + + assert len(results) > 0, "Expected at least one NIfTI file" + + filenames = [filename for filename, _ in results] + nifti_files = [f for f in filenames if f.endswith(".nii.gz") or f.endswith(".nii")] + json_files = [f for f in filenames if f.endswith(".json")] + + assert len(nifti_files) > 0, "Expected at least one .nii.gz or .nii file" + assert len(json_files) > 0, "Expected at least one .json sidecar file" + + for filename, content in results: + assert isinstance(content, BytesIO) + data = content.read() + assert len(data) > 0, f"File {filename} should not be empty" + + +@pytest.mark.acceptance +@pytest.mark.order("last") +@pytest.mark.django_db(transaction=True) +def test_iter_nifti_study(channels_live_server: ChannelsLiveServer): + setup_dimse_orthancs() + + _, group, token = create_user_with_dicom_web_group_and_token() + server = DicomServer.objects.get(ae_title="ORTHANC1") + grant_access(group, server, source=True) + adit_client = AditClient(server_url=channels_live_server.url, auth_token=token) + + metadata = load_sample_dicoms_metadata("1001") + study_uid: str = metadata["StudyInstanceUID"].iloc[0] + + retrieved = adit_client.retrieve_nifti_study(server.ae_title, study_uid) + iterated = list(adit_client.iter_nifti_study(server.ae_title, study_uid)) + + assert len(iterated) == len(retrieved), ( + "iter and retrieve should return the same number of files" + ) + for (r_name, _), (i_name, _) in zip(retrieved, iterated): + assert r_name == i_name, f"Filenames should match: {r_name} != {i_name}" + + +@pytest.mark.acceptance +@pytest.mark.order("last") +@pytest.mark.django_db(transaction=True) +def test_iter_nifti_series(channels_live_server: ChannelsLiveServer): + setup_dimse_orthancs() + + _, group, token = create_user_with_dicom_web_group_and_token() + server = DicomServer.objects.get(ae_title="ORTHANC1") + grant_access(group, server, source=True) + adit_client = AditClient(server_url=channels_live_server.url, auth_token=token) + + metadata = load_sample_dicoms_metadata("1001") + image_metadata = cast(pd.DataFrame, metadata[~metadata["Modality"].isin(["SR", "KO", "PR"])]) + study_uid: str = image_metadata["StudyInstanceUID"].iloc[0] + series_uid: str = image_metadata["SeriesInstanceUID"].iloc[0] + + retrieved = adit_client.retrieve_nifti_series(server.ae_title, study_uid, series_uid) + iterated = list(adit_client.iter_nifti_series(server.ae_title, study_uid, series_uid)) + + assert len(iterated) == len(retrieved), ( + "iter and retrieve should return the same number of files" + ) + for (r_name, _), (i_name, _) in zip(retrieved, iterated): + assert r_name == i_name, f"Filenames should match: {r_name} != {i_name}" + + +@pytest.mark.acceptance +@pytest.mark.order("last") +@pytest.mark.django_db(transaction=True) +def test_retrieve_nifti_image(channels_live_server: ChannelsLiveServer): + setup_dimse_orthancs() + + _, group, token = create_user_with_dicom_web_group_and_token() + server = DicomServer.objects.get(ae_title="ORTHANC1") + grant_access(group, server, source=True) + adit_client = AditClient(server_url=channels_live_server.url, auth_token=token) + + metadata = load_sample_dicoms_metadata("1001") + image_metadata = cast(pd.DataFrame, metadata[~metadata["Modality"].isin(["SR", "KO", "PR"])]) + study_uid: str = image_metadata["StudyInstanceUID"].iloc[0] + series_uid: str = image_metadata["SeriesInstanceUID"].iloc[0] + image_uid: str = image_metadata["SOPInstanceUID"].iloc[0] + + results = adit_client.retrieve_nifti_image(server.ae_title, study_uid, series_uid, image_uid) + + assert len(results) > 0, "Expected at least one NIfTI file" + + filenames = [filename for filename, _ in results] + nifti_files = [f for f in filenames if f.endswith(".nii.gz") or f.endswith(".nii")] + json_files = [f for f in filenames if f.endswith(".json")] + + assert len(nifti_files) > 0, "Expected at least one .nii.gz or .nii file" + assert len(json_files) > 0, "Expected at least one .json sidecar file" + + for filename, content in results: + assert isinstance(content, BytesIO) + data = content.read() + assert len(data) > 0, f"File {filename} should not be empty" + + +@pytest.mark.acceptance +@pytest.mark.order("last") +@pytest.mark.django_db(transaction=True) +def test_iter_nifti_image(channels_live_server: ChannelsLiveServer): + setup_dimse_orthancs() + + _, group, token = create_user_with_dicom_web_group_and_token() + server = DicomServer.objects.get(ae_title="ORTHANC1") + grant_access(group, server, source=True) + adit_client = AditClient(server_url=channels_live_server.url, auth_token=token) + + metadata = load_sample_dicoms_metadata("1001") + image_metadata = cast(pd.DataFrame, metadata[~metadata["Modality"].isin(["SR", "KO", "PR"])]) + study_uid: str = image_metadata["StudyInstanceUID"].iloc[0] + series_uid: str = image_metadata["SeriesInstanceUID"].iloc[0] + image_uid: str = image_metadata["SOPInstanceUID"].iloc[0] + + retrieved = adit_client.retrieve_nifti_image(server.ae_title, study_uid, series_uid, image_uid) + iterated = list(adit_client.iter_nifti_image(server.ae_title, study_uid, series_uid, image_uid)) + + assert len(iterated) == len(retrieved), ( + "iter and retrieve should return the same number of files" + ) + for (r_name, _), (i_name, _) in zip(retrieved, iterated): + assert r_name == i_name, f"Filenames should match: {r_name} != {i_name}" diff --git a/adit/dicom_web/tests/test_renderers.py b/adit/dicom_web/tests/test_renderers.py new file mode 100644 index 000000000..c212ff841 --- /dev/null +++ b/adit/dicom_web/tests/test_renderers.py @@ -0,0 +1,80 @@ +from io import BytesIO + +import pytest + +from adit.dicom_web.renderers import WadoMultipartApplicationNiftiRenderer + + +async def _collect_rendered_output(renderer, images): + chunks = [] + async for chunk in renderer.render(images): + chunks.append(chunk) + return b"".join(chunks) + + +async def _async_iter(items): + for item in items: + yield item + + +class TestWadoMultipartApplicationNiftiRenderer: + @pytest.mark.asyncio + async def test_render_single_file(self): + renderer = WadoMultipartApplicationNiftiRenderer() + content = b"fake nifti data" + files = [("scan.nii.gz", BytesIO(content))] + + output = await _collect_rendered_output(renderer, _async_iter(files)) + + assert b"--nifti-boundary\r\n" in output + assert b"Content-Type: application/octet-stream\r\n" in output + assert b'Content-Disposition: attachment; filename="scan.nii.gz"' in output + assert content in output + assert output.endswith(b"\r\n--nifti-boundary--\r\n") + + @pytest.mark.asyncio + async def test_render_multiple_files(self): + renderer = WadoMultipartApplicationNiftiRenderer() + files = [ + ("scan.json", BytesIO(b'{"key": "value"}')), + ("scan.nii.gz", BytesIO(b"nifti data")), + ] + + output = await _collect_rendered_output(renderer, _async_iter(files)) + + # First file starts without leading \r\n + assert output.startswith(b"--nifti-boundary\r\n") + # Second file separated by \r\n--boundary\r\n + assert b"\r\n--nifti-boundary\r\n" in output + # Both filenames present + assert b'filename="scan.json"' in output + assert b'filename="scan.nii.gz"' in output + assert output.endswith(b"\r\n--nifti-boundary--\r\n") + + @pytest.mark.asyncio + async def test_render_empty_iterator(self): + renderer = WadoMultipartApplicationNiftiRenderer() + + output = await _collect_rendered_output(renderer, _async_iter([])) + + # Only the closing boundary + assert output == b"\r\n--nifti-boundary--\r\n" + + @pytest.mark.asyncio + async def test_render_filename_sanitization(self): + renderer = WadoMultipartApplicationNiftiRenderer() + malicious_name = 'bad\r\nname"file.nii.gz' + files = [(malicious_name, BytesIO(b"data"))] + + output = await _collect_rendered_output(renderer, _async_iter(files)) + + # \r, \n, and " should be stripped + assert b'filename="badnamefile.nii.gz"' in output + assert b"\r\nname" not in output.split(b"Content-Disposition")[1].split(b"\r\n\r\n")[0] + + def test_content_type_property(self): + renderer = WadoMultipartApplicationNiftiRenderer() + + assert renderer.content_type == ( + "multipart/related; type=application/octet-stream; boundary=nifti-boundary" + ) diff --git a/adit/dicom_web/tests/utils/test_wado_retrieve_nifti.py b/adit/dicom_web/tests/utils/test_wado_retrieve_nifti.py new file mode 100644 index 000000000..01ae44eb7 --- /dev/null +++ b/adit/dicom_web/tests/utils/test_wado_retrieve_nifti.py @@ -0,0 +1,435 @@ +import asyncio +import logging +from io import BytesIO +from typing import cast +from unittest.mock import MagicMock + +import pytest +from pydicom import Dataset + +from adit.core.errors import ( + DcmToNiftiConversionError, + DicomError, + NoSpatialDataError, + NoValidDicomError, + RetriableDicomError, +) +from adit.core.models import DicomServer +from adit.dicom_web.errors import BadGatewayApiError, ServiceUnavailableApiError +from adit.dicom_web.utils import wadors_utils + +WADORS_LOGGER = "adit.dicom_web.utils.wadors_utils" + + +@pytest.fixture(autouse=True) +def _enable_log_propagation(): + """Enable propagation on the adit logger so caplog can capture log messages.""" + adit_logger = logging.getLogger("adit") + original = adit_logger.propagate + adit_logger.propagate = True + yield + adit_logger.propagate = original + + +# --- Fakes reused across tests (following test_wado_retrieve.py pattern) --- + + +class FakeDicomServer: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + for attr in [ + "patient_root_find_support", + "patient_root_get_support", + "patient_root_move_support", + "study_root_find_support", + "study_root_get_support", + "study_root_move_support", + "store_scp_support", + ]: + if not hasattr(self, attr): + setattr(self, attr, False) + + +def _make_server() -> DicomServer: + return cast( + DicomServer, FakeDicomServer(name="Test", ae_title="TEST", host="localhost", port=104) + ) + + +def immediate_sync_to_async(func, *, thread_sensitive=False): + async def wrapper(*args, **kwargs): + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) + + return wrapper + + +def _make_series_dataset(series_uid: str, modality: str) -> Dataset: + ds = Dataset() + ds.SeriesInstanceUID = series_uid + ds.Modality = modality + return ds + + +# --- wado_retrieve_nifti tests --- + + +class TestWadoRetrieveNifti: + @pytest.mark.asyncio + async def test_study_filters_non_image_modalities(self, monkeypatch): + """SR, KO, PR series should be skipped; CT series should be processed.""" + series_list = [ + _make_series_dataset("1.1", "CT"), + _make_series_dataset("1.2", "SR"), + _make_series_dataset("1.3", "KO"), + _make_series_dataset("1.4", "PR"), + ] + fetched_series_uids = [] + + class FakeOperator: + def __init__(self, server): + pass + + def find_series(self, query_ds): + return series_list + + def fake_fetch_dicom_data(source_server, query, level): + fetched_series_uids.append(query["SeriesInstanceUID"]) + return [Dataset()] + + async def fake_process_single_fetch(dicom_images): + yield ("test.nii.gz", BytesIO(b"nifti")) + + monkeypatch.setattr(wadors_utils, "DicomOperator", FakeOperator) + monkeypatch.setattr(wadors_utils, "_fetch_dicom_data", fake_fetch_dicom_data) + monkeypatch.setattr(wadors_utils, "_process_single_fetch", fake_process_single_fetch) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + query = {"PatientID": "P1", "StudyInstanceUID": "1.2.3"} + results = [] + async for item in wadors_utils.wado_retrieve_nifti(_make_server(), query, "STUDY"): + results.append(item) + + # Only the CT series should have been fetched + assert fetched_series_uids == ["1.1"] + assert len(results) == 1 + + @pytest.mark.asyncio + async def test_study_all_non_image_modalities(self, monkeypatch): + """Study with only SR/KO/PR series should yield nothing.""" + series_list = [ + _make_series_dataset("1.1", "SR"), + _make_series_dataset("1.2", "KO"), + ] + + class FakeOperator: + def __init__(self, server): + pass + + def find_series(self, query_ds): + return series_list + + monkeypatch.setattr(wadors_utils, "DicomOperator", FakeOperator) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + query = {"PatientID": "P1", "StudyInstanceUID": "1.2.3"} + results = [] + async for item in wadors_utils.wado_retrieve_nifti(_make_server(), query, "STUDY"): + results.append(item) + + assert results == [] + + @pytest.mark.asyncio + async def test_series_level(self, monkeypatch): + """Series-level should fetch directly without modality filtering.""" + + def fake_fetch_dicom_data(source_server, query, level): + assert level == "SERIES" + return [Dataset()] + + async def fake_process_single_fetch(dicom_images): + yield ("series.nii.gz", BytesIO(b"data")) + + monkeypatch.setattr(wadors_utils, "DicomOperator", lambda s: None) + monkeypatch.setattr(wadors_utils, "_fetch_dicom_data", fake_fetch_dicom_data) + monkeypatch.setattr(wadors_utils, "_process_single_fetch", fake_process_single_fetch) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + query = { + "PatientID": "P1", + "StudyInstanceUID": "1.2.3", + "SeriesInstanceUID": "1.2.3.4", + } + results = [] + async for item in wadors_utils.wado_retrieve_nifti(_make_server(), query, "SERIES"): + results.append(item) + + assert len(results) == 1 + assert results[0][0] == "series.nii.gz" + + @pytest.mark.asyncio + async def test_image_level(self, monkeypatch): + """Image-level should fetch directly without modality filtering.""" + + def fake_fetch_dicom_data(source_server, query, level): + assert level == "IMAGE" + return [Dataset()] + + async def fake_process_single_fetch(dicom_images): + yield ("image.nii.gz", BytesIO(b"data")) + + monkeypatch.setattr(wadors_utils, "DicomOperator", lambda s: None) + monkeypatch.setattr(wadors_utils, "_fetch_dicom_data", fake_fetch_dicom_data) + monkeypatch.setattr(wadors_utils, "_process_single_fetch", fake_process_single_fetch) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + query = { + "PatientID": "P1", + "StudyInstanceUID": "1.2.3", + "SeriesInstanceUID": "1.2.3.4", + "SOPInstanceUID": "1.2.3.4.5", + } + results = [] + async for item in wadors_utils.wado_retrieve_nifti(_make_server(), query, "IMAGE"): + results.append(item) + + assert len(results) == 1 + assert results[0][0] == "image.nii.gz" + + @pytest.mark.asyncio + async def test_retriable_error(self, monkeypatch): + """RetriableDicomError should be wrapped as ServiceUnavailableApiError.""" + + def fake_fetch_dicom_data(source_server, query, level): + raise RetriableDicomError("timeout") + + monkeypatch.setattr(wadors_utils, "DicomOperator", lambda s: None) + monkeypatch.setattr(wadors_utils, "_fetch_dicom_data", fake_fetch_dicom_data) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + query = { + "PatientID": "P1", + "StudyInstanceUID": "1.2.3", + "SeriesInstanceUID": "1.2.3.4", + } + + with pytest.raises(ServiceUnavailableApiError): + async for _ in wadors_utils.wado_retrieve_nifti(_make_server(), query, "SERIES"): + pass + + @pytest.mark.asyncio + async def test_non_retriable_error(self, monkeypatch): + """DicomError should be wrapped as BadGatewayApiError.""" + + def fake_fetch_dicom_data(source_server, query, level): + raise DicomError("permanent failure") + + monkeypatch.setattr(wadors_utils, "DicomOperator", lambda s: None) + monkeypatch.setattr(wadors_utils, "_fetch_dicom_data", fake_fetch_dicom_data) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + query = { + "PatientID": "P1", + "StudyInstanceUID": "1.2.3", + "SeriesInstanceUID": "1.2.3.4", + } + + with pytest.raises(BadGatewayApiError): + async for _ in wadors_utils.wado_retrieve_nifti(_make_server(), query, "SERIES"): + pass + + +# --- _process_single_fetch tests --- + + +class TestProcessSingleFetch: + @pytest.mark.asyncio + async def test_yields_files_in_order(self, tmp_path, monkeypatch): + """Files should be yielded in order: json, nifti, bval, bvec.""" + nifti_output_dir = tmp_path / "nifti_output" + nifti_output_dir.mkdir() + + # Create fake output files + (nifti_output_dir / "scan.json").write_text('{"key": "val"}') + (nifti_output_dir / "scan.nii.gz").write_bytes(b"nifti data") + (nifti_output_dir / "scan.bval").write_text("0 1000") + (nifti_output_dir / "scan.bvec").write_text("1 0 0") + + monkeypatch.setattr( + wadors_utils, "DicomToNiftiConverter", lambda: MagicMock(convert=MagicMock()) + ) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + # Patch TemporaryDirectory to use our tmp_path + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + ds = Dataset() + ds.PatientID = "P1" + results = [] + async for filename, content in wadors_utils._process_single_fetch([ds]): + results.append((filename, content.read())) + + filenames = [r[0] for r in results] + assert filenames == ["scan.json", "scan.nii.gz", "scan.bval", "scan.bvec"] + + @pytest.mark.asyncio + async def test_handles_nii_without_gz(self, tmp_path, monkeypatch): + """Uncompressed .nii files should also be yielded.""" + nifti_output_dir = tmp_path / "nifti_output" + nifti_output_dir.mkdir() + (nifti_output_dir / "scan.nii").write_bytes(b"nifti data") + + monkeypatch.setattr( + wadors_utils, "DicomToNiftiConverter", lambda: MagicMock(convert=MagicMock()) + ) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + results = [] + async for filename, content in wadors_utils._process_single_fetch([Dataset()]): + results.append(filename) + + assert results == ["scan.nii"] + + @pytest.mark.asyncio + async def test_no_valid_dicom_logs_warning(self, tmp_path, monkeypatch, caplog): + """NoValidDicomError should log a warning and yield nothing.""" + converter_mock = MagicMock() + converter_mock.convert.side_effect = NoValidDicomError("no dicom") + + monkeypatch.setattr(wadors_utils, "DicomToNiftiConverter", lambda: converter_mock) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + results = [] + with caplog.at_level(logging.WARNING, logger=WADORS_LOGGER): + async for item in wadors_utils._process_single_fetch([Dataset()]): + results.append(item) + + assert results == [] + assert any("conversion failed unexpectedly" in msg for msg in caplog.messages) + + @pytest.mark.asyncio + async def test_no_spatial_data_logs_warning(self, tmp_path, monkeypatch, caplog): + """NoSpatialDataError should log a warning and yield nothing.""" + converter_mock = MagicMock() + converter_mock.convert.side_effect = NoSpatialDataError("no spatial") + + monkeypatch.setattr(wadors_utils, "DicomToNiftiConverter", lambda: converter_mock) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + results = [] + with caplog.at_level(logging.WARNING, logger=WADORS_LOGGER): + async for item in wadors_utils._process_single_fetch([Dataset()]): + results.append(item) + + assert results == [] + assert any("conversion failed unexpectedly" in msg for msg in caplog.messages) + + @pytest.mark.asyncio + async def test_conversion_error_logs_warning(self, tmp_path, monkeypatch, caplog): + """DcmToNiftiConversionError should log a warning and yield nothing.""" + converter_mock = MagicMock() + converter_mock.convert.side_effect = DcmToNiftiConversionError("convert failed") + + monkeypatch.setattr(wadors_utils, "DicomToNiftiConverter", lambda: converter_mock) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + results = [] + with caplog.at_level(logging.WARNING, logger=WADORS_LOGGER): + async for item in wadors_utils._process_single_fetch([Dataset()]): + results.append(item) + + assert results == [] + assert any("Failed to convert DICOM" in msg for msg in caplog.messages) + + @pytest.mark.asyncio + async def test_unexpected_error_propagates(self, tmp_path, monkeypatch): + """Generic exceptions should be re-raised.""" + converter_mock = MagicMock() + converter_mock.convert.side_effect = RuntimeError("unexpected") + + monkeypatch.setattr(wadors_utils, "DicomToNiftiConverter", lambda: converter_mock) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + with pytest.raises(RuntimeError, match="unexpected"): + async for _ in wadors_utils._process_single_fetch([Dataset()]): + pass + + @pytest.mark.asyncio + async def test_empty_dicom_list(self, tmp_path, monkeypatch): + """Empty dicom list should still attempt conversion (dcm2niix handles it).""" + nifti_output_dir = tmp_path / "nifti_output" + nifti_output_dir.mkdir() + + converter_mock = MagicMock() + converter_mock.convert.side_effect = NoValidDicomError("no dicom") + + monkeypatch.setattr(wadors_utils, "DicomToNiftiConverter", lambda: converter_mock) + monkeypatch.setattr(wadors_utils, "write_dataset", lambda ds, path: None) + monkeypatch.setattr(wadors_utils, "sync_to_async", immediate_sync_to_async) + + from contextlib import asynccontextmanager + + @asynccontextmanager + async def fake_temp_dir(): + yield str(tmp_path) + + monkeypatch.setattr(wadors_utils, "TemporaryDirectory", fake_temp_dir) + + results = [] + async for item in wadors_utils._process_single_fetch([]): + results.append(item) + + assert results == [] diff --git a/adit/dicom_web/urls.py b/adit/dicom_web/urls.py index d5c2f88ca..30d355845 100644 --- a/adit/dicom_web/urls.py +++ b/adit/dicom_web/urls.py @@ -6,6 +6,9 @@ QueryStudiesAPIView, RetrieveImageAPIView, RetrieveImageMetadataAPIView, + RetrieveNiftiImageAPIView, + RetrieveNiftiSeriesAPIView, + RetrieveNiftiStudyAPIView, RetrieveSeriesAPIView, RetrieveSeriesMetadataAPIView, RetrieveStudyAPIView, @@ -69,6 +72,21 @@ RetrieveImageMetadataAPIView.as_view(), name="wado_rs-image_metadata_with_study_uid_and_series_uid_and_image_uid", ), + path( + "/wadors/studies//nifti", + RetrieveNiftiStudyAPIView.as_view(), + name="wado_rs-nifti_study_with_study_uid", + ), + path( + "/wadors/studies//series//nifti", + RetrieveNiftiSeriesAPIView.as_view(), + name="wado_rs-nifti_series_with_study_uid_and_series_uid", + ), + path( + "/wadors/studies//series//instances//nifti", + RetrieveNiftiImageAPIView.as_view(), + name="wado_rs-nifti_image_with_study_uid_and_series_uid_and_image_uid", + ), path( "/stowrs/studies", StoreImagesAPIView.as_view(), diff --git a/adit/dicom_web/utils/wadors_utils.py b/adit/dicom_web/utils/wadors_utils.py index 732bc6823..6a9cc50b1 100644 --- a/adit/dicom_web/utils/wadors_utils.py +++ b/adit/dicom_web/utils/wadors_utils.py @@ -1,20 +1,38 @@ import asyncio import logging +import os from collections.abc import Callable +from io import BytesIO +from pathlib import Path from typing import AsyncIterator, Literal +import aiofiles +import aiofiles.os from adrf.views import sync_to_async +from aiofiles.tempfile import TemporaryDirectory from pydicom import Dataset -from adit.core.errors import DicomError, RetriableDicomError +from adit.core.errors import ( + DcmToNiftiConversionError, + DicomError, + NoSpatialDataError, + NoValidDicomError, + RetriableDicomError, +) from adit.core.models import DicomServer from adit.core.utils.dicom_dataset import QueryDataset from adit.core.utils.dicom_manipulator import DicomManipulator from adit.core.utils.dicom_operator import DicomOperator +from adit.core.utils.dicom_to_nifti_converter import DicomToNiftiConverter +from adit.core.utils.dicom_utils import write_dataset from ..errors import BadGatewayApiError, ServiceUnavailableApiError -logger = logging.getLogger("__name__") +logger = logging.getLogger(__name__) + +# Modalities that are known to not contain image data and cannot be converted to NIfTI. +# SR = Structured Reports, KO = Key Object Selection, PR = Presentation State. +NON_IMAGE_MODALITIES = {"SR", "KO", "PR"} async def wado_retrieve( @@ -86,15 +104,19 @@ def fetch_with_sentinel(fetch_func: Callable[..., None], **kwargs: object) -> No # Start fetch task. Sentinel will be added via call_soon_threadsafe when done. fetch_task = asyncio.create_task(fetch_coro) + sentinel_seen = False try: while True: queue_ds = await queue.get() if queue_ds is None: + sentinel_seen = True break yield queue_ds finally: - # Ensure fetch task is properly awaited even if consumer stops early - if not fetch_task.done(): + # Only cancel if consumer exited early without seeing the sentinel. + # If sentinel was seen, fetch_task will complete on its own — let its + # exception (if any) propagate to the outer except handlers. + if not sentinel_seen and not fetch_task.done(): fetch_task.cancel() try: await fetch_task @@ -105,3 +127,169 @@ def fetch_with_sentinel(fetch_func: Callable[..., None], **kwargs: object) -> No raise ServiceUnavailableApiError(str(exc)) except DicomError as exc: raise BadGatewayApiError(str(exc)) + + +def _fetch_dicom_data( + source_server: DicomServer, + query: dict[str, str], + level: Literal["STUDY", "SERIES", "IMAGE"], +) -> list[Dataset]: + """Fetch DICOM data synchronously and return the list of datasets.""" + operator = DicomOperator(source_server) + query_ds = QueryDataset.from_dict(query) + dicom_images: list[Dataset] = [] + + def callback(ds: Dataset) -> None: + dicom_images.append(ds) + + if level == "STUDY": + operator.fetch_study( + patient_id=query_ds.PatientID, + study_uid=query_ds.StudyInstanceUID, + callback=callback, + ) + elif level == "SERIES": + operator.fetch_series( + patient_id=query_ds.PatientID, + study_uid=query_ds.StudyInstanceUID, + series_uid=query_ds.SeriesInstanceUID, + callback=callback, + ) + elif level == "IMAGE": + assert query_ds.has("SeriesInstanceUID") + operator.fetch_image( + patient_id=query_ds.PatientID, + study_uid=query_ds.StudyInstanceUID, + series_uid=query_ds.SeriesInstanceUID, + image_uid=query_ds.SOPInstanceUID, + callback=callback, + ) + else: + raise ValueError(f"Invalid WADO-RS level: {level}.") + + return dicom_images + + +async def wado_retrieve_nifti( + source_server: DicomServer, + query: dict[str, str], + level: Literal["STUDY", "SERIES", "IMAGE"], +) -> AsyncIterator[tuple[str, BytesIO]]: + """Retrieve DICOM data and convert to NIfTI format. + + Returns the generated files (NIfTI, JSON, bval, bvec) as tuples of + (filename, file_content). + + For study-level requests, fetches each series individually to prevent + loading the entire study into memory at once. Non-image series (SR, KO, PR) + are skipped before fetching. + """ + operator = DicomOperator(source_server) + + try: + if level == "STUDY": + series_list = await sync_to_async(operator.find_series, thread_sensitive=False)( + QueryDataset.create( + StudyInstanceUID=query["StudyInstanceUID"], + ) + ) + + for series in series_list: + modality = series.Modality + if modality in NON_IMAGE_MODALITIES: + logger.debug( + f"Skipping non-image series {series.SeriesInstanceUID} " + f"(modality: {modality})" + ) + continue + + series_query = { + "PatientID": query["PatientID"], + "StudyInstanceUID": query["StudyInstanceUID"], + "SeriesInstanceUID": series.SeriesInstanceUID, + } + + dicom_images = await sync_to_async(_fetch_dicom_data, thread_sensitive=False)( + source_server, series_query, "SERIES" + ) + + async for filename, file_content in _process_single_fetch(dicom_images): + yield filename, file_content + else: + dicom_images = await sync_to_async(_fetch_dicom_data, thread_sensitive=False)( + source_server, query, level + ) + async for filename, file_content in _process_single_fetch(dicom_images): + yield filename, file_content + + except RetriableDicomError as err: + raise ServiceUnavailableApiError(str(err)) + except DicomError as err: + raise BadGatewayApiError(str(err)) + + +async def _process_single_fetch( + dicom_images: list[Dataset], +) -> AsyncIterator[tuple[str, BytesIO]]: + """Convert a list of DICOM datasets to NIfTI format and yield the resulting files. + + For each conversion output group (identified by base filename), yields files in order: + JSON sidecar first, then NIfTI (.nii.gz or .nii), then bval, then bvec. + + If conversion fails with NoValidDicomError or NoSpatialDataError, a warning is logged + because the series was expected to contain image data (non-image modalities are filtered + out before this function is called). + """ + async with TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) + + for file_idx, dicom_image in enumerate(dicom_images): + dicom_file_path = temp_path / f"dicom_file_{file_idx}.dcm" + await sync_to_async(write_dataset, thread_sensitive=False)(dicom_image, dicom_file_path) + + nifti_output_dir = temp_path / "nifti_output" + await aiofiles.os.makedirs(nifti_output_dir, exist_ok=True) + converter = DicomToNiftiConverter() + + try: + await sync_to_async(converter.convert, thread_sensitive=False)( + temp_path, nifti_output_dir + ) + except (NoValidDicomError, NoSpatialDataError) as e: + # The series passed the modality check but still failed conversion. + # This is unexpected and worth logging as a warning. + logger.warning(f"Series conversion failed unexpectedly: {e}") + return + except DcmToNiftiConversionError as e: + logger.warning(f"Failed to convert DICOM files to NIfTI: {e}") + return + except Exception as e: + logger.error(f"Error during DICOM to NIfTI conversion: {e}") + raise + + entries = await aiofiles.os.scandir(nifti_output_dir) + all_files = [entry.name for entry in entries] + + file_pairs: dict[str, dict[str, str]] = {} + for filename in all_files: + base_name, ext = os.path.splitext(filename) + if ext == ".json": + file_pairs.setdefault(base_name, {})["json"] = filename + elif ext == ".gz" and base_name.endswith(".nii"): + actual_base = os.path.splitext(base_name)[0] + file_pairs.setdefault(actual_base, {})["nifti"] = filename + elif ext == ".nii": + file_pairs.setdefault(base_name, {})["nifti"] = filename + elif ext == ".bval": + file_pairs.setdefault(base_name, {})["bval"] = filename + elif ext == ".bvec": + file_pairs.setdefault(base_name, {})["bvec"] = filename + + file_order = ["json", "nifti", "bval", "bvec"] + for _base_name, files in file_pairs.items(): + for file_type in file_order: + if file_type in files: + file_path = os.path.join(nifti_output_dir, files[file_type]) + async with aiofiles.open(file_path, "rb") as f: + content = await f.read() + yield files[file_type], BytesIO(content) diff --git a/adit/dicom_web/views.py b/adit/dicom_web/views.py index ca9f43860..589e2cabd 100644 --- a/adit/dicom_web/views.py +++ b/adit/dicom_web/views.py @@ -33,10 +33,11 @@ StowApplicationDicomJsonRenderer, WadoApplicationDicomJsonRenderer, WadoMultipartApplicationDicomRenderer, + WadoMultipartApplicationNiftiRenderer, ) from .utils.qidors_utils import qido_find from .utils.stowrs_utils import stow_store -from .utils.wadors_utils import wado_retrieve +from .utils.wadors_utils import wado_retrieve, wado_retrieve_nifti logger = logging.getLogger(__name__) @@ -323,6 +324,31 @@ async def get( ) +class RetrieveNiftiStudyAPIView(RetrieveAPIView): + renderer_classes = [WadoMultipartApplicationNiftiRenderer] + + async def get( + self, request: AuthenticatedApiRequest, ae_title: str, study_uid: str + ) -> StreamingHttpResponse: + async with self.track_session(request.user) as session: + source_server = await self._get_dicom_server(request, ae_title) + + query = self.query.copy() + query["StudyInstanceUID"] = study_uid + + images = wado_retrieve_nifti(source_server, query, "STUDY") + + renderer = cast( + WadoMultipartApplicationNiftiRenderer, getattr(request, "accepted_renderer") + ) + return StreamingHttpResponse( + streaming_content=_StreamingSessionWrapper( + renderer.render(images), session, self._finalize_statistic + ), + content_type=renderer.content_type, + ) + + class RetrieveStudyMetadataAPIView(RetrieveStudyAPIView): async def get( self, request: AuthenticatedApiRequest, ae_title: str, study_uid: str @@ -388,6 +414,32 @@ async def get( ) +class RetrieveNiftiSeriesAPIView(RetrieveAPIView): + renderer_classes = [WadoMultipartApplicationNiftiRenderer] + + async def get( + self, request: AuthenticatedApiRequest, ae_title: str, study_uid: str, series_uid: str + ) -> StreamingHttpResponse: + async with self.track_session(request.user) as session: + source_server = await self._get_dicom_server(request, ae_title) + + query = self.query.copy() + query["StudyInstanceUID"] = study_uid + query["SeriesInstanceUID"] = series_uid + + images = wado_retrieve_nifti(source_server, query, "SERIES") + + renderer = cast( + WadoMultipartApplicationNiftiRenderer, getattr(request, "accepted_renderer") + ) + return StreamingHttpResponse( + streaming_content=_StreamingSessionWrapper( + renderer.render(images), session, self._finalize_statistic + ), + content_type=renderer.content_type, + ) + + class RetrieveSeriesMetadataAPIView(RetrieveSeriesAPIView): async def get( self, request: AuthenticatedApiRequest, ae_title: str, study_uid: str, series_uid: str @@ -460,6 +512,38 @@ async def get( ) +class RetrieveNiftiImageAPIView(RetrieveAPIView): + renderer_classes = [WadoMultipartApplicationNiftiRenderer] + + async def get( + self, + request: AuthenticatedApiRequest, + ae_title: str, + study_uid: str, + series_uid: str, + image_uid: str, + ) -> StreamingHttpResponse: + async with self.track_session(request.user) as session: + source_server = await self._get_dicom_server(request, ae_title) + + query = self.query.copy() + query["StudyInstanceUID"] = study_uid + query["SeriesInstanceUID"] = series_uid + query["SOPInstanceUID"] = image_uid + + images = wado_retrieve_nifti(source_server, query, "IMAGE") + + renderer = cast( + WadoMultipartApplicationNiftiRenderer, getattr(request, "accepted_renderer") + ) + return StreamingHttpResponse( + streaming_content=_StreamingSessionWrapper( + renderer.render(images), session, self._finalize_statistic + ), + content_type=renderer.content_type, + ) + + class RetrieveImageMetadataAPIView(RetrieveImageAPIView): async def get( self,