From f217240a76b6631dd0370c11951cdcf81212ac39 Mon Sep 17 00:00:00 2001 From: Sameer Chaturvedi Date: Mon, 23 Feb 2026 00:22:53 -0500 Subject: [PATCH 1/6] Add DataConnection and ProcessConnection ABCs Introduce abstract base classes for the backend abstraction layer. DataConnection defines the per-shot data access interface (get_data, get_data_with_dims, get_dims, cleanup). ProcessConnection defines the per-process factory interface (get_shot_connection, from_config). Signed-off-by: Sameer Chaturvedi --- disruption_py/inout/__init__.py | 5 ++ disruption_py/inout/base.py | 122 ++++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+) create mode 100644 disruption_py/inout/base.py diff --git a/disruption_py/inout/__init__.py b/disruption_py/inout/__init__.py index e69de29bb..125615ceb 100644 --- a/disruption_py/inout/__init__.py +++ b/disruption_py/inout/__init__.py @@ -0,0 +1,5 @@ +"""Data connection abstractions and implementations.""" + +from disruption_py.inout.base import DataConnection, ProcessConnection + +__all__ = ["DataConnection", "ProcessConnection"] diff --git a/disruption_py/inout/base.py b/disruption_py/inout/base.py new file mode 100644 index 000000000..da4f72d36 --- /dev/null +++ b/disruption_py/inout/base.py @@ -0,0 +1,122 @@ +"""Abstract base classes for data connections. + +DataConnection: per-shot data access (get_data, get_data_with_dims, get_dims). +ProcessConnection: per-process factory that creates DataConnection instances. +""" + +from abc import ABC, abstractmethod +from typing import List, Tuple + +import numpy as np + +from disruption_py.machine.tokamak import Tokamak + + +class DataConnection(ABC): + """Per-shot data access interface. + + Each instance is bound to a single shot. Implementations must provide + get_data, get_data_with_dims, get_dims, and cleanup. The reconnect + method is optional (default no-op). + """ + + @property + @abstractmethod + def shot_id(self) -> int: + """The shot ID this connection is bound to.""" + + @abstractmethod + def get_data(self, path: str, group: str = None, **kwargs) -> np.ndarray: + """Get data at path. + + Parameters + ---------- + path : str + Data path (node path for MDSplus, variable name for Xarray). + group : str, optional + Container name (tree for MDSplus, group for Xarray). + **kwargs + Backend-specific options. + + Returns + ------- + np.ndarray + """ + + @abstractmethod + def get_data_with_dims( + self, + path: str, + group: str = None, + dim_nums: List = None, + **kwargs, + ) -> Tuple: + """Get data and dimension arrays. + + Parameters + ---------- + path : str + Data path. + group : str, optional + Container name. + dim_nums : List, optional + Dimension indices to retrieve. Default [0]. + **kwargs + Backend-specific options. + + Returns + ------- + Tuple + (data, dim0, dim1, ...) as numpy arrays. + """ + + @abstractmethod + def get_dims( + self, + path: str, + group: str = None, + dim_nums: List = None, + **kwargs, + ) -> Tuple: + """Get only dimension arrays. + + Parameters + ---------- + path : str + Data path. + group : str, optional + Container name. + dim_nums : List, optional + Dimension indices to retrieve. Default [0]. + **kwargs + Backend-specific options. + + Returns + ------- + Tuple + Requested dimensions. + """ + + @abstractmethod + def cleanup(self) -> None: + """Release resources for this shot.""" + + def reconnect(self) -> None: + """Reconnect after error. Default no-op. + + Xarray opens a new DataTree per shot so there is nothing to + reconnect. MDSplus overrides this to call conn.reconnect(). + """ + + +class ProcessConnection(ABC): + """Per-process factory that creates DataConnection instances.""" + + @abstractmethod + def get_shot_connection(self, shot_id: int) -> DataConnection: + """Create a per-shot DataConnection for the given shot.""" + + @classmethod + @abstractmethod + def from_config(cls, tokamak: Tokamak) -> "ProcessConnection": + """Create a ProcessConnection from tokamak configuration.""" From 3d0dbdbbe86b22675a91a2d922e7d4681007be1c Mon Sep 17 00:00:00 2001 From: Sameer Chaturvedi Date: Mon, 23 Feb 2026 00:39:50 -0500 Subject: [PATCH 2/6] Adapt MDSConnection to implement DataConnection ABC ProcessMDSConnection now inherits from ProcessConnection. MDSConnection now inherits from DataConnection. shot_id is a property, get_data/get_data_with_dims/get_dims accept group as alias for tree_name, and get_dims returns tuple instead of list. Signed-off-by: Sameer Chaturvedi --- disruption_py/inout/mds.py | 49 ++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/disruption_py/inout/mds.py b/disruption_py/inout/mds.py index 74630fbc5..f1f38ce08 100644 --- a/disruption_py/inout/mds.py +++ b/disruption_py/inout/mds.py @@ -14,6 +14,7 @@ from disruption_py.config import config from disruption_py.core.utils.misc import shot_msg from disruption_py.core.utils.shared_instance import SharedInstance +from disruption_py.inout.base import DataConnection, ProcessConnection from disruption_py.machine.tokamak import Tokamak try: @@ -52,9 +53,9 @@ class MdsException(Exception): mdsExceptions = MDSplus.mdsExceptions -class ProcessMDSConnection: +class ProcessMDSConnection(ProcessConnection): """ - Abstract class for connecting to MDSplus. + Process-level MDSplus connection. Ensure that a single MDSPlus connection is used by each process for all shots retrieved by that process. @@ -73,7 +74,7 @@ def __init__(self, conn_string: str): self.conn = MDSplus.Connection(conn_string) @classmethod - def from_config(cls, tokamak: Tokamak): + def from_config(cls, tokamak: Tokamak) -> "ProcessMDSConnection": """ Create instance of the MDS connection based on the connection string from the configuration. @@ -82,7 +83,7 @@ def from_config(cls, tokamak: Tokamak): config(tokamak).inout.mds.mdsplus_connection_string ) - def get_shot_connection(self, shot_id: int): + def get_shot_connection(self, shot_id: int) -> "MDSConnection": """Get MDSPlus Connection wrapper for individual shot.""" return MDSConnection(self.conn, shot_id) @@ -129,7 +130,7 @@ def wrapper(*args, **kwargs): return wrapper -class MDSConnection: +class MDSConnection(DataConnection): """ Wrapper class for MDSPlus Connection class used for handling individual shots. """ @@ -138,11 +139,16 @@ def __init__( self, conn: MDSplus.Connection, shot_id: int # pylint: disable=no-member ): self.conn = conn - self.shot_id = shot_id + self._shot_id = shot_id self.tree_nickname_funcs = {} self.tree_nicknames = {} self.open_trees = [] + @property + def shot_id(self) -> int: + """The shot ID this connection is bound to.""" + return self._shot_id + def reconnect(self): """ Reconnect to the MDSplus server. @@ -233,8 +239,10 @@ def get(self, expression: str, arguments: Any = None, tree_name: str = None) -> def get_data( self, path: str, + group: str = None, tree_name: str = None, arguments: Any = None, + **kwargs, ) -> np.ndarray: """ Get data for record at specified path. @@ -243,6 +251,8 @@ def get_data( ---------- path : str MDSplus path to record. + group : str, optional + Alias for tree_name (generic interface). tree_name : str, optional The name of the tree that must be open for retrieval. arguments : Any, optional @@ -254,6 +264,7 @@ def get_data( np.ndarray Returns the node data. """ + tree_name = tree_name or group if tree_name is not None: self.open_tree(tree_name) @@ -267,8 +278,10 @@ def get_data( def get_data_with_dims( self, path: str, - tree_name: str = None, + group: str = None, dim_nums: List = None, + tree_name: str = None, + **kwargs, ) -> Tuple: """ Get data and dimension(s) for record at specified path. @@ -277,17 +290,19 @@ def get_data_with_dims( ---------- path : str MDSplus path to record. - tree_name : str, optional - The name of the tree that must be open for retrieval. + group : str, optional + Alias for tree_name (generic interface). dim_nums : List, optional A list of dimensions that should have their size retrieved. Default [0]. + tree_name : str, optional + The name of the tree that must be open for retrieval. Returns ------- Tuple Returns the node data, followed by the requested dimensions. """ - + tree_name = tree_name or group dim_nums = dim_nums or [0] if tree_name is not None: @@ -305,8 +320,10 @@ def get_data_with_dims( def get_dims( self, path: str, - tree_name: str = None, + group: str = None, dim_nums: List = None, + tree_name: str = None, + **kwargs, ) -> Tuple: """ Get the specified dimensions for record at specified path. @@ -315,17 +332,19 @@ def get_dims( ---------- path : str MDSplus path to record. - tree_name : str, optional - The name of the tree that must be open for retrieval. + group : str, optional + Alias for tree_name (generic interface). dim_nums : List, optional A list of dimensions that should have their size retrieved. Default [0]. + tree_name : str, optional + The name of the tree that must be open for retrieval. Returns ------- Tuple Returns the requested dimensions as a tuple. """ - + tree_name = tree_name or group dim_nums = dim_nums or [0] if tree_name is not None: @@ -334,7 +353,7 @@ def get_dims( logger.trace(shot_msg("Getting dims: {path}"), shot=self.shot_id, path=path) dims = [self.conn.get(f"dim_of({path},{d})").data() for d in dim_nums] - return dims + return tuple(dims) # nicknames From 7ec65461c5a12b3608cd086b08a130458fa07896 Mon Sep 17 00:00:00 2001 From: Sameer Chaturvedi Date: Tue, 24 Feb 2026 00:18:53 -0500 Subject: [PATCH 3/6] Split XarrayConnection into process-level and per-shot classes ProcessXarrayConnection(ProcessConnection) holds config and creates per-shot XarrayDataConnection(DataConnection) instances, replacing the old pattern where get_shot_connection() returned self. Adds get_data_with_dims() and get_dims() to the Xarray backend. cleanup() now closes the DataTree. Backward compat alias preserved. Signed-off-by: Sameer Chaturvedi --- disruption_py/inout/xr.py | 163 ++++++++++++++++++++++++++++++++------ 1 file changed, 139 insertions(+), 24 deletions(-) diff --git a/disruption_py/inout/xr.py b/disruption_py/inout/xr.py index d1f634dff..6d9ad7e9a 100644 --- a/disruption_py/inout/xr.py +++ b/disruption_py/inout/xr.py @@ -5,6 +5,7 @@ """ import threading +from typing import List, Tuple import numpy as np import xarray as xr @@ -12,12 +13,16 @@ from disruption_py.config import config from disruption_py.core.utils.misc import shot_msg +from disruption_py.inout.base import DataConnection, ProcessConnection from disruption_py.machine.tokamak import Tokamak -class XarrayConnection: +class ProcessXarrayConnection(ProcessConnection): """ - Class for connecting to Xarray store. + Process-level Xarray connection. + + Holds configuration for the Xarray store and creates per-shot + XarrayDataConnection instances. """ def __init__( @@ -32,7 +37,6 @@ def __init__( self.endpoint_url = endpoint_url self.file_path = file_path self.file_ext = file_ext - self.data_tree: xr.DataTree | None = None logger.debug( "PID #{pid} | Connecting to Xarray store: {server}", @@ -49,16 +53,20 @@ def folder_path(self): return self.file_path @classmethod - def from_config(cls, tokamak: Tokamak): + def from_config(cls, tokamak: Tokamak) -> "ProcessXarrayConnection": """ - Create instance of the connection based on the file path, file extension, and endpoint URL - from the configuration. + Create instance of the connection based on the file path, file extension, + and endpoint URL from the configuration. """ params = config(tokamak).inout.xarray - return XarrayConnection(**params) + return ProcessXarrayConnection(**params) + + def get_shot_file_path(self, shot_id: int): + """Get file path for individual shot.""" + return f"{self.folder_path}/{shot_id}.{self.file_ext}" - def get_shot_connection(self, shot_id: int): - """Get connection to xarray store for individual shot.""" + def get_shot_connection(self, shot_id: int) -> "XarrayDataConnection": + """Get per-shot data connection for individual shot.""" file_path = self.get_shot_file_path(shot_id) engine = "zarr" if self.file_ext == "zarr" else "netcdf4" @@ -68,20 +76,54 @@ def get_shot_connection(self, shot_id: int): file_path=file_path, ) - self.data_tree = xr.open_datatree( + data_tree = xr.open_datatree( file_path, engine=engine, chunks=None, create_default_indexes=False ) - return self + return XarrayDataConnection(shot_id, data_tree) - def get_shot_file_path(self, shot_id: int): - """Get file path for individual shot.""" - file_path = f"{self.folder_path}/{shot_id}.{self.file_ext}" - return file_path - def get_data(self, shot_id: int, path: str, return_xarray: bool = False): - """Get data from the connection.""" - if self.data_tree is None: - self.get_shot_connection(shot_id) +class XarrayDataConnection(DataConnection): + """ + Per-shot Xarray data connection wrapping a DataTree. + """ + + def __init__(self, shot_id: int, data_tree: xr.DataTree): + self._shot_id = shot_id + self.data_tree = data_tree + + @property + def shot_id(self) -> int: + """The shot ID this connection is bound to.""" + return self._shot_id + + def _resolve_path(self, path: str, group: str = None) -> str: + """Prepend group to path if group is provided and path has no '/'.""" + if group is not None and "/" not in path: + return f"{group}/{path}" + return path + + def get_data( + self, path: str, group: str = None, return_xarray: bool = False, **kwargs + ) -> np.ndarray: + """Get data from the connection. + + Parameters + ---------- + path : str + Variable path, e.g. "summary/ip" or just "ip" if group is provided. + group : str, optional + Group prefix. If provided and path has no "/", prepends group. + return_xarray : bool, optional + If True, return the raw xarray DataArray instead of numpy values. + **kwargs + Backend-specific options (ignored). + + Returns + ------- + np.ndarray or xr.DataArray + """ + path = self._resolve_path(path, group) + logger.trace(shot_msg("Getting data: {path}"), shot=self._shot_id, path=path) try: item = self.data_tree[path] @@ -94,7 +136,7 @@ def get_data(self, shot_id: int, path: str, return_xarray: bool = False): logger.warning( shot_msg("Variable not found: {path}"), path=path, - shot=shot_id, + shot=self._shot_id, ) if return_xarray: @@ -102,8 +144,81 @@ def get_data(self, shot_id: int, path: str, return_xarray: bool = False): return np.array([np.nan]) - def cleanup(self): - """Cleanup the connection.""" + def get_data_with_dims( + self, + path: str, + group: str = None, + dim_nums: List = None, + **kwargs, + ) -> Tuple: + """Get data and dimension arrays from the DataTree. + + Parameters + ---------- + path : str + Variable path. + group : str, optional + Group prefix. + dim_nums : List, optional + Dimension indices to retrieve. Default [0]. + **kwargs + Backend-specific options (ignored). + + Returns + ------- + Tuple + (data, dim0, dim1, ...) as numpy arrays. + """ + path = self._resolve_path(path, group) + dim_nums = dim_nums or [0] + logger.trace( + shot_msg("Getting data and dims: {path}"), shot=self._shot_id, path=path + ) + + item = self.data_tree[path] + data = item.values + dim_names = list(item.dims) + dims = [] + for d in dim_nums: + dim_name = dim_names[d] + dims.append(item.coords[dim_name].values) + + return data, *dims + + def get_dims( + self, + path: str, + group: str = None, + dim_nums: List = None, + **kwargs, + ) -> Tuple: + """Get only dimension arrays. + + Parameters + ---------- + path : str + Variable path. + group : str, optional + Group prefix. + dim_nums : List, optional + Dimension indices to retrieve. Default [0]. + **kwargs + Backend-specific options (ignored). + + Returns + ------- + Tuple + Requested dimensions. + """ + result = self.get_data_with_dims(path, group=group, dim_nums=dim_nums, **kwargs) + return result[1:] + + def cleanup(self) -> None: + """Close the DataTree.""" + if self.data_tree is not None: + self.data_tree.close() + self.data_tree = None + - def reconnect(self): - """Reconnect the connection.""" +# Backward compatibility alias +XarrayConnection = ProcessXarrayConnection From e457c5aa6ff285885263f28f16737e097ee3ff2b Mon Sep 17 00:00:00 2001 From: Sameer Chaturvedi Date: Wed, 25 Feb 2026 19:00:58 -0500 Subject: [PATCH 4/6] Update consumers to use DataConnection and ProcessConnection ABCs Migrates all type annotations, imports, and call sites from concrete MDSConnection/XarrayConnection to the ABCs introduced in commits 1-3. MAST physics methods now use the per-shot DataConnection interface (dropping shot_id from call signatures), and get_mdsplus_class() is renamed to get_process_connection() with a deprecated alias. Signed-off-by: Sameer Chaturvedi --- disruption_py/core/physics_method/params.py | 4 +- disruption_py/core/retrieval_manager.py | 25 ++++----- disruption_py/machine/east/util.py | 8 +-- disruption_py/machine/mast/efit.py | 7 ++- disruption_py/machine/mast/physics.py | 60 ++++++++++----------- disruption_py/machine/mast/util.py | 24 ++++----- disruption_py/settings/nickname_setting.py | 8 +-- disruption_py/settings/time_setting.py | 15 +++--- disruption_py/workflow.py | 19 ++++--- 9 files changed, 84 insertions(+), 86 deletions(-) diff --git a/disruption_py/core/physics_method/params.py b/disruption_py/core/physics_method/params.py index 7a067dcaa..b770319ee 100644 --- a/disruption_py/core/physics_method/params.py +++ b/disruption_py/core/physics_method/params.py @@ -11,7 +11,7 @@ from loguru import logger from disruption_py.core.utils.misc import shot_msg_patch, to_tuple -from disruption_py.inout.mds import MDSConnection +from disruption_py.inout.base import DataConnection from disruption_py.machine.tokamak import Tokamak @@ -25,7 +25,7 @@ class PhysicsMethodParams: shot_id: int tokamak: Tokamak disruption_time: float - mds_conn: MDSConnection + mds_conn: DataConnection times: np.ndarray def __post_init__(self): diff --git a/disruption_py/core/retrieval_manager.py b/disruption_py/core/retrieval_manager.py index cfb45b795..7c92954b0 100644 --- a/disruption_py/core/retrieval_manager.py +++ b/disruption_py/core/retrieval_manager.py @@ -11,7 +11,8 @@ from disruption_py.core.physics_method.params import PhysicsMethodParams from disruption_py.core.physics_method.runner import populate_shot from disruption_py.core.utils.misc import shot_msg -from disruption_py.inout.mds import MDSConnection, ProcessMDSConnection, mdsExceptions +from disruption_py.inout.base import DataConnection, ProcessConnection +from disruption_py.inout.mds import MDSConnection, mdsExceptions from disruption_py.inout.sql import ShotDatabase from disruption_py.machine.tokamak import Tokamak from disruption_py.settings.nickname_setting import NicknameSettingParams @@ -29,15 +30,15 @@ class RetrievalManager: The tokamak instance. process_database : ShotDatabase The SQL database - process_mds_conn : ProcessMDSConnection - The MDS connection + process_mds_conn : ProcessConnection + The process-level data connection """ def __init__( self, tokamak: Tokamak, process_database: ShotDatabase, - process_mds_conn: ProcessMDSConnection, + process_mds_conn: ProcessConnection, ): """ Parameters @@ -46,8 +47,8 @@ def __init__( The tokamak instance. process_database : ShotDatabase The SQL database. - process_mds_conn : ProcessMDSConnection - The MDS connection. + process_mds_conn : ProcessConnection + The process-level data connection. """ self.tokamak = tokamak self.process_database = process_database @@ -187,7 +188,7 @@ def shot_cleanup( def setup_physics_method_params( self, shot_id: int, - mds_conn: MDSConnection, + mds_conn: DataConnection, disruption_time: float, retrieval_settings: RetrievalSettings, ) -> PhysicsMethodParams: @@ -198,8 +199,8 @@ def setup_physics_method_params( ---------- shot_id : int The ID of the shot. - mds_conn : MDSConnection - The MDS connection for the shot. + mds_conn : DataConnection + The data connection for the shot. disruption_time : float The disruption time of the shot. retrieval_settings : RetrievalSettings @@ -231,7 +232,7 @@ def setup_physics_method_params( def _init_times( self, shot_id: int, - mds_conn: MDSConnection, + mds_conn: DataConnection, disruption_time: float, retrieval_settings: RetrievalSettings, ) -> np.ndarray: @@ -242,8 +243,8 @@ def _init_times( ---------- shot_id : int The ID of the shot. - mds_conn : MDSConnection - The MDS connection for the shot. + mds_conn : DataConnection + The data connection for the shot. disruption_time : float The disruption time of the shot. retrieval_settings : RetrievalSettings diff --git a/disruption_py/machine/east/util.py b/disruption_py/machine/east/util.py index 08b593e34..d96d8dac4 100644 --- a/disruption_py/machine/east/util.py +++ b/disruption_py/machine/east/util.py @@ -7,7 +7,7 @@ import numpy as np import scipy -from disruption_py.inout.mds import MDSConnection +from disruption_py.inout.base import DataConnection class EastUtilMethods: @@ -42,7 +42,7 @@ def subtract_ip_baseline_offset(ip, ip_time): return ip @staticmethod - def retrieve_ip(mds_conn: MDSConnection, shot_id: int): + def retrieve_ip(mds_conn: DataConnection, shot_id: int): """ Read in the measured plasma current, Ip. There are several different measurements of Ip: IPE, IPG, IPM (all in the EAST tree), and PCRL01 @@ -54,8 +54,8 @@ def retrieve_ip(mds_conn: MDSConnection, shot_id: int): Parameters ---------- - mds_conn : MDSConnection - Connection to MDSplus server. + mds_conn : DataConnection + Data connection for the shot. shot_id : int Shot number. diff --git a/disruption_py/machine/mast/efit.py b/disruption_py/machine/mast/efit.py index e4d637216..a24912c6b 100644 --- a/disruption_py/machine/mast/efit.py +++ b/disruption_py/machine/mast/efit.py @@ -6,7 +6,6 @@ from disruption_py.core.physics_method.decorator import physics_method from disruption_py.core.physics_method.params import PhysicsMethodParams -from disruption_py.inout.xr import XarrayConnection from disruption_py.machine.mast.util import MastUtilMethods from disruption_py.machine.tokamak import Tokamak @@ -51,13 +50,13 @@ def get_efit_parameters(params: PhysicsMethodParams): dict A dictionary containing the retrieved EFIT parameters. """ - conn: XarrayConnection = params.mds_conn - eq_time = conn.get_data(params.shot_id, "equilibrium/time") + conn = params.mds_conn + eq_time = conn.get_data("equilibrium/time") times = params.times outputs = {} for key, prop in MastEfitMethods.efit_properties.items(): - signal = conn.get_data(params.shot_id, f"equilibrium/{prop}") + signal = conn.get_data(f"equilibrium/{prop}") item = MastUtilMethods.interpolate_1d(eq_time, signal, times) outputs[key] = item diff --git a/disruption_py/machine/mast/physics.py b/disruption_py/machine/mast/physics.py index e3ef8031f..86d98b3ee 100644 --- a/disruption_py/machine/mast/physics.py +++ b/disruption_py/machine/mast/physics.py @@ -12,7 +12,6 @@ from disruption_py.core.physics_method.errors import CalculationError from disruption_py.core.physics_method.params import PhysicsMethodParams from disruption_py.core.utils.math import interp1 -from disruption_py.inout.xr import XarrayConnection from disruption_py.machine.mast.util import MastUtilMethods from disruption_py.machine.tokamak import Tokamak @@ -43,11 +42,11 @@ def get_ip_parameters(params: PhysicsMethodParams): A dictionary containing plasma current (`ip`), its time derivative (`dip_dt`), programmed plasma current (`ip_prog`), and its time derivative (`dipprog_dt`). """ - conn: XarrayConnection = params.mds_conn - ip = conn.get_data(params.shot_id, "summary/ip") - ip_prog = conn.get_data(params.shot_id, "pulse_schedule/i_plasma") - ip_prog_time = conn.get_data(params.shot_id, "pulse_schedule/time") - magtime = conn.get_data(params.shot_id, "summary/time") + conn = params.mds_conn + ip = conn.get_data("summary/ip") + ip_prog = conn.get_data("pulse_schedule/i_plasma") + ip_prog_time = conn.get_data("pulse_schedule/time") + magtime = conn.get_data("summary/time") dip_dt = np.gradient(ip, magtime) dipprog_dt = np.gradient(ip_prog, ip_prog_time) @@ -85,11 +84,11 @@ def get_power(params: PhysicsMethodParams): A dictionary containing neutral beam injection power (`power_nbi`) and radiated power (`power_radiated`). """ - conn: XarrayConnection = params.mds_conn + conn = params.mds_conn - power_nbi = conn.get_data(params.shot_id, "summary/power_nbi") - power_radiated = conn.get_data(params.shot_id, "summary/power_radiated") - base_time = conn.get_data(params.shot_id, "summary/time") + power_nbi = conn.get_data("summary/power_nbi") + power_radiated = conn.get_data("summary/power_radiated") + base_time = conn.get_data("summary/time") times = params.times power_nbi = MastUtilMethods.interpolate_1d(base_time, power_nbi, times) @@ -117,12 +116,12 @@ def get_gas(params: PhysicsMethodParams): A dictionary containing total injected gas (`total_injected`), inboard total gas (`inboard_total`), and outboard total gas (`outboard_total`). """ - conn: XarrayConnection = params.mds_conn + conn = params.mds_conn - total_injected = conn.get_data(params.shot_id, "gas_injection/total_injected") - inboard_total = conn.get_data(params.shot_id, "gas_injection/inboard_total") - outboard_total = conn.get_data(params.shot_id, "gas_injection/outboard_total") - base_time = conn.get_data(params.shot_id, "gas_injection/time") + total_injected = conn.get_data("gas_injection/total_injected") + inboard_total = conn.get_data("gas_injection/inboard_total") + outboard_total = conn.get_data("gas_injection/outboard_total") + base_time = conn.get_data("gas_injection/time") times = params.times total_injected = MastUtilMethods.interpolate_1d( @@ -158,11 +157,11 @@ def get_ts_parameters(params: PhysicsMethodParams): core electron density (`n_e_core`). """ times = params.times - conn: XarrayConnection = params.mds_conn + conn = params.mds_conn - t_e_core = conn.get_data(params.shot_id, "thomson_scattering/t_e_core") - n_e_core = conn.get_data(params.shot_id, "thomson_scattering/n_e_core") - base_time = conn.get_data(params.shot_id, "thomson_scattering/time") + t_e_core = conn.get_data("thomson_scattering/t_e_core") + n_e_core = conn.get_data("thomson_scattering/n_e_core") + base_time = conn.get_data("thomson_scattering/time") t_e_core = MastUtilMethods.interpolate_1d(base_time, t_e_core, times) n_e_core = MastUtilMethods.interpolate_1d(base_time, n_e_core, times) @@ -201,13 +200,13 @@ def get_densities(params: PhysicsMethodParams): """ - conn: XarrayConnection = params.mds_conn - n_e = conn.get_data(params.shot_id, "summary/line_average_n_e") - t_n = conn.get_data(params.shot_id, "summary/time") - ip = conn.get_data(params.shot_id, "summary/ip") - t_ip = conn.get_data(params.shot_id, "summary/time") - a_minor = conn.get_data(params.shot_id, "equilibrium/minor_radius") - t_a = conn.get_data(params.shot_id, "equilibrium/time") + conn = params.mds_conn + n_e = conn.get_data("summary/line_average_n_e") + t_n = conn.get_data("summary/time") + ip = conn.get_data("summary/ip") + t_ip = conn.get_data("summary/time") + a_minor = conn.get_data("equilibrium/minor_radius") + t_a = conn.get_data("equilibrium/time") return MastPhysicsMethods._get_densities( params.times, n_e, t_n, ip, t_ip, a_minor, t_a @@ -276,10 +275,8 @@ def get_sxr(params: PhysicsMethodParams): A dictionary containing SXR data (`sxr_data`) and corresponding time points (`sxr_time`). """ - conn: XarrayConnection = params.mds_conn - hcam = conn.get_data( - params.shot_id, "soft_x_rays/horizontal_cam_upper", return_xarray=True - ) + conn = params.mds_conn + hcam = conn.get_data("soft_x_rays/horizontal_cam_upper", return_xarray=True) if hcam is not None: hcam = hcam.isel(horizontal_cam_upper_channel=7) @@ -314,10 +311,9 @@ def get_dalpha(params: PhysicsMethodParams): dict A dictionary containing D-alpha signal data (`dalpha`). """ - conn: XarrayConnection = params.mds_conn + conn = params.mds_conn dalpha = conn.get_data( - params.shot_id, "spectrometer_visible/filter_spectrometer_dalpha_voltage", return_xarray=True, ) diff --git a/disruption_py/machine/mast/util.py b/disruption_py/machine/mast/util.py index e4263fd2e..e9ec8b280 100644 --- a/disruption_py/machine/mast/util.py +++ b/disruption_py/machine/mast/util.py @@ -7,7 +7,7 @@ import numpy as np from disruption_py.core.utils.math import interp1 -from disruption_py.inout.xr import XarrayConnection +from disruption_py.inout.base import DataConnection class MastUtilMethods: @@ -17,44 +17,40 @@ class MastUtilMethods: """ @staticmethod - def retrieve_ip(conn: XarrayConnection, shot_id: int): + def retrieve_ip(conn: DataConnection): """ Read in the measured plasma current, Ip. Parameters ---------- - conn : XarrayConnection - Connection to S3 bucket. - shot_id : int - Shot number. + conn : DataConnection + Per-shot data connection. Returns ------- tuple[np.ndarray, np.ndarray] Plasma current [A], time base of plasma current [s]. """ - ip = conn.get_data(shot_id, "summary/ip") # ensure data is cached - ip_time = conn.get_data(shot_id, "summary/time") + ip = conn.get_data("summary/ip") + ip_time = conn.get_data("summary/time") return ip, ip_time @staticmethod - def retrieve_efit_time(conn: XarrayConnection, shot_id: int): + def retrieve_efit_time(conn: DataConnection): """ Read in the EFIT time base. Parameters ---------- - conn : XarrayConnection - Connection to S3 bucket. - shot_id : int - Shot number. + conn : DataConnection + Per-shot data connection. Returns ------- np.ndarray EFIT time base [s]. """ - efit_time = conn.get_data(shot_id, "equilibrium/time") + efit_time = conn.get_data("equilibrium/time") return efit_time @staticmethod diff --git a/disruption_py/settings/nickname_setting.py b/disruption_py/settings/nickname_setting.py index ae9aa5ea6..845976ad3 100644 --- a/disruption_py/settings/nickname_setting.py +++ b/disruption_py/settings/nickname_setting.py @@ -12,7 +12,7 @@ from disruption_py.config import config from disruption_py.core.utils.enums import map_string_to_enum -from disruption_py.inout.mds import MDSConnection +from disruption_py.inout.base import DataConnection from disruption_py.inout.sql import ShotDatabase from disruption_py.machine.tokamak import Tokamak @@ -30,8 +30,8 @@ class NicknameSettingParams: ---------- shot_id : int The shot ID for which to resolve nicknames. - mds_conn : MDSConnection - MDSConnection object for accessing MDSPlus data. + mds_conn : DataConnection + Data connection for the shot. database : ShotDatabase Database connection for querying tokamak shot data. disruption_time : float @@ -41,7 +41,7 @@ class NicknameSettingParams: """ shot_id: int - mds_conn: MDSConnection + mds_conn: DataConnection database: ShotDatabase disruption_time: float tokamak: Tokamak diff --git a/disruption_py/settings/time_setting.py b/disruption_py/settings/time_setting.py index 53cc060c8..d4e346e55 100644 --- a/disruption_py/settings/time_setting.py +++ b/disruption_py/settings/time_setting.py @@ -15,7 +15,8 @@ from disruption_py.config import config from disruption_py.core.utils.enums import map_string_to_enum from disruption_py.core.utils.misc import shot_msg_patch -from disruption_py.inout.mds import MDSConnection, mdsExceptions +from disruption_py.inout.base import DataConnection +from disruption_py.inout.mds import mdsExceptions from disruption_py.inout.sql import ShotDatabase from disruption_py.machine.east.util import EastUtilMethods from disruption_py.machine.mast.util import MastUtilMethods @@ -31,8 +32,8 @@ class TimeSettingParams: ---------- shot_id : int Shot ID for the timebase being created. - mds_conn : MDSConnection - Connection to MDSPlus for retrieving MDSPlus data. + mds_conn : DataConnection + Data connection for the shot. database : ShotDatabase Database object with connection to the SQL database. disruption_time : float @@ -42,7 +43,7 @@ class TimeSettingParams: """ shot_id: int - mds_conn: MDSConnection + mds_conn: DataConnection database: ShotDatabase disruption_time: float tokamak: Tokamak @@ -283,7 +284,7 @@ def mast_times(self, params: TimeSettingParams) -> np.ndarray: np.ndarray Array of times in the timebase. """ - efit_time = MastUtilMethods.retrieve_efit_time(params.mds_conn, params.shot_id) + efit_time = MastUtilMethods.retrieve_efit_time(params.mds_conn) return efit_time @@ -392,7 +393,7 @@ def mast_times(self, params: TimeSettingParams) -> np.ndarray: np.ndarray Array of times in the timebase. """ - ip, ip_time = MastUtilMethods.retrieve_ip(params.mds_conn, params.shot_id) + ip, ip_time = MastUtilMethods.retrieve_ip(params.mds_conn) return self._calculate_disruption_times(params, ip, ip_time) @classmethod @@ -639,7 +640,7 @@ def mast_times(self, params: TimeSettingParams) -> np.ndarray: np.ndarray Array of times in the timebase. """ - _, ip_time = MastUtilMethods.retrieve_ip(params.mds_conn, params.shot_id) + _, ip_time = MastUtilMethods.retrieve_ip(params.mds_conn) return ip_time diff --git a/disruption_py/workflow.py b/disruption_py/workflow.py index 5de284bbc..30bf310f5 100644 --- a/disruption_py/workflow.py +++ b/disruption_py/workflow.py @@ -24,9 +24,10 @@ get_temporary_folder, without_duplicates, ) +from disruption_py.inout.base import ProcessConnection from disruption_py.inout.mds import ProcessMDSConnection from disruption_py.inout.sql import ShotDatabase -from disruption_py.inout.xr import XarrayConnection +from disruption_py.inout.xr import ProcessXarrayConnection from disruption_py.machine.tokamak import Tokamak, resolve_tokamak_from_environment from disruption_py.settings import RetrievalSettings from disruption_py.settings.log_settings import LogSettings, resolve_log_settings @@ -73,7 +74,7 @@ def get_shots_data( shotlist_setting: ShotlistSettingType, tokamak: Tokamak = None, database_initializer: Callable[..., ShotDatabase] = None, - mds_connection_initializer: Callable[..., ProcessMDSConnection] = None, + mds_connection_initializer: Callable[..., ProcessConnection] = None, retrieval_settings: RetrievalSettings = None, output_setting: OutputSetting = "dataset", num_processes: int = 1, @@ -216,11 +217,11 @@ def get_database( return ShotDatabase.from_config(tokamak=tokamak) -def get_mdsplus_class( +def get_process_connection( tokamak: Tokamak = None, -) -> ProcessMDSConnection | XarrayConnection: +) -> ProcessConnection: """ - Get the MDSplus connection for the tokamak. + Get the process-level data connection for the tokamak. """ tokamak = resolve_tokamak_from_environment(tokamak) @@ -229,11 +230,15 @@ def get_mdsplus_class( return ProcessMDSConnection.from_config(tokamak=tokamak) if "xarray" in inout_cfg: - return XarrayConnection.from_config(tokamak=tokamak) + return ProcessXarrayConnection.from_config(tokamak=tokamak) raise ValueError("No valid MDSplus or xarray connection found.") +# Deprecated alias +get_mdsplus_class = get_process_connection + + def _get_database_instance(tokamak, database_initializer): """ Create database instance @@ -249,7 +254,7 @@ def _get_mds_instance(tokamak, mds_connection_initializer): """ if mds_connection_initializer: return mds_connection_initializer() - return get_mdsplus_class(tokamak) + return get_process_connection(tokamak) def run(tokamak, methods, shots, efit_tree, time_base, output, processes, log_level): From f70ae599ff6047a5e0bae4f229190b12293055e6 Mon Sep 17 00:00:00 2001 From: gtrevisan Date: Thu, 5 Mar 2026 09:45:38 -0500 Subject: [PATCH 5/6] add shebangs --- disruption_py/inout/__init__.py | 2 ++ disruption_py/inout/base.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/disruption_py/inout/__init__.py b/disruption_py/inout/__init__.py index 125615ceb..83fa18533 100644 --- a/disruption_py/inout/__init__.py +++ b/disruption_py/inout/__init__.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + """Data connection abstractions and implementations.""" from disruption_py.inout.base import DataConnection, ProcessConnection diff --git a/disruption_py/inout/base.py b/disruption_py/inout/base.py index da4f72d36..650aa8607 100644 --- a/disruption_py/inout/base.py +++ b/disruption_py/inout/base.py @@ -1,4 +1,7 @@ -"""Abstract base classes for data connections. +#!/usr/bin/env python3 + +""" +Abstract base classes for data connections. DataConnection: per-shot data access (get_data, get_data_with_dims, get_dims). ProcessConnection: per-process factory that creates DataConnection instances. From 2f9a71f52447bf06fc22ff06887f31a78cae8a6a Mon Sep 17 00:00:00 2001 From: Sameer Chaturvedi Date: Thu, 12 Mar 2026 16:25:34 -0400 Subject: [PATCH 6/6] Address PR #523 review feedback - Drop deprecated aliases (XarrayConnection, get_mdsplus_class) - Restore machine-specific type hints (MDSConnection for EAST, XarrayDataConnection for MAST) - Fix ProcessXarrayConnection.__init__ to initialize attrs before early return - Update XarrayDataConnection.get_data return type annotation to match actual behavior - Update examples/mdsplus.py to use get_process_connection Signed-off-by: Sameer Chaturvedi --- disruption_py/inout/xr.py | 16 +++++++--------- disruption_py/machine/east/util.py | 8 ++++---- disruption_py/machine/mast/util.py | 14 +++++++------- disruption_py/workflow.py | 4 ---- examples/mdsplus.py | 4 ++-- 5 files changed, 20 insertions(+), 26 deletions(-) diff --git a/disruption_py/inout/xr.py b/disruption_py/inout/xr.py index 6d9ad7e9a..d0dfc81da 100644 --- a/disruption_py/inout/xr.py +++ b/disruption_py/inout/xr.py @@ -31,13 +31,13 @@ def __init__( file_ext: str = "zarr", endpoint_url: str | None = None, ): - if file_path is None: - return - self.endpoint_url = endpoint_url self.file_path = file_path self.file_ext = file_ext + if file_path is None: + return + logger.debug( "PID #{pid} | Connecting to Xarray store: {server}", server=endpoint_url, @@ -104,7 +104,7 @@ def _resolve_path(self, path: str, group: str = None) -> str: def get_data( self, path: str, group: str = None, return_xarray: bool = False, **kwargs - ) -> np.ndarray: + ) -> "np.ndarray | xr.DataArray | None": """Get data from the connection. Parameters @@ -120,7 +120,9 @@ def get_data( Returns ------- - np.ndarray or xr.DataArray + np.ndarray or xr.DataArray or None + numpy array by default, xr.DataArray if return_xarray=True, + or None if variable not found and return_xarray=True. """ path = self._resolve_path(path, group) logger.trace(shot_msg("Getting data: {path}"), shot=self._shot_id, path=path) @@ -218,7 +220,3 @@ def cleanup(self) -> None: if self.data_tree is not None: self.data_tree.close() self.data_tree = None - - -# Backward compatibility alias -XarrayConnection = ProcessXarrayConnection diff --git a/disruption_py/machine/east/util.py b/disruption_py/machine/east/util.py index d96d8dac4..743295ffd 100644 --- a/disruption_py/machine/east/util.py +++ b/disruption_py/machine/east/util.py @@ -7,7 +7,7 @@ import numpy as np import scipy -from disruption_py.inout.base import DataConnection +from disruption_py.inout.mds import MDSConnection class EastUtilMethods: @@ -42,7 +42,7 @@ def subtract_ip_baseline_offset(ip, ip_time): return ip @staticmethod - def retrieve_ip(mds_conn: DataConnection, shot_id: int): + def retrieve_ip(mds_conn: MDSConnection, shot_id: int): """ Read in the measured plasma current, Ip. There are several different measurements of Ip: IPE, IPG, IPM (all in the EAST tree), and PCRL01 @@ -54,8 +54,8 @@ def retrieve_ip(mds_conn: DataConnection, shot_id: int): Parameters ---------- - mds_conn : DataConnection - Data connection for the shot. + mds_conn : MDSConnection + MDSplus connection for the shot. shot_id : int Shot number. diff --git a/disruption_py/machine/mast/util.py b/disruption_py/machine/mast/util.py index e9ec8b280..faba5e944 100644 --- a/disruption_py/machine/mast/util.py +++ b/disruption_py/machine/mast/util.py @@ -7,7 +7,7 @@ import numpy as np from disruption_py.core.utils.math import interp1 -from disruption_py.inout.base import DataConnection +from disruption_py.inout.xr import XarrayDataConnection class MastUtilMethods: @@ -17,14 +17,14 @@ class MastUtilMethods: """ @staticmethod - def retrieve_ip(conn: DataConnection): + def retrieve_ip(conn: XarrayDataConnection): """ Read in the measured plasma current, Ip. Parameters ---------- - conn : DataConnection - Per-shot data connection. + conn : XarrayDataConnection + Per-shot Xarray data connection. Returns ------- @@ -36,14 +36,14 @@ def retrieve_ip(conn: DataConnection): return ip, ip_time @staticmethod - def retrieve_efit_time(conn: DataConnection): + def retrieve_efit_time(conn: XarrayDataConnection): """ Read in the EFIT time base. Parameters ---------- - conn : DataConnection - Per-shot data connection. + conn : XarrayDataConnection + Per-shot Xarray data connection. Returns ------- diff --git a/disruption_py/workflow.py b/disruption_py/workflow.py index 30bf310f5..358d3e772 100644 --- a/disruption_py/workflow.py +++ b/disruption_py/workflow.py @@ -235,10 +235,6 @@ def get_process_connection( raise ValueError("No valid MDSplus or xarray connection found.") -# Deprecated alias -get_mdsplus_class = get_process_connection - - def _get_database_instance(tokamak, database_initializer): """ Create database instance diff --git a/examples/mdsplus.py b/examples/mdsplus.py index 601229fba..9e165ba04 100644 --- a/examples/mdsplus.py +++ b/examples/mdsplus.py @@ -7,7 +7,7 @@ import pytest from disruption_py.machine.tokamak import Tokamak, resolve_tokamak_from_environment -from disruption_py.workflow import get_mdsplus_class +from disruption_py.workflow import get_process_connection def main(): @@ -41,7 +41,7 @@ def main(): else: raise ValueError(f"Unspecified or unsupported tokamak: {tokamak}.") - mds = get_mdsplus_class(tokamak).conn + mds = get_process_connection(tokamak).conn print(f"Initialized MDSplus: {mds.hostspec}") mds.openTree(tree, shot)