diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/bmi_model.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/bmi_model.py index 6858056a..3a6f0174 100755 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/bmi_model.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/bmi_model.py @@ -24,15 +24,17 @@ from mpi4py import MPI from NextGen_Forcings_Engine_BMI import esmf_creation, forcing_extraction +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.config import ( + ConfigOptions, +) +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.geoMod import GEOGRID +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig from .bmi_grid import Grid, GridType from .core import ( - config, err_handler, forcingInputMod, - geoMod, ioMod, - parallel, suppPrecipMod, ) from .model import NWMv3ForcingEngineModel @@ -207,7 +209,7 @@ def initialize(self, config_file: str, output_path: str | None = None) -> None: # If _job_meta was not set by initialize_with_params(), create a default one if self._job_meta is None: - self._job_meta = config.ConfigOptions(self.cfg_bmi) + self._job_meta = ConfigOptions(self.cfg_bmi) # Parse the configuration options try: @@ -231,7 +233,7 @@ def initialize(self, config_file: str, output_path: str | None = None) -> None: self._job_meta.nwmConfig = self.cfg_bmi["NWM_CONFIG"] # Initialize MPI communication - self._mpi_meta = parallel.MpiConfig() + self._mpi_meta = MpiConfig() try: comm = MPI.Comm.f2py(self._comm) if self._comm is not None else None self._mpi_meta.initialize_comm(self._job_meta, comm=comm) @@ -252,24 +254,14 @@ def initialize(self, config_file: str, output_path: str | None = None) -> None: # information about the modeling domain, local processor # grid boundaries, and ESMF grid objects/fields to be used # in regridding. - self._wrf_hydro_geo_meta = geoMod.GeoMetaWrfHydro() - - if self._job_meta.grid_type == "gridded": - self._wrf_hydro_geo_meta.initialize_destination_geo_gridded( - self._job_meta, self._mpi_meta - ) - elif self._job_meta.grid_type == "unstructured": - self._wrf_hydro_geo_meta.initialize_destination_geo_unstructured( - self._job_meta, self._mpi_meta - ) - elif self._job_meta.grid_type == "hydrofabric": - self._wrf_hydro_geo_meta.initialize_destination_geo_hydrofabric( - self._job_meta, self._mpi_meta - ) - else: - self._job_meta.errMsg = "You must specify a proper grid_type (gridded, unstructured, hydrofabric) in the config." + if self._job_meta.grid_type not in GEOGRID: + self._job_meta.errMsg = f"Invalid grid type specified: {self._job_meta.grid_type}. Valid options are: {list(GEOGRID.keys())}" err_handler.err_out_screen_para(self._job_meta.errMsg, self._mpi_meta) + self._wrf_hydro_geo_meta = GEOGRID.get(self._job_meta.grid_type)( + self._job_meta, self._mpi_meta + ) + # Assign grid type to BMI class for grid information self._grid_type = self._job_meta.grid_type.lower() @@ -759,15 +751,6 @@ def initialize(self, config_file: str, output_path: str | None = None) -> None: for long_name in self._var_name_units_map.keys() } - if self._job_meta.spatial_meta is not None: - try: - self._wrf_hydro_geo_meta.initialize_geospatial_metadata( - self._job_meta, self._mpi_meta - ) - except Exception as e: - err_handler.err_out_screen_para(self._job_meta.errMsg, self._mpi_meta) - err_handler.check_program_status(self._job_meta, self._mpi_meta) - # Check to make sure we have enough dimensionality to run regridding. ESMF requires both grids # to have a size of at least 2. if ( @@ -897,9 +880,7 @@ def initialize_with_params( :raises ValueError: If an invalid grid type is specified, an exception is raised. """ # Set the job metadata parameters (b_date, geogrid) using config_options - self._job_meta = config.ConfigOptions( - self.cfg_bmi, b_date=b_date, geogrid_arg=geogrid - ) + self._job_meta = ConfigOptions(self.cfg_bmi, b_date=b_date, geogrid_arg=geogrid) # Now that _job_meta is set, call initialize() to set up the core model self.initialize(config_file, output_path=output_path) diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/consts.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/consts.py new file mode 100644 index 00000000..cc54e8d9 --- /dev/null +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/consts.py @@ -0,0 +1,88 @@ +CONSTS = { + "geoMod": { + "GeoMeta": [ + # "nx_global", + # "ny_global", + # "nx_global_elem", + # "ny_global_elem", + # "dx_meters", + # "dy_meters", + # "latitude_grid", + # "longitude_grid", + # "element_ids", + # "element_ids_global", + # "latitude_grid_elem", + # "longitude_grid_elem", + # "lat_bounds", + # "lon_bounds", + # "mesh_inds", + # "mesh_inds_elem", + # "height", + # "height_elem", + # "sina_grid", + # "cosa_grid", + "nodeCoords", + "centerCoords", + "inds", + # "slope", + # "slp_azi", + # "slope_elem", + # "slp_azi_elem", + # "esmf_grid", + "esmf_lat", + "esmf_lon", + ], + "handle_exception": { + "esmf_nc": "Unable to open spatial metadata file: :::arg:::", + }, + "UnstructuredGeoMeta": [ + "x_lower_bound", + "x_upper_bound", + "y_lower_bound", + "y_upper_bound", + "dx_meters", + "dy_meters", + "element_ids", + "element_ids_global", + "sina_grid", + "cosa_grid", + "esmf_lat", + "esmf_lon", + ], + "HydrofabricGeoMeta": [ + "nx_local_elem", + "ny_local_elem", + "x_lower_bound", + "x_upper_bound", + "y_lower_bound", + "y_upper_bound", + "nx_global_elem", + "ny_global_elem", + "dx_meters", + "dy_meters", + "mesh_inds_elem", + "height_elem", + "sina_grid", + "cosa_grid", + "slope_elem", + "slp_azi_elem", + "esmf_lat", + "esmf_lon", + ], + "GriddedGeoMeta": [ + "nx_local_elem", + "ny_local_elem", + "nx_global_elem", + "ny_global_elem", + "element_ids", + "element_ids_global", + "lat_bounds", + "lon_bounds", + "mesh_inds", + "mesh_inds_elem", + "height_elem", + "slope_elem", + "slp_azi_elem", + ], + }, +} diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/forcingInputMod.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/forcingInputMod.py index 9931100b..5e94f65d 100755 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/forcingInputMod.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/forcingInputMod.py @@ -12,7 +12,7 @@ ConfigOptions, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.geoMod import ( - GeoMetaWrfHydro, + GeoMeta, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig from nextgen_forcings_ewts import MODULE_NAME @@ -963,7 +963,7 @@ def regrid_map(self): def regrid_inputs( self, config_options: ConfigOptions, - wrf_hyro_geo_meta: GeoMetaWrfHydro, + wrf_hyro_geo_meta: GeoMeta, mpi_config: MpiConfig, ): """Regrid input forcings to the final output grids for this timestep. @@ -1012,7 +1012,7 @@ def temporal_interpolate_inputs( def init_dict( config_options: ConfigOptions, - geo_meta_wrf_hydro: GeoMetaWrfHydro, + geo_meta_wrf_hydro: GeoMeta, mpi_config: MpiConfig, ) -> dict: """Initialize the input forcing dictionary. diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/geoMod.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/geoMod.py index 8902ebbe..0bfe63e3 100755 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/geoMod.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/geoMod.py @@ -1,7 +1,6 @@ import math -from time import time +from pathlib import Path -import netCDF4 import numpy as np # For ESMF + shapely 2.x, shapely must be imported first, to avoid segfault "address not mapped to object" stemming from calls such as: @@ -9,1360 +8,1459 @@ import shapely from scipy import spatial -from .. import esmf_utils, nc_utils -from . import err_handler - try: import esmpy as ESMF except ImportError: import ESMF import logging +from functools import lru_cache, wraps +from typing import Any + +import xarray as xr +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.config import ( + ConfigOptions, +) +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.consts import CONSTS +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig from nextgen_forcings_ewts import MODULE_NAME LOG = logging.getLogger(MODULE_NAME) +CONSTS = CONSTS[Path(__file__).stem] -class GeoMetaWrfHydro: - """Abstract class for handling information about the WRF-Hydro domain we are processing forcings too.""" - - def __init__(self): - """Initialize GeoMetaWrfHydro class variables.""" - self.nx_global = None - self.ny_global = None - self.nx_global_elem = None - self.ny_global_elem = None - self.dx_meters = None - self.dy_meters = None - self.nx_local = None - self.ny_local = None - self.nx_local_elem = None - self.ny_local_elem = None - self.x_lower_bound = None - self.x_upper_bound = None - self.y_lower_bound = None - self.y_upper_bound = None - self.latitude_grid = None - self.longitude_grid = None - self.element_ids = None - self.element_ids_global = None - self.latitude_grid_elem = None - self.longitude_grid_elem = None - self.lat_bounds = None - self.lon_bounds = None - self.mesh_inds = None - self.mesh_inds_elem = None - self.height = None - self.height_elem = None - self.sina_grid = None - self.cosa_grid = None - self.nodeCoords = None - self.centerCoords = None - self.inds = None - self.slope = None - self.slp_azi = None - self.slope_elem = None - self.slp_azi_elem = None - self.esmf_grid = None - self.esmf_lat = None - self.esmf_lon = None - self.crs_atts = None - self.x_coord_atts = None - self.x_coords = None - self.y_coord_atts = None - self.y_coords = None - self.spatial_global_atts = None - - def get_processor_bounds(self, config_options): - """Calculate the local grid boundaries for this processor. - - ESMF operates under the hood and the boundary values - are calculated within the ESMF software. - :return: - """ - if config_options.grid_type == "gridded": - self.x_lower_bound = self.esmf_grid.lower_bounds[ESMF.StaggerLoc.CENTER][1] - self.x_upper_bound = self.esmf_grid.upper_bounds[ESMF.StaggerLoc.CENTER][1] - self.y_lower_bound = self.esmf_grid.lower_bounds[ESMF.StaggerLoc.CENTER][0] - self.y_upper_bound = self.esmf_grid.upper_bounds[ESMF.StaggerLoc.CENTER][0] - self.nx_local = self.x_upper_bound - self.x_lower_bound - self.ny_local = self.y_upper_bound - self.y_lower_bound - elif config_options.grid_type == "unstructured": - self.nx_local = len(self.esmf_grid.coords[0][1]) - self.ny_local = len(self.esmf_grid.coords[0][1]) - self.nx_local_elem = len(self.esmf_grid.coords[1][1]) - self.ny_local_elem = len(self.esmf_grid.coords[1][1]) - # LOG.debug("ESMF Mesh nx local node is " + str(self.nx_local)) - # LOG.debug("ESMF Mesh nx local elem is " + str(self.nx_local_elem)) - elif config_options.grid_type == "hydrofabric": - self.nx_local = len(self.esmf_grid.coords[1][1]) - self.ny_local = len(self.esmf_grid.coords[1][1]) - # self.nx_local_poly = len(self.esmf_poly_coords) - # self.ny_local_poly = len(self.esmf_poly_coords) - # LOG.debug("ESMF Mesh nx local elem is " + str(self.nx_local)) - # LOG.debug("ESMF Mesh nx local poly is " + str(self.nx_local_poly)) - # LOG.debug("WRF-HYDRO LOCAL X BOUND 1 = " + str(self.x_lower_bound)) - # LOG.debug("WRF-HYDRO LOCAL X BOUND 2 = " + str(self.x_upper_bound)) - # LOG.debug("WRF-HYDRO LOCAL Y BOUND 1 = " + str(self.y_lower_bound)) - # LOG.debug("WRF-HYDRO LOCAL Y BOUND 2 = " + str(self.y_upper_bound)) - - def initialize_destination_geo_gridded(self, config_options, mpi_config): - """Initialize GeoMetaWrfHydro class variables. +def set_none(func) -> Any: + """Set the output of a function to None if an exception is raised.""" - Initialization function to initialize ESMF through ESMPy, - calculate the global parameters of the WRF-Hydro grid - being processed to, along with the local parameters - for this particular processor. - :return: - """ - # Open the geogrid file and extract necessary information - # to create ESMF fields. - if mpi_config.rank == 0: - try: - idTmp = netCDF4.Dataset(config_options.geogrid, "r") - except Exception as e: - config_options.errMsg = ( - "Unable to open the WRF-Hydro geogrid file: " - + config_options.geogrid - ) - raise Exception - if idTmp.variables[config_options.lat_var].ndim == 3: - try: - self.nx_global = idTmp.variables[config_options.lat_var].shape[2] - except Exception as e: - config_options.errMsg = ( - "Unable to extract X dimension size from latitude variable in: " - + config_options.geogrid - ) - raise Exception + @wraps(func) + def wrapper(self) -> Any: + """Set the output of a function to None if an exception is raised.""" + if self.spatial_metadata_exists: + return func(self) + else: + return None - try: - self.ny_global = idTmp.variables[config_options.lat_var].shape[1] - except Exception as e: - config_options.errMsg = ( - "Unable to extract Y dimension size from latitude in: " - + config_options.geogrid - ) - raise Exception + return wrapper - try: - self.dx_meters = idTmp.DX - except Exception as e: - config_options.errMsg = ( - "Unable to extract DX global attribute in: " - + config_options.geogrid - ) - raise Exception - try: - self.dy_meters = idTmp.DY - except Exception as e: - config_options.errMsg = ( - "Unable to extract DY global attribute in: " - + config_options.geogrid - ) - raise Exception - elif idTmp.variables[config_options.lat_var].ndim == 2: - try: - self.nx_global = idTmp.variables[config_options.lat_var].shape[1] - except Exception as e: - config_options.errMsg = ( - "Unable to extract X dimension size from latitude variable in: " - + config_options.geogrid - ) - raise Exception +def broadcast(prop) -> Any: + """Broadcast the output of a function to all processors.""" - try: - self.ny_global = idTmp.variables[config_options.lat_var].shape[0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract Y dimension size from latitude in: " - + config_options.geogrid - ) - raise Exception + @wraps(prop) + def wrapper(self) -> Any: + """Broadcast the output of a function to all processors.""" + result = prop.fget(self) + return self.mpi_config.comm.bcast(result, root=0) - try: - self.dx_meters = idTmp.variables[config_options.lon_var].dx - except Exception as e: - config_options.errMsg = ( - "Unable to extract DX global attribute in: " - + config_options.geogrid - ) - raise Exception + return property(wrapper) - try: - self.dy_meters = idTmp.variables[config_options.lat_var].dy - except Exception as e: - config_options.errMsg = ( - "Unable to extract DY global attribute in: " - + config_options.geogrid - ) - raise Exception - else: - try: - self.nx_global = idTmp.variables[config_options.lon_var].shape[0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract X dimension size from longitude variable in: " - + config_options.geogrid - ) - raise Exception +def barrier(prop) -> Any: + """Synchronize all processors at a barrier.""" - try: - self.ny_global = idTmp.variables[config_options.lat_var].shape[0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract Y dimension size from latitude in: " - + config_options.geogrid - ) - raise Exception - if config_options.input_forcings[0] != 23: - try: - self.dx_meters = idTmp.variables[config_options.lon_var].dx - except Exception as e: - config_options.errMsg = ( - "Unable to extract dx metadata attribute in: " - + config_options.geogrid - ) - raise Exception - - try: - self.dy_meters = idTmp.variables[config_options.lat_var].dy - except Exception as e: - config_options.errMsg = ( - "Unable to extract dy metadata attribute in: " - + config_options.geogrid - ) - raise Exception - else: - # Manually input the grid spacing since ERA5-Interim does not - # internally have this geospatial information within the netcdf file - self.dx_meters = 31000 - self.dy_meters = 31000 + @wraps(prop) + def wrapper(self) -> Any: + """Synchronize all processors at a barrier.""" + result = prop.fget(self) + self.mpi_config.comm.barrier() + return result - # mpi_config.comm.barrier() + return property(wrapper) - # Broadcast global dimensions to the other processors. - self.nx_global = mpi_config.broadcast_parameter( - self.nx_global, config_options, param_type=int - ) - self.ny_global = mpi_config.broadcast_parameter( - self.ny_global, config_options, param_type=int - ) - self.dx_meters = mpi_config.broadcast_parameter( - self.dx_meters, config_options, param_type=float - ) - self.dy_meters = mpi_config.broadcast_parameter( - self.dy_meters, config_options, param_type=float - ) - # mpi_config.comm.barrier() +def scatter(prop) -> Any: + """Scatter the output of a function to all processors.""" + @wraps(prop) + def wrapper(self) -> Any: + """Scatter the output of a function to all processors.""" try: - self.esmf_grid = ESMF.Grid( - np.array([self.ny_global, self.nx_global]), - staggerloc=ESMF.StaggerLoc.CENTER, - coord_sys=ESMF.CoordSys.SPH_DEG, - ) + var, name, config_options, post_slice = prop.fget(self) + var = self.mpi_config.scatter_array(self, var, config_options) + if post_slice: + return var[:, :] + else: + return var except Exception as e: - config_options.errMsg = ( - "Unable to create ESMF grid for WRF-Hydro geogrid: " - + config_options.geogrid + self.config_options.errMsg = ( + f"Unable to subset {name} from geogrid file into ESMF object" ) - raise Exception + raise e - # mpi_config.comm.barrier() + return property(wrapper) - self.esmf_lat = self.esmf_grid.get_coords(1) - self.esmf_lon = self.esmf_grid.get_coords(0) - # mpi_config.comm.barrier() +class GeoMeta: + """GeoMeta class for handling information about the geometry metadata. - # Obtain the local boundaries for this processor. - self.get_processor_bounds(config_options) + Extract names of variable attributes from each of the input geospatial variables. These + can change, so we are making this as flexible as possible to accomodate future changes. + """ - # Scatter global XLAT_M grid to processors.. - if mpi_config.rank == 0: - if idTmp.variables[config_options.lat_var].ndim == 3: - varTmp = idTmp.variables[config_options.lat_var][0, :, :] - elif idTmp.variables[config_options.lat_var].ndim == 2: - varTmp = idTmp.variables[config_options.lat_var][:, :] - elif idTmp.variables[config_options.lat_var].ndim == 1: - lat = idTmp.variables[config_options.lat_var][:] - lon = idTmp.variables[config_options.lon_var][:] - varTmp = np.meshgrid(lon, lat)[1] - lat = None - lon = None - # Flag to grab entire array for AWS slicing - if config_options.aws: - self.lat_bounds = varTmp - else: - varTmp = None + def __init__(self, config_options: ConfigOptions, mpi_config: MpiConfig) -> None: + """Initialize GeoMeta class variables.""" + self.config_options = config_options + self.mpi_config = mpi_config + for attr in CONSTS[self.__class__.__base__.__name__]: + setattr(self, attr, None) - # mpi_config.comm.barrier() - - varSubTmp = mpi_config.scatter_array(self, varTmp, config_options) - - # mpi_config.comm.barrier() + @property + @lru_cache + def spatial_metadata_exists(self) -> bool: + """Check to make sure the geospatial metadata file exists.""" + if self.config_options.spatial_meta is None: + return False + else: + return True - # Place the local lat/lon grid slices from the parent geogrid file into - # the ESMF lat/lon grids. + @property + @lru_cache + def geogrid_ds(self) -> xr.Dataset: + """Get the geogrid file path.""" try: - self.esmf_lat[:, :] = varSubTmp - self.latitude_grid = varSubTmp - varSubTmp = None - varTmp = None + with xr.open_dataset(self.config_options.geogrid) as ds: + return ds except Exception as e: - config_options.errMsg = ( - "Unable to subset latitude from geogrid file into ESMF object" + self.config_options.errMsg = "Unable to open geogrid file with xarray" + raise e + + @property + @lru_cache + @set_none + def esmf_ds(self) -> xr.Dataset: + """Open the geospatial metadata file and return the xarray dataset object.""" + try: + with xr.open_dataset(self.config_options.spatial_meta) as ds: + esmf_ds = ds.load() + except Exception as e: + self.config_options.errMsg = ( + f"Unable to open esmf file: {self.config_options.spatial_meta}" ) - raise Exception + raise e + self._check_variables_exist(esmf_ds) + return esmf_ds + + def _check_variables_exist(self, esmf_ds: xr.Dataset): + """Check to make sure the expected variables are present in the geospatial metadata file.""" + if self.mpi_config.rank == 0: + for var in ["crs", "x", "y"]: + if var not in esmf_ds.variables.keys(): + self.config_options.errMsg = f"Unable to locate {var} variable in: {self.config_options.spatial_meta}" + raise Exception - # mpi_config.comm.barrier() + def ncattrs(self, var: str) -> list: + """Extract variable attribute names from the geospatial metadata file.""" + return self.get_esmf_var(var).ncattrs() + + def get_var(self, ds: xr.Dataset, var: str) -> xr.DataArray: + """Get a variable from a xr.Dataset.""" + if self.mpi_config.rank == 0: + try: + return ds.variables[var] + except Exception as e: + self.config_options.errMsg = f"Unable to extract {var} variable from: {self.config_options.spatial_meta} due to {str(e)}" + raise e + + def get_geogrid_var(self, var: str) -> xr.DataArray: + """Get a variable from the geogrid file.""" + return self.get_var(self.geogrid_ds, var) + + def get_esmf_var(self, var: str) -> xr.DataArray: + """Get a variable from the geospatial metadata file.""" + return self.get_var(self.esmf_ds, var) + + @property + @lru_cache + @set_none + def _crs_att_names(self) -> list: + """Extract crs attribute names from the geospatial metadata file.""" + return self.ncattrs("crs") + + @property + @lru_cache + @set_none + def _x_coord_att_names(self) -> list: + """Extract x coordinate attribute names from the geospatial metadata file.""" + return self.ncattrs("x") + + @property + @lru_cache + @set_none + def _y_coord_att_names(self) -> list: + """Extract y coordinate attribute names from the geospatial metadata file.""" + return self.ncattrs("y") + + def getncattr(self, var: str) -> dict: + """Extract variable attribute values from the geospatial metadata file.""" + return { + item: self.get_esmf_var(var).getncattr(item) for item in self.ncattrs(var) + } + + @property + @lru_cache + @set_none + def x_coord_atts(self) -> dict: + """Extract x coordinate attribute values from the geospatial metadata file.""" + return self.getncattr("x") + + @property + @lru_cache + @set_none + def y_coord_atts(self) -> dict: + """Extract y coordinate attribute values from the geospatial metadata file.""" + return self.getncattr("y") + + @property + @lru_cache + @set_none + def crs_atts(self) -> dict: + """Extract crs coordinate attribute values from the geospatial metadata file.""" + return self.getncattr("crs") + + @property + @lru_cache + @set_none + def _global_att_names(self) -> list: + """Extract global attribute values from the geospatial metadata file.""" + if self.mpi_config.rank == 0: + try: + return self.esmf_ds.ncattrs() + except Exception as e: + self.config_options.errMsg = f"Unable to extract global attribute names from: {self.config_options.spatial_meta}" + raise e + + @property + @lru_cache + @set_none + def spatial_global_atts(self) -> dict: + """Extract global attribute values from the geospatial metadata file.""" + if self.mpi_config.rank == 0: + try: + return { + item: self.esmf_ds.getncattr(item) + for item in self._global_att_names + } + except Exception as e: + self.config_options.errMsg = f"Unable to extract global attributes from: {self.config_options.spatial_meta}" + raise e + + def extract_coords(self, dimension: str) -> np.ndarray: + """Extract coordinate values from the geospatial metadata file.""" + if self.mpi_config.rank == 0: + if len(self.get_esmf_var(dimension).shape) == 1: + return self.get_esmf_var(dimension)[:].data + elif len(self.get_esmf_var(dimension).shape) == 2: + return self.get_esmf_var(dimension)[:, :].data + + @property + @lru_cache + @set_none + def x_coords(self) -> np.ndarray: + """Extract x coordinate values from the geospatial metadata file.""" + return self.extract_coords("x") + + @property + @lru_cache + @set_none + def y_coords(self) -> np.ndarray: + """Extract y coordinate values from the geospatial metadata file. + + Check to see if the Y coordinates are North-South. If so, flip them. + """ + if self.mpi_config.rank == 0: + y_coords = self.extract_coords("y") + if len(self.get_esmf_var("y").shape) == 1: + if y_coords[1] < y_coords[0]: + y_coords[:] = np.flip(y_coords[:], axis=0) + elif len(self.get_esmf_var("y").shape) == 2: + if y_coords[1, 0] > y_coords[0, 0]: + y_coords[:, :] = np.flipud(y_coords[:, :]) + return y_coords - # Scatter global XLONG_M grid to processors.. - if mpi_config.rank == 0: - if idTmp.variables[config_options.lat_var].ndim == 3: - varTmp = idTmp.variables[config_options.lon_var][0, :, :] - elif idTmp.variables[config_options.lon_var].ndim == 2: - varTmp = idTmp.variables[config_options.lon_var][:, :] - elif idTmp.variables[config_options.lon_var].ndim == 1: - lat = idTmp.variables[config_options.lat_var][:] - lon = idTmp.variables[config_options.lon_var][:] - varTmp = np.meshgrid(lon, lat)[0] - lat = None - lon = None - # Flag to grab entire array for AWS slicing - if config_options.aws: - self.lon_bounds = varTmp - else: - varTmp = None - # mpi_config.comm.barrier() +class GriddedGeoMeta(GeoMeta): + """Class for handling information about the gridded domains for forcing.""" - varSubTmp = mpi_config.scatter_array(self, varTmp, config_options) + def __init__(self, config_options: ConfigOptions, mpi_config: MpiConfig) -> None: + """Initialize GriddedGeoMeta class variables. - # mpi_config.comm.barrier() + Initialization function to initialize ESMF through ESMPy, + calculate the global parameters of the WRF-Hydro grid + being processed to, along with the local parameters + for this particular processor. + :return: + """ + super().__init__(config_options, mpi_config) + for attr in CONSTS[self.__class__.__name__]: + setattr(self, attr, None) + + @broadcast + @property + @lru_cache + def nx_global(self) -> int: + """Get the global x dimension size for the gridded domain.""" + if self.mpi_config.rank == 0: + try: + if self.ndim_lat == 3: + return self.lat_var.shape[2] + elif self.ndim_lat == 2: + return self.lat_var.shape[1] + else: + # NOTE Is this correct? using lon_var + return self.lon_var.shape[0] + except Exception as e: + self.config_options.errMsg = f"Unable to extract X dimension size from longitude variable in: {self.config_options.geogrid}" + raise e + + @broadcast + @property + @lru_cache + def ny_global(self) -> int: + """Get the global y dimension size for the gridded domain.""" + if self.mpi_config.rank == 0: + try: + if self.ndim_lat == 3: + return self.lat_var.shape[1] + else: + return self.lat_var.shape[0] + except Exception as e: + self.config_options.errMsg = f"Unable to extract Y dimension size from latitude in: {self.config_options.geogrid}" + raise e + + @property + @lru_cache + def ndim_lat(self) -> int: + """Get the number of dimensions for the latitude variable.""" + return self.lat_var.ndim + + @property + @lru_cache + def ndim_lon(self) -> int: + """Get the number of dimensions for the longitude variable.""" + return self.lon_var.ndim + + @broadcast + @property + @lru_cache + def dy_meters(self) -> float: + """Get the DY distance in meters for the latitude variable.""" + if self.mpi_config.rank == 0: + try: + if self.ndim_lat == 3: + return self.geogrid_ds.DY + elif self.ndim_lat == 2: + return self.lat_var.dy + else: + if self.config_options.input_forcings[0] != 23: + return self.lat_var.dy + else: + # Manually input the grid spacing since ERA5-Interim does not + # internally have this geospatial information within the netcdf file + return 31000 + except Exception as e: + self.config_options.errMsg = f"Unable to extract DY global attribute in: {self.config_options.geogrid}" + raise e + + @broadcast + @property + @lru_cache + def dx_meters(self) -> float: + """Get the DX distance in meters for the longitude variable.""" + if self.mpi_config.rank == 0: + try: + if self.ndim_lat == 3: + return self.geogrid_ds.DX + elif self.ndim_lat == 2: + return self.lon_var.dx + else: + if self.config_options.input_forcings[0] != 23: + return self.lon_var.dx + else: + # Manually input the grid spacing since ERA5-Interim does not + # internally have this geospatial information within the netcdf file + return 31000 + except Exception as e: + self.config_options.errMsg = f"Unable to extract dx metadata attribute in: {self.config_options.geogrid}" + raise e + @property + @lru_cache + def esmf_grid(self) -> ESMF.Grid: + """Create the ESMF grid object for the gridded domain.""" try: - self.esmf_lon[:, :] = varSubTmp - self.longitude_grid = varSubTmp - varSubTmp = None - varTmp = None - except Exception as e: - config_options.errMsg = ( - "Unable to subset longitude from geogrid file into ESMF object" + return ESMF.Grid( + np.array([self.ny_global, self.nx_global]), + staggerloc=ESMF.StaggerLoc.CENTER, + coord_sys=ESMF.CoordSys.SPH_DEG, ) - raise Exception + except Exception as e: + self.config_options.errMsg = f"Unable to create ESMF grid for WRF-Hydro geogrid: {self.config_options.geogrid}" + raise e + + @property + @lru_cache + def esmf_lat(self) -> np.ndarray: + """Get the ESMF latitude grid.""" + esmf_lat = self.esmf_grid.get_coords(1) + esmf_lat[:, :] = self.latitude_grid + return esmf_lat + + @property + @lru_cache + def esmf_lon(self) -> np.ndarray: + """Get the ESMF longitude grid.""" + esmf_lon = self.esmf_grid.get_coords(0) + esmf_lon[:, :] = self.longitude_grid + return esmf_lon + + @scatter + @property + @lru_cache + def latitude_grid(self) -> np.ndarray: + """Get the latitude grid for the gridded domain.""" + # Scatter global XLAT_M grid to processors.. + if self.mpi_config.rank == 0: + if self.ndim_lat == 3: + var_tmp = self.lat_var[0, :, :] + elif self.ndim_lat == 2: + var_tmp = self.lat_var[:, :] + elif self.ndim_lat == 1: + lat = self.lat_var[:] + lon = self.lon_var[:] + var_tmp = np.meshgrid(lon, lat)[1] - # mpi_config.comm.barrier() + # Flag to grab entire array for AWS slicing + if self.config_options.aws: + self.lat_bounds = var_tmp + else: + var_tmp = None + return var_tmp, "latitude_grid", self.config_options, False + + @property + @lru_cache + def lon_var(self) -> xr.DataArray: + """Get the longitude variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.lon_var) + + @property + @lru_cache + def lat_var(self) -> xr.DataArray: + """Get the latitude variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.lat_var) + + @scatter + @property + @lru_cache + def longitude_grid(self) -> np.ndarray: + """Get the longitude grid for the gridded domain.""" + # Scatter global XLONG_M grid to processors.. + if self.mpi_config.rank == 0: + if self.ndim_lat == 3: + var_tmp = self.lon_var[0, :, :] + elif self.ndim_lat == 2: + var_tmp = self.lon_var[:, :] + elif self.ndim_lat == 1: + lat = self.lat_var[:] + lon = self.lon_var[:] + var_tmp = np.meshgrid(lon, lat)[0] + # Flag to grab entire array for AWS slicing + if self.config_options.aws: + self.lon_bounds = var_tmp + else: + var_tmp = None + + return var_tmp, "longitude_grid", self.config_options, False + + @property + @lru_cache + def cosalpha_var(self) -> xr.DataArray: + """Get the COSALPHA variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.cosalpha_var) + + @scatter + @property + @lru_cache + def cosa_grid(self) -> np.ndarray: + """Get the COSALPHA grid for the gridded domain.""" if ( - config_options.cosalpha_var is not None - and config_options.sinalpha_var is not None + self.config_options.cosalpha_var is not None + and self.config_options.sinalpha_var is not None ): # Scatter the COSALPHA,SINALPHA grids to the processors. - if mpi_config.rank == 0: - if idTmp.variables[config_options.cosalpha_var].ndim == 3: - varTmp = idTmp.variables[config_options.cosalpha_var][0, :, :] + if self.mpi_config.rank == 0: + if self.cosalpha_var.ndim == 3: + cosa = self.cosa_grid_from_geogrid_n3 else: - varTmp = idTmp.variables[config_options.cosalpha_var][:, :] + cosa = self.cosalpha_var[:, :] else: - varTmp = None - # mpi_config.comm.barrier() + cosa = None - varSubTmp = mpi_config.scatter_array(self, varTmp, config_options) - # mpi_config.comm.barrier() + return cosa, "cosa", self.config_options, True - self.cosa_grid = varSubTmp[:, :] - varSubTmp = None - varTmp = None - - if mpi_config.rank == 0: - if idTmp.variables[config_options.sinalpha_var].ndim == 3: - varTmp = idTmp.variables[config_options.sinalpha_var][0, :, :] - else: - varTmp = idTmp.variables[config_options.sinalpha_var][:, :] - else: - varTmp = None - # mpi_config.comm.barrier() - - varSubTmp = mpi_config.scatter_array(self, varTmp, config_options) - # mpi_config.comm.barrier() - self.sina_grid = varSubTmp[:, :] - varSubTmp = None - varTmp = None - - if config_options.hgt_var is not None: - # Read in a scatter the WRF-Hydro elevation, which is used for downscaling - # purposes. - if mpi_config.rank == 0: - if idTmp.variables[config_options.hgt_var].ndim == 3: - varTmp = idTmp.variables[config_options.hgt_var][0, :, :] - else: - varTmp = idTmp.variables[config_options.hgt_var][:, :] - else: - varTmp = None - # mpi_config.comm.barrier() - - varSubTmp = mpi_config.scatter_array(self, varTmp, config_options) - # mpi_config.comm.barrier() - self.height = varSubTmp - varSubTmp = None - varTmp = None + @property + @lru_cache + def sinalpha_var(self) -> xr.DataArray: + """Get the SINALPHA variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.sinalpha_var) + @property + @lru_cache + def sina_grid(self) -> np.ndarray: + """Get the SINALPHA grid for the gridded domain.""" if ( - config_options.cosalpha_var is not None - and config_options.sinalpha_var is not None + self.config_options.cosalpha_var is not None + and self.config_options.sinalpha_var is not None ): - # Calculate the slope from the domain using elevation on the WRF-Hydro domain. This will - # be used for downscaling purposes. - if mpi_config.rank == 0: - try: - slopeTmp, slp_azi_tmp = self.calc_slope(idTmp, config_options) - except Exception: - raise Exception + if self.mpi_config.rank == 0: + if self.sinalpha_var.ndim == 3: + sina = self.sina_grid_from_geogrid_n3 + else: + sina = self.sinalpha_var[:, :] else: - slopeTmp = None - slp_azi_tmp = None - # mpi_config.comm.barrier() + sina = None - slopeSubTmp = mpi_config.scatter_array(self, slopeTmp, config_options) - self.slope = slopeSubTmp[:, :] - slopeSubTmp = None + return sina, "sina", self.config_options, True - slp_azi_sub = mpi_config.scatter_array(self, slp_azi_tmp, config_options) - self.slp_azi = slp_azi_sub[:, :] - slp_azi_tmp = None + @property + @lru_cache + def hgt_var(self) -> xr.DataArray: + """Get the HGT variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.hgt_var) - elif ( - config_options.slope_var is not None - and config_options.slope_azimuth_var is not None - ): - if mpi_config.rank == 0: - if idTmp.variables[config_options.slope_var].ndim == 3: - varTmp = idTmp.variables[config_options.slope_var][0, :, :] - else: - varTmp = idTmp.variables[config_options.slope_var][:, :] - else: - varTmp = None - - slopeSubTmp = mpi_config.scatter_array(self, varTmp, config_options) - self.slope = slopeSubTmp - varTmp = None + @scatter + @property + @lru_cache + def height(self) -> np.ndarray: + """Get the height grid for the gridded domain. - if mpi_config.rank == 0: - if idTmp.variables[config_options.slope_azimuth_var].ndim == 3: - varTmp = idTmp.variables[config_options.slope_azimuth_var][0, :, :] + Used for downscaling purposes. + """ + if self.config_options.hgt_var is not None: + if self.mpi_config.rank == 0: + if self.hgt_var.ndim == 3: + height = self.hgt_grid_from_geogrid_n3 else: - varTmp = idTmp.variables[config_options.slope_azimuth_var][:, :] + height = self.hgt_var[:, :] else: - varTmp = None - - slp_azi_sub = mpi_config.scatter_array(self, varTmp, config_options) - self.slp_azi = slp_azi_sub[:, :] - varTmp = None - - elif config_options.hgt_var is not None: - # Calculate the slope from the domain using elevation of the gridded model and other approximations - if mpi_config.rank == 0: - try: - slopeTmp, slp_azi_tmp = self.calc_slope_gridded( - idTmp, config_options - ) - except Exception: - raise Exception - else: - slopeTmp = None - slp_azi_tmp = None - # mpi_config.comm.barrier() - - slopeSubTmp = mpi_config.scatter_array(self, slopeTmp, config_options) - self.slope = slopeSubTmp[:, :] - slopeSubTmp = None - - slp_azi_sub = mpi_config.scatter_array(self, slp_azi_tmp, config_options) - self.slp_azi = slp_azi_sub[:, :] - slp_azi_tmp = None + height = None + + return height, "height", self.config_options, False + + @property + @lru_cache + def slope_var(self) -> xr.DataArray: + """Get the slope variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.slope_var) + + @property + @lru_cache + def slope_azimuth_var(self) -> xr.DataArray: + """Get the slope azimuth variable from the geospatial metadata file.""" + return self.get_geogrid_var(self.config_options.slope_azimuth_var) + + @property + @lru_cache + def dx(self) -> np.ndarray: + """Calculate the dx distance in meters for the longitude variable.""" + dx = np.empty( + ( + self.lat_var.shape[0], + self.lon_var.shape[0], + ), + dtype=float, + ) + dx[:] = self.lon_var.dx + return dx + + @property + @lru_cache + def dy(self) -> np.ndarray: + """Calculate the dy distance in meters for the latitude variable.""" + dy = np.empty( + ( + self.lat_var.shape[0], + self.lon_var.shape[0], + ), + dtype=float, + ) + dy[:] = self.lat_var.dy + return dy + + @property + @lru_cache + def dz(self) -> np.ndarray: + """Calculate the dz distance in meters for the height variable.""" + dz_init = np.diff(self.hgt_var, axis=0) + dz = np.empty(self.dx.shape, dtype=float) + dz[0 : dz_init.shape[0], 0 : dz_init.shape[1]] = dz_init + dz[dz_init.shape[0] :, :] = dz_init[-1, :] + return dz - if mpi_config.rank == 0: - # Close the geogrid file - try: - idTmp.close() - except Exception as e: - config_options.errMsg = ( - "Unable to close geogrid file: " + config_options.geogrid - ) - raise Exception + @scatter + @property + @lru_cache + def slope(self) -> np.ndarray: + """Calculate slope grids needed for incoming shortwave radiation downscaling. - # Reset temporary variables to free up memory - slopeTmp = None - slp_azi_tmp = None - varTmp = None + Calculate slope from sina_grid, cosa_grid, and height variables if they are + present in the geogrid file, otherwise calculate slope from slope and slope + azimuth variables, and if those are not present, calculate slope from height variable. - def initialize_geospatial_metadata(self, config_options, mpi_config): - """Initialize GeoMetaWrfHydro class variables. - Function that will read in crs/x/y geospatial metadata and coordinates - from the optional geospatial metadata file IF it was specified by the user in - the configuration file. - :param config_options: - :return: + Calculate grid coordinates dx distances in meters + based on general geospatial formula approximations + on a spherical grid. """ - # We will only read information on processor 0. This data is not necessary for the - # other processors, and is only used in the output routines. - if mpi_config.rank == 0: - # Open the geospatial metadata file. - try: - idTmp = netCDF4.Dataset(config_options.spatial_meta, "r") - except Exception as e: - config_options.errMsg = ( - "Unable to open spatial metadata file: " - + config_options.spatial_meta - ) - raise Exception - - # Make sure the expected variables are present in the file. - if "crs" not in idTmp.variables.keys(): - config_options.errMsg = ( - "Unable to locate crs variable in: " + config_options.spatial_meta - ) - raise Exception - if "x" not in idTmp.variables.keys(): - config_options.errMsg = ( - "Unable to locate x variable in: " + config_options.spatial_meta - ) - raise Exception - if "y" not in idTmp.variables.keys(): - config_options.errMsg = ( - "Unable to locate y variable in: " + config_options.spatial_meta - ) - raise Exception - # Extract names of variable attributes from each of the input geospatial variables. These - # can change, so we are making this as flexible as possible to accomodate future changes. - try: - crs_att_names = idTmp.variables["crs"].ncattrs() - except Exception as e: - config_options.errMsg = ( - "Unable to extract crs attribute names from: " - + config_options.spatial_meta - ) - raise Exception - try: - x_coord_att_names = idTmp.variables["x"].ncattrs() - except Exception as e: - config_options.errMsg = ( - "Unable to extract x attribute names from: " - + config_options.spatial_meta - ) - raise Exception - try: - y_coord_att_names = idTmp.variables["y"].ncattrs() - except Exception as e: - config_options.errMsg = ( - "Unable to extract y attribute names from: " - + config_options.spatial_meta - ) - raise Exception - # Extract attribute values - try: - self.x_coord_atts = { - item: idTmp.variables["x"].getncattr(item) - for item in x_coord_att_names - } - except Exception as e: - config_options.errMsg = ( - "Unable to extract x coordinate attributes from: " - + config_options.spatial_meta - ) - raise Exception - try: - self.y_coord_atts = { - item: idTmp.variables["y"].getncattr(item) - for item in y_coord_att_names - } - except Exception as e: - config_options.errMsg = ( - "Unable to extract y coordinate attributes from: " - + config_options.spatial_meta - ) - raise Exception - try: - self.crs_atts = { - item: idTmp.variables["crs"].getncattr(item) - for item in crs_att_names - } - except Exception as e: - config_options.errMsg = ( - "Unable to extract crs coordinate attributes from: " - + config_options.spatial_meta - ) - raise Exception - - # Extract global attributes - try: - global_att_names = idTmp.ncattrs() - except Exception as e: - config_options.errMsg = ( - "Unable to extract global attribute names from: " - + config_options.spatial_meta - ) - raise Exception - try: - self.spatial_global_atts = { - item: idTmp.getncattr(item) for item in global_att_names - } - except Exception as e: - config_options.errMsg = ( - "Unable to extract global attributes from: " - + config_options.spatial_meta - ) - raise Exception - - # Extract x/y coordinate values - if len(idTmp.variables["x"].shape) == 1: - try: - self.x_coords = idTmp.variables["x"][:].data - except Exception as e: - config_options.errMsg = ( - "Unable to extract x coordinate values from: " - + config_options.spatial_meta - ) - raise Exception - try: - self.y_coords = idTmp.variables["y"][:].data - except Exception as e: - config_options.errMsg = ( - "Unable to extract y coordinate values from: " - + config_options.spatial_meta - ) - raise Exception - # Check to see if the Y coordinates are North-South. If so, flip them. - if self.y_coords[1] < self.y_coords[0]: - self.y_coords[:] = np.flip(self.y_coords[:], axis=0) - - if len(idTmp.variables["x"].shape) == 2: - try: - self.x_coords = idTmp.variables["x"][:, :].data - except Exception as e: - config_options.errMsg = ( - "Unable to extract x coordinate values from: " - + config_options.spatial_meta - ) - raise Exception - try: - self.y_coords = idTmp.variables["y"][:, :].data - except Exception as e: - config_options.errMsg = ( - "Unable to extract y coordinate values from: " - + config_options.spatial_meta - ) - raise Exception - # Check to see if the Y coordinates are North-South. If so, flip them. - if self.y_coords[1, 0] > self.y_coords[0, 0]: - self.y_coords[:, :] = np.flipud(self.y_coords[:, :]) - - # Close the geospatial metadata file. - try: - idTmp.close() - except Exception as e: - config_options.errMsg = ( - "Unable to close spatial metadata file: " - + config_options.spatial_meta - ) - raise Exception - - # mpi_config.comm.barrier() + if ( + self.config_options.cosalpha_var is not None + and self.config_options.sinalpha_var is not None + ): + slope = self.slope_from_cosalpha_sinalpha + elif ( + self.config_options.slope_var is not None + and self.config_options.slope_azimuth_var is not None + ): + slope = self.slope_from_slope_azimuth + elif self.config_options.hgt_var is not None: + slope = self.slope_from_height + else: + raise Exception( + "Unable to calculate slope grid for incoming shortwave radiation downscaling. No geospatial metadata variables provided to calculate slope." + ) + return slope, "slope", self.config_options, True - def calc_slope(self, idTmp, config_options): - """Calculate slope grids needed for incoming shortwave radiation downscaling. + @scatter + @property + @lru_cache + def slp_azi(self) -> np.ndarray: + """Calculate slope azimuth grids needed for incoming shortwave radiation downscaling. - Function to calculate slope grids needed for incoming shortwave radiation downscaling - later during the program. - :param idTmp: - :param config_options: - :return: + Calculate slp_azi from sina_grid, cosa_grid, and height variables if they are + present in the geogrid file, otherwise calculate slope from slope and slope + azimuth variables, and if those are not present, calculate slope from height variable. """ - # First extract the sina,cosa, and elevation variables from the geogrid file. - try: - sinaGrid = idTmp.variables[config_options.sinalpha_var][0, :, :] - except Exception as e: - config_options.errMsg = ( - "Unable to extract SINALPHA from: " + config_options.geogrid - ) - raise + if ( + self.config_options.cosalpha_var is not None + and self.config_options.sinalpha_var is not None + ): + slp_azi = self.slp_azi_from_cosalpha_sinalpha + elif ( + self.config_options.slope_var is not None + and self.config_options.slope_azimuth_var is not None + ): + slp_azi = self.slp_azi_from_slope_azimuth - try: - cosaGrid = idTmp.variables[config_options.cosalpha_var][0, :, :] - except Exception as e: - config_options.errMsg = ( - "Unable to extract COSALPHA from: " + config_options.geogrid + elif self.config_options.hgt_var is not None: + slp_azi = self.slp_azi_from_height + else: + raise Exception( + "Unable to calculate slope azimuth grid for incoming shortwave radiation downscaling. No geospatial metadata variables provided to calculate slope azimuth." ) - raise - try: - heightDest = idTmp.variables[config_options.hgt_var][0, :, :] - except Exception as e: - config_options.errMsg = ( - "Unable to extract HGT_M from: " + config_options.geogrid - ) - raise + return slp_azi, "slp_azi", self.config_options, True - # Ensure cosa/sina are correct dimensions - if sinaGrid.shape[0] != self.ny_global or sinaGrid.shape[1] != self.nx_global: - config_options.errMsg = ( - "SINALPHA dimensions mismatch in: " + config_options.geogrid + @property + @lru_cache + def slp_azi_from_slope_azimuth(self) -> np.ndarray: + """Calculate slope azimuth from slope and slope azimuth variables.""" + if self.mpi_config.rank == 0: + if self.slope_azimuth_var.ndim == 3: + return self.slope_azimuth_var[0, :, :] + else: + return self.slope_azimuth_var[:, :] + + @property + @lru_cache + def slp_azi_from_height(self) -> np.ndarray: + """Calculate slope azimuth from height variable.""" + if self.mpi_config.rank == 0: + return (180 / np.pi) * np.arctan(self.dx / self.dy) + + @property + @lru_cache + def slope_from_height(self) -> np.ndarray: + """Calculate slope from height variable.""" + if self.mpi_config.rank == 0: + return self.dz / np.sqrt((self.dx**2) + (self.dy**2)) + + @property + @lru_cache + def slope_from_slope_azimuth(self) -> np.ndarray: + """Calculate slope from slope and slope azimuth variables.""" + if self.mpi_config.rank == 0: + if self.slope_var.ndim == 3: + return self.slope_var[0, :, :] + else: + return self.slope_var[:, :] + + @property + @lru_cache + def slope_from_cosalpha_sinalpha(self) -> np.ndarray: + """Calculate slope from COSALPHA and SINALPHA variables.""" + if self.mpi_config.rank == 0: + slope_tmp = np.arctan( + (self.hx[self.ind_orig] ** 2 + self.hy[self.ind_orig] ** 2) ** 0.5 ) - raise Exception - if cosaGrid.shape[0] != self.ny_global or cosaGrid.shape[1] != self.nx_global: - config_options.errMsg = ( - "COSALPHA dimensions mismatch in: " + config_options.geogrid + slope_tmp[np.where(slope_tmp < 1e-4)] = 0.0 + return slope_tmp + + @property + @lru_cache + def slp_azi_from_cosalpha_sinalpha(self) -> np.ndarray: + """Calculate slope azimuth from COSALPHA and SINALPHA variables.""" + if self.mpi_config.rank == 0: + slp_azi = np.empty([self.ny_global, self.nx_global], np.float32) + slp_azi[np.where(self.slope_from_cosalpha_sinalpha < 1e-4)] = 0.0 + ind_valesmf_ds = np.where(self.slope_from_cosalpha_sinalpha >= 1e-4) + slp_azi[ind_valesmf_ds] = ( + np.arctan2(self.hx[ind_valesmf_ds], self.hy[ind_valesmf_ds]) + math.pi ) - raise Exception - if ( - heightDest.shape[0] != self.ny_global - or heightDest.shape[1] != self.nx_global - ): - config_options.errMsg = ( - "HGT_M dimension mismatch in: " + config_options.geogrid + ind_valesmf_ds = np.where(self.cosa_grid_from_geogrid_n3 >= 0.0) + slp_azi[ind_valesmf_ds] = slp_azi[ind_valesmf_ds] - np.arcsin( + self.sina_grid_from_geogrid_n3[ind_valesmf_ds] ) - raise Exception - - # Establish constants + ind_valesmf_ds = np.where(self.cosa_grid_from_geogrid_n3 < 0.0) + slp_azi[ind_valesmf_ds] = slp_azi[ind_valesmf_ds] - ( + math.pi - np.arcsin(self.sina_grid_from_geogrid_n3[ind_valesmf_ds]) + ) + return slp_azi + + @property + @lru_cache + def ind_orig(self) -> tuple[np.ndarray, np.ndarray]: + """Calculate the indices of the original grid points for the height variable.""" + return np.where(self.hgt_grid_from_geogrid_n3 == self.hgt_grid_from_geogrid_n3) + + @property + @lru_cache + def hx(self) -> np.ndarray: + """Calculate the slope in the x direction from the height variable.""" rdx = 1.0 / self.dx_meters - rdy = 1.0 / self.dy_meters msftx = 1.0 - msfty = 1.0 - - slopeOut = np.empty([self.ny_global, self.nx_global], np.float32) toposlpx = np.empty([self.ny_global, self.nx_global], np.float32) - toposlpy = np.empty([self.ny_global, self.nx_global], np.float32) - slp_azi = np.empty([self.ny_global, self.nx_global], np.float32) - ipDiff = np.empty([self.ny_global, self.nx_global], np.int32) - jpDiff = np.empty([self.ny_global, self.nx_global], np.int32) + ip_diff = np.empty([self.ny_global, self.nx_global], np.int32) hx = np.empty([self.ny_global, self.nx_global], np.float32) - hy = np.empty([self.ny_global, self.nx_global], np.float32) # Create index arrays that will be used to calculate slope. - xTmp = np.arange(self.nx_global) - yTmp = np.arange(self.ny_global) - xGrid = np.tile(xTmp[:], (self.ny_global, 1)) - yGrid = np.repeat(yTmp[:, np.newaxis], self.nx_global, axis=1) - indOrig = np.where(heightDest == heightDest) - indIp1 = ((indOrig[0]), (indOrig[1] + 1)) - indIm1 = ((indOrig[0]), (indOrig[1] - 1)) - indJp1 = ((indOrig[0] + 1), (indOrig[1])) - indJm1 = ((indOrig[0] - 1), (indOrig[1])) - indIp1[1][np.where(indIp1[1] >= self.nx_global)] = self.nx_global - 1 - indJp1[0][np.where(indJp1[0] >= self.ny_global)] = self.ny_global - 1 - indIm1[1][np.where(indIm1[1] < 0)] = 0 - indJm1[0][np.where(indJm1[0] < 0)] = 0 - - ipDiff[indOrig] = xGrid[indIp1] - xGrid[indIm1] - jpDiff[indOrig] = yGrid[indJp1] - yGrid[indJm1] - - toposlpx[indOrig] = ( - (heightDest[indIp1] - heightDest[indIm1]) * msftx * rdx - ) / ipDiff[indOrig] - toposlpy[indOrig] = ( - (heightDest[indJp1] - heightDest[indJm1]) * msfty * rdy - ) / jpDiff[indOrig] - hx[indOrig] = toposlpx[indOrig] - hy[indOrig] = toposlpy[indOrig] - slopeOut[indOrig] = np.arctan((hx[indOrig] ** 2 + hy[indOrig] ** 2) ** 0.5) - slopeOut[np.where(slopeOut < 1e-4)] = 0.0 - slp_azi[np.where(slopeOut < 1e-4)] = 0.0 - indValidTmp = np.where(slopeOut >= 1e-4) - slp_azi[indValidTmp] = np.arctan2(hx[indValidTmp], hy[indValidTmp]) + math.pi - indValidTmp = np.where(cosaGrid >= 0.0) - slp_azi[indValidTmp] = slp_azi[indValidTmp] - np.arcsin(sinaGrid[indValidTmp]) - indValidTmp = np.where(cosaGrid < 0.0) - slp_azi[indValidTmp] = slp_azi[indValidTmp] - ( - math.pi - np.arcsin(sinaGrid[indValidTmp]) - ) - - # Reset temporary arrays to None to free up memory - toposlpx = None - toposlpy = None - heightDest = None - sinaGrid = None - cosaGrid = None - indValidTmp = None - xTmp = None - yTmp = None - xGrid = None - ipDiff = None - jpDiff = None - indOrig = None - indJm1 = None - indJp1 = None - indIm1 = None - indIp1 = None - hx = None - hy = None - - return slopeOut, slp_azi - - def calc_slope_gridded(self, idTmp, config_options): - """Calculate slope grids needed for incoming shortwave radiation downscaling. + x_tmp = np.arange(self.nx_global) + x_grid = np.tile(x_tmp[:], (self.ny_global, 1)) + ind_ip1 = ((self.ind_orig[0]), (self.ind_orig[1] + 1)) + ind_im1 = ((self.ind_orig[0]), (self.ind_orig[1] - 1)) + ind_ip1[1][np.where(ind_ip1[1] >= self.nx_global)] = self.nx_global - 1 + ind_im1[1][np.where(ind_im1[1] < 0)] = 0 + + ip_diff[self.ind_orig] = x_grid[ind_ip1] - x_grid[ind_im1] + toposlpx[self.ind_orig] = ( + ( + self.hgt_grid_from_geogrid_n3[ind_ip1] + - self.hgt_grid_from_geogrid_n3[ind_im1] + ) + * msftx + * rdx + ) / ip_diff[self.ind_orig] + hx = np.empty([self.ny_global, self.nx_global], np.float32) + hx[self.ind_orig] = toposlpx[self.ind_orig] + return hx - Function to calculate slope grids needed for incoming shortwave radiation downscaling - later during the program. This calculates the slopes for grid cells - :param idTmp: - :param config_options: - :return: - """ - idTmp = netCDF4.Dataset(config_options.geogrid, "r") + @property + @lru_cache + def hy(self) -> np.ndarray: + """Calculate the slope in the y direction from the height variable.""" + rdy = 1.0 / self.dy_meters + msfty = 1.0 + toposlpy = np.empty([self.ny_global, self.nx_global], np.float32) + jp_diff = np.empty([self.ny_global, self.nx_global], np.int32) + hy = np.empty([self.ny_global, self.nx_global], np.float32) + # Create index arrays that will be used to calculate slope. + y_tmp = np.arange(self.ny_global) + y_grid = np.repeat(y_tmp[:, np.newaxis], self.nx_global, axis=1) + ind_jp1 = ((self.ind_orig[0] + 1), (self.ind_orig[1])) + ind_jm1 = ((self.ind_orig[0] - 1), (self.ind_orig[1])) + ind_jp1[0][np.where(ind_jp1[0] >= self.ny_global)] = self.ny_global - 1 + ind_jm1[0][np.where(ind_jm1[0] < 0)] = 0 + + jp_diff[self.ind_orig] = y_grid[ind_jp1] - y_grid[ind_jm1] + toposlpy[self.ind_orig] = ( + ( + self.hgt_grid_from_geogrid_n3[ind_jp1] + - self.hgt_grid_from_geogrid_n3[ind_jm1] + ) + * msfty + * rdy + ) / jp_diff[self.ind_orig] + hy[self.ind_orig] = toposlpy[self.ind_orig] + return hy + + @property + @lru_cache + def x_lower_bound(self) -> float: + """Get the local x lower bound for this processor.""" + return self.esmf_grid.lower_bounds[ESMF.StaggerLoc.CENTER][1] + + @property + @lru_cache + def x_upper_bound(self) -> float: + """Get the local x upper bound for this processor.""" + return self.esmf_grid.upper_bounds[ESMF.StaggerLoc.CENTER][1] + + @property + @lru_cache + def y_lower_bound(self) -> float: + """Get the local y lower bound for this processor.""" + return self.esmf_grid.lower_bounds[ESMF.StaggerLoc.CENTER][0] + + @property + @lru_cache + def y_upper_bound(self) -> float: + """Get the local y upper bound for this processor.""" + return self.esmf_grid.upper_bounds[ESMF.StaggerLoc.CENTER][0] + + @property + @lru_cache + def nx_local(self) -> int: + """Get the local x dimension size for this processor.""" + return self.x_upper_bound - self.x_lower_bound + + @property + @lru_cache + def ny_local(self) -> int: + """Get the local y dimension size for this processor.""" + return self.y_upper_bound - self.y_lower_bound + + @property + @lru_cache + def sina_grid_from_geogrid_n3(self) -> np.ndarray: + """Get the SINALPHA grid for the gridded domain directly from the geogrid file.""" try: - lons = idTmp.variables[config_options.lon_var][:] - lats = idTmp.variables[config_options.lat_var][:] + return self.check_grid(self.sinalpha_var[0, :, :]) except Exception as e: - config_options.errMsg = ( - "Unable to extract gridded coordinates in " + config_options.geogrid + self.config_options.errMsg = ( + f"Unable to extract SINALPHA from: {self.config_options.geogrid}" + ) + raise e + + def check_grid(self, grid: np.ndarray) -> np.ndarray: + """Check to make sure the grid dimensions match the expected dimensions for the gridded domain.""" + if grid.shape[0] != self.ny_global or grid.shape[1] != self.nx_global: + self.config_options.errMsg = ( + f"Grid dimensions mismatch in: {self.config_options.geogrid}" ) raise Exception + return grid + + @property + @lru_cache + def cosa_grid_from_geogrid_n3(self) -> np.ndarray: + """Get the COSALPHA grid for the gridded domain directly from the geogrid file.""" try: - dx = np.empty( - ( - idTmp.variables[config_options.lat_var].shape[0], - idTmp.variables[config_options.lon_var].shape[0], - ), - dtype=float, - ) - dy = np.empty( - ( - idTmp.variables[config_options.lat_var].shape[0], - idTmp.variables[config_options.lon_var].shape[0], - ), - dtype=float, - ) - dx[:] = idTmp.variables[config_options.lon_var].dx - dy[:] = idTmp.variables[config_options.lat_var].dy + return self.check_grid(self.cosalpha_var[0, :, :]) except Exception as e: - config_options.errMsg = ( - "Unable to extract dx and dy distances in " + config_options.geogrid + self.config_options.errMsg = ( + f"Unable to extract COSALPHA from: {self.config_options.geogrid}" ) - raise Exception + raise e + + @property + @lru_cache + def hgt_grid_from_geogrid_n3(self) -> np.ndarray: + """Get the HGT_M grid for the gridded domain directly from the geogrid file.""" try: - heights = idTmp.variables[config_options.hgt_var][:] + return self.check_grid(self.hgt_var[0, :, :]) except Exception as e: - config_options.errMsg = ( - "Unable to extract heights of grid cells in " + config_options.geogrid + self.config_options.errMsg = ( + f"Unable to extract HGT_M from: {self.config_options.geogrid}" ) - raise Exception + raise e - idTmp.close() - # calculate grid coordinates dx distances in meters - # based on general geospatial formula approximations - # on a spherical grid - dz_init = np.diff(heights, axis=0) - dz = np.empty(dx.shape, dtype=float) - dz[0 : dz_init.shape[0], 0 : dz_init.shape[1]] = dz_init - dz[dz_init.shape[0] :, :] = dz_init[-1, :] +class HydrofabricGeoMeta(GeoMeta): + """Class for handling information about the hydrofabric domain forcing.""" - slope = dz / np.sqrt((dx**2) + (dy**2)) - slp_azi = (180 / np.pi) * np.arctan(dx / dy) - - # Reset temporary arrays to None to free up memory - lons = None - lats = None - heights = None - dx = None - dy = None - dz = None - - return slope, slp_azi - - def initialize_destination_geo_unstructured(self, config_options, mpi_config): - """Initialize GeoMetaWrfHydro class variables. + def __init__(self, config_options: ConfigOptions, mpi_config: MpiConfig): + """Initialize HydrofabricGeoMeta class variables. Initialization function to initialize ESMF through ESMPy, - calculate the global parameters of the WRF-Hydro grid + calculate the global parameters of the hydrofabric being processed to, along with the local parameters for this particular processor. :return: """ - # Open the geogrid file and extract necessary information - # to create ESMF fields. - if mpi_config.rank == 0: - try: - idTmp = netCDF4.Dataset(config_options.geogrid, "r") - except Exception as e: - config_options.errMsg = ( - "Unable to open the unstructured mesh file: " - + config_options.geogrid - ) - raise Exception - - try: - self.nx_global = idTmp.variables[config_options.nodecoords_var].shape[0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract X dimension size in " + config_options.geogrid - ) - raise Exception - - try: - self.ny_global = idTmp.variables[config_options.nodecoords_var].shape[0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract Y dimension size in " + config_options.geogrid - ) - raise Exception - - try: - self.nx_global_elem = idTmp.variables[ - config_options.elemcoords_var - ].shape[0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract X dimension size in " + config_options.geogrid - ) - raise Exception + super().__init__(config_options, mpi_config) + for attr in CONSTS[self.__class__.__name__]: + setattr(self, attr, None) + + @property + @lru_cache + def lat_bounds(self) -> np.ndarray: + """Get the latitude bounds for the unstructured domain.""" + return self.get_bound(1) + + @property + @lru_cache + def lon_bounds(self) -> np.ndarray: + """Get the longitude bounds for the unstructured domain.""" + return self.get_bound(0) + + def get_bound(self, dim: int) -> np.ndarray: + """Get the longitude or latitude bounds for the unstructured domain.""" + if self.config_options.aws: + return self.get_geogrid_var(self.config_options.nodecoords_var)[:, dim] + + @broadcast + @property + @lru_cache + def elementcoords_global(self) -> np.ndarray: + """Get the global element coordinates for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.elemcoords_var) + + @barrier + @broadcast + @property + @lru_cache + def nx_global(self) -> int: + """Get the global x dimension size for the unstructured domain.""" + return self.elementcoords_global.shape[0] + + @barrier + @broadcast + @property + @lru_cache + def ny_global(self) -> int: + """Get the global y dimension size for the unstructured domain. + + Same as nx_global. + """ + return self.nx_global + @property + @lru_cache + def esmf_grid(self) -> ESMF.Mesh: + """Create the ESMF Mesh object for the unstructured domain.""" + try: + return ESMF.Mesh( + filename=self.config_options.geogrid, filetype=ESMF.FileFormat.ESMFMESH + ) + except Exception as e: + LOG.critical( + f"Unable to create ESMF Mesh: {self.config_options.geogrid} " + f"due to {str(e)}" + ) + raise e + + @property + @lru_cache + def latitude_grid(self) -> np.ndarray: + """Get the latitude grid for the unstructured domain.""" + return self.esmf_grid.coords[1][1] + + @property + @lru_cache + def longitude_grid(self) -> np.ndarray: + """Get the longitude grid for the unstructured domain.""" + return self.esmf_grid.coords[1][0] + + @property + @lru_cache + def pet_element_inds(self) -> np.ndarray: + """Get the PET element indices for the unstructured domain.""" + if self.mpi_config.rank == 0: try: - self.ny_global_elem = idTmp.variables[ - config_options.elemcoords_var - ].shape[0] + tree = spatial.KDTree(self.elementcoords_global) + return tree.query( + np.column_stack([self.longitude_grid, self.latitude_grid]) + )[1] except Exception as e: - config_options.errMsg = ( - "Unable to extract Y dimension size in " + config_options.geogrid + LOG.critical( + f"Failed to open mesh file: {self.config_options.geogrid} " + f"due to {str(e)}" ) - raise Exception + raise e + + @property + @lru_cache + def element_ids(self) -> np.ndarray: + """Get the element IDs for the unstructured domain.""" + return self.element_ids_global[self.pet_element_inds] + + @broadcast + @property + @lru_cache + def element_ids_global(self) -> np.ndarray: + """Get the global element IDs for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.element_id_var).values + + @broadcast + @property + @lru_cache + def heights_global(self) -> np.ndarray: + """Get the global heights for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.hgt_var) + + @property + @lru_cache + def height(self) -> np.ndarray: + """Get the height grid for the unstructured domain.""" + if self.mpi_config.rank == 0: + if self.config_options.hgt_var is not None: + return self.heights_global[self.pet_element_inds] + + @property + @lru_cache + def slope(self) -> np.ndarray: + """Get the slopes for the unstructured domain.""" + return self.slopes_global[self.pet_element_inds] + + @property + @lru_cache + def slp_azi(self) -> np.ndarray: + """Get the slope azimuths for the unstructured domain.""" + return self.slp_azi_global[self.pet_element_inds] + + @property + @lru_cache + def mesh_inds(self) -> np.ndarray: + """Get the mesh indices for the unstructured domain.""" + return self.pet_element_inds + + @broadcast + @property + @lru_cache + def slopes_global(self) -> np.ndarray: + """Get the global slopes for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.slope_var) + + @property + @lru_cache + def slp_azi_global(self) -> np.ndarray: + """Get the global slope azimuths for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.slope_azimuth_var) + + @property + @lru_cache + def nx_local(self) -> int: + """Get the local x dimension size for this processor.""" + return len(self.esmf_grid.coords[1][1]) + + @property + @lru_cache + def ny_local(self) -> int: + """Get the local y dimension size for this processor.""" + return len(self.esmf_grid.coords[1][1]) + + +class UnstructuredGeoMeta(GeoMeta): + """Class for handling information about the hydrofabric domain forcing.""" + + def __init__(self, config_options: ConfigOptions, mpi_config: MpiConfig) -> None: + """Initialize HydrofabricGeoMeta class variables. + Initialization function to initialize ESMF through ESMPy, + calculate the global parameters of the unstructured mesh + being processed to, along with the local parameters + for this particular processor. + :return: + """ + super().__init__(config_options, mpi_config) + for attr in CONSTS[self.__class__.__name__]: + setattr(self, attr, None) + + @broadcast + @property + @lru_cache + def nx_global(self) -> int: + """Get the global x dimension size for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.nodecoords_var).shape[0] + + @broadcast + @property + @lru_cache + def ny_global(self) -> int: + """Get the global y dimension size for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.nodecoords_var).shape[0] + + @broadcast + @property + @lru_cache + def nx_global_elem(self) -> int: + """Get the global x dimension size for the unstructured domain elements.""" + return self.get_esmf_var(self.config_options.elemcoords_var).shape[0] + + @broadcast + @property + @lru_cache + def ny_global_elem(self) -> int: + """Get the global y dimension size for the unstructured domain elements.""" + return self.get_esmf_var(self.config_options.elemcoords_var).shape[0] + + @property + @lru_cache + def lon_bounds(self) -> np.ndarray: + """Get the longitude bounds for the unstructured domain.""" + return self.get_bound(0) + + @property + @lru_cache + def lat_bounds(self) -> np.ndarray: + """Get the latitude bounds for the unstructured domain.""" + return self.get_bound(1) + + def get_bound(self, dim: int) -> np.ndarray: + """Get the longitude or latitude bounds for the unstructured domain.""" + if self.mpi_config.rank == 0: # Flag to grab entire array for AWS slicing - if config_options.aws: - self.lat_bounds = idTmp.variables[config_options.nodecoords_var][:][ - :, 1 - ] - self.lon_bounds = idTmp.variables[config_options.nodecoords_var][:][ - :, 0 - ] - - # mpi_config.comm.barrier() - - # Broadcast global dimensions to the other processors. - self.nx_global = mpi_config.broadcast_parameter( - self.nx_global, config_options, param_type=int - ) - self.ny_global = mpi_config.broadcast_parameter( - self.ny_global, config_options, param_type=int - ) - self.nx_global_elem = mpi_config.broadcast_parameter( - self.nx_global_elem, config_options, param_type=int - ) - self.ny_global_elem = mpi_config.broadcast_parameter( - self.ny_global_elem, config_options, param_type=int - ) - - # mpi_config.comm.barrier() + if self.config_options.aws: + return self.get_esmf_var(self.config_options.nodecoords_var)[:][:, dim] - if mpi_config.rank == 0: - # Close the geogrid file - try: - idTmp.close() - except Exception as e: - config_options.errMsg = ( - "Unable to close geogrid Mesh file: " + config_options.geogrid - ) - raise Exception + @property + @lru_cache + def esmf_grid(self) -> ESMF.Mesh: + """Create the ESMF grid object for the unstructured domain. + Removed argument coord_sys=ESMF.CoordSys.SPH_DEG since we are always reading from a file + From ESMF documentation + If you create a mesh from a file (like NetCDF/ESMF-Mesh), coord_sys is ignored. The mesh’s coordinate system should be embedded in the file or inferred. + """ try: - # Removed argument coord_sys=ESMF.CoordSys.SPH_DEG since we are always reading from a file - # From ESMF documentation - # If you create a mesh from a file (like NetCDF/ESMF-Mesh), coord_sys is ignored. The mesh’s coordinate system should be embedded in the file or inferred. - self.esmf_grid = ESMF.Mesh( - filename=config_options.geogrid, filetype=ESMF.FileFormat.ESMFMESH + return ESMF.Mesh( + filename=self.config_options.geogrid, filetype=ESMF.FileFormat.ESMFMESH ) except Exception as e: - config_options.errMsg = ( - "Unable to create ESMF Mesh from geogrid file: " - + config_options.geogrid - ) - raise Exception + self.config_options.errMsg = f"Unable to create ESMF Mesh from geogrid file: {self.config_options.geogrid}" + raise e - # mpi_config.comm.barrier() + @property + @lru_cache + def latitude_grid(self) -> np.ndarray: + """Get the latitude grid for the unstructured domain. - # Obtain the local boundaries for this processor. - self.get_processor_bounds(config_options) + Place the local lat/lon grid slices from the parent geogrid file into + the ESMF lat/lon grids that have already been seperated by processors. + """ + return self.esmf_grid.coords[0][1] - # Place the local lat/lon grid slices from the parent geogrid file into - # the ESMF lat/lon grids that have already been seperated by processors. - try: - self.latitude_grid = self.esmf_grid.coords[0][1] - self.latitude_grid_elem = self.esmf_grid.coords[1][1] - varSubTmp = None - varTmp = None - except Exception as e: - config_options.errMsg = ( - "Unable to subset node latitudes from ESMF Mesh object" - ) - raise Exception - try: - self.longitude_grid = self.esmf_grid.coords[0][0] - self.longitude_grid_elem = self.esmf_grid.coords[1][0] - varSubTmp = None - varTmp = None - except Exception as e: - config_options.errMsg = ( - "Unable to subset XLONG_M from geogrid file into ESMF Mesh object" - ) - raise Exception + @property + @lru_cache + def latitude_grid_elem(self) -> np.ndarray: + """Get the latitude grid for the unstructured domain elements. - idTmp = netCDF4.Dataset(config_options.geogrid, "r") + Place the local lat/lon grid slices from the parent geogrid file into + the ESMF lat/lon grids that have already been seperated by processors. + """ + return self.esmf_grid.coords[1][1] - # Get lat and lon global variables for pet extraction of indices - nodecoords_global = idTmp.variables[config_options.nodecoords_var][:].data - elementcoords_global = idTmp.variables[config_options.elemcoords_var][:].data + @property + @lru_cache + def longitude_grid(self) -> np.ndarray: + """Get the longitude grid for the unstructured domain. + + Place the local lat/lon grid slices from the parent geogrid file into + the ESMF lat/lon grids that have already been seperated by processors. + """ + return self.esmf_grid.coords[0][0] + @property + @lru_cache + def longitude_grid_elem(self) -> np.ndarray: + """Get the longitude grid for the unstructured domain elements. + + Place the local lat/lon grid slices from the parent geogrid file into + the ESMF lat/lon grids that have already been seperated by processors. + """ + return self.esmf_grid.coords[1][0] + + @property + @lru_cache + def pet_element_inds(self) -> np.ndarray: + """Get the local node indices for the unstructured domain elements.""" + # Get lat and lon global variables for pet extraction of indices + elementcoords_global = self.get_var( + self.geogrid_ds, self.config_options.elemcoords_var + )[:].data # Find the corresponding local indices to slice global heights and slope # variables that are based on the partitioning on the unstructured mesh - pet_nodecoords = np.empty((len(self.latitude_grid), 2), dtype=float) pet_elementcoords = np.empty((len(self.latitude_grid_elem), 2), dtype=float) - pet_nodecoords[:, 0] = self.longitude_grid - pet_nodecoords[:, 1] = self.latitude_grid pet_elementcoords[:, 0] = self.longitude_grid_elem pet_elementcoords[:, 1] = self.latitude_grid_elem + return spatial.KDTree(elementcoords_global).query(pet_elementcoords)[1] - distance, pet_node_inds = spatial.KDTree(nodecoords_global).query( - pet_nodecoords - ) - distance, pet_element_inds = spatial.KDTree(elementcoords_global).query( - pet_elementcoords - ) + @property + @lru_cache + def pet_node_inds(self) -> np.ndarray: + """Get the local node indices for the unstructured domain nodes.""" + # Get lat and lon global variables for pet extraction of indices + nodecoords_global = self.get_var( + self.geogrid_ds, self.config_options.nodecoords_var + )[:].data + # Find the corresponding local indices to slice global heights and slope + # variables that are based on the partitioning on the unstructured mesh + pet_nodecoords = np.empty((len(self.latitude_grid), 2), dtype=float) + pet_nodecoords[:, 0] = self.longitude_grid + pet_nodecoords[:, 1] = self.latitude_grid - # reset variables to free up memory - nodecoords_global = None - elementcoords_global = None - pet_nodecoords = None - pet_elementcoords = None - distance = None - - # Not accepting cosalpha and sinalpha at this time for unstructured meshes, only - # accepting the pre-calculated slope and slope azmiuth variables if available, - # otherwise calculate slope from height estimates - # if(config_options.cosalpha_var != None and config_options.sinalpha_var != None): - # self.cosa_grid = idTmp.variables[config_options.cosalpha_var][:].data[pet_node_inds] - # self.sina_grid = idTmp.variables[config_options.sinalpha_var][:].data[pet_node_inds] - # slopeTmp, slp_azi_tmp = self.calc_slope(idTmp,config_options) - # self.slope = slope_node_Tmp[pet_node_inds] - # self.slp_azi = slp_azi_node_tmp[pet_node_inds] + return spatial.KDTree(nodecoords_global).query(pet_nodecoords)[1] + + # NOTE this is a note/commented out code from before refactor on 2/19/2026. + # Not accepting cosalpha and sinalpha at this time for unstructured meshes, only + # accepting the pre-calculated slope and slope azmiuth variables if available, + # otherwise calculate slope from height estimates + # if(config_options.cosalpha_var != None and config_options.sinalpha_var != None): + # self.cosa_grid = esmf_ds.variables[config_options.cosalpha_var][:].data[pet_node_inds] + # self.sina_grid = esmf_ds.variables[config_options.sinalpha_var][:].data[pet_node_inds] + # slope_tmp, slp_azi_tmp = self.calc_slope(esmf_ds,config_options) + # self.slope = slope_node_tmp[pet_node_inds] + # self.slp_azi = slp_azi_node_tmp[pet_node_inds] + + @property + @lru_cache + def slope(self) -> np.ndarray: + """Get the slope grid for the unstructured domain.""" if ( - config_options.slope_var is not None - and config_options.slp_azi_var is not None + self.config_options.slope_var is not None + and self.config_options.slp_azi_var is not None ): - self.slope = idTmp.variables[config_options.slope_var][:].data[ - pet_node_inds - ] - self.slp_azi = idTmp.variables[config_options.slope_azimuth_var][:].data[ - pet_node_inds + return self.get_geogrid_var(self.config_options.slope_var)[ + self.pet_node_inds ] - self.slope_elem = idTmp.variables[config_options.slope_var_elem][:].data[ - pet_element_inds + elif self.config_options.hgt_var is not None: + return ( + self.dz_node + / np.sqrt((self.dx_node**2) + (self.dy_node**2))[self.pet_node_inds] + ) + + @property + @lru_cache + def slp_azi(self) -> np.ndarray: + """Get the slope azimuth grid for the unstructured domain.""" + if ( + self.config_options.slope_var is not None + and self.config_options.slp_azi_var is not None + ): + return self.get_geogrid_var(self.config_options.slope_azimuth_var)[ + self.pet_node_inds ] - self.slp_azi_elem = idTmp.variables[config_options.slope_azimuth_var_elem][ - : - ].data[pet_element_inds] - - # Read in a scatter the mesh node elevation, which is used for downscaling purposes - self.height = idTmp.variables[config_options.hgt_var][:].data[pet_node_inds] - # Read in a scatter the mesh element elevation, which is used for downscaling purposes. - self.height_elem = idTmp.variables[config_options.hgt_elem_var][:].data[ - pet_element_inds + elif self.config_options.hgt_var is not None: + return (180 / np.pi) * np.arctan(self.dx_node / self.dy_node)[ + self.pet_node_inds ] - elif config_options.hgt_var is not None: - # Read in a scatter the mesh node elevation, which is used for downscaling purposes - self.height = idTmp.variables[config_options.hgt_var][:].data[pet_node_inds] - - # Read in a scatter the mesh element elevation, which is used for downscaling purposes. - self.height_elem = idTmp.variables[config_options.hgt_elem_var][:].data[ - pet_element_inds + @property + @lru_cache + def slope_elem(self) -> np.ndarray: + """Get the slope grid for the unstructured domain elements.""" + if ( + self.config_options.slope_var is not None + and self.config_options.slp_azi_var is not None + ): + return self.get_geogrid_var(self.config_options.slope_var_elem)[:].data[ + self.pet_element_inds ] - - # Calculate the slope from the domain using elevation on the WRF-Hydro domain. This will - # be used for downscaling purposes. - slope_node_Tmp, slp_azi_node_tmp, slope_elem_Tmp, slp_azi_elem_tmp = ( - self.calc_slope_unstructured(idTmp, config_options) + elif self.config_options.hgt_var is not None: + return ( + self.dz_elem + / np.sqrt((self.dx_elem**2) + (self.dy_elem**2))[self.pet_element_inds] ) - self.slope = slope_node_Tmp[pet_node_inds] - slope_node_Tmp = None - - self.slp_azi = slp_azi_node_tmp[pet_node_inds] - slp_azi__node_tmp = None - - self.slope_elem = slope_elem_Tmp[pet_element_inds] - slope_elem_Tmp = None - - self.slp_azi_elem = slp_azi_elem_tmp[pet_element_inds] - slp_azi_elem_tmp = None - - # save indices where mesh was partition for future scatter functions - self.mesh_inds = pet_node_inds - self.mesh_inds_elem = pet_element_inds - - # reset variables to free up memory - pet_node_inds = None - pet_element_inds = None + @property + @lru_cache + def slp_azi_elem(self) -> np.ndarray: + """Get the slope azimuth grid for the unstructured domain elements.""" + if ( + self.config_options.slope_var is not None + and self.config_options.slp_azi_var is not None + ): + return self.get_var( + self.geogrid_ds, self.config_options.slope_azimuth_var_elem + )[:].data[self.pet_element_inds] + elif self.config_options.hgt_var is not None: + return (180 / np.pi) * np.arctan(self.dx_elem / self.dy_elem)[ + self.pet_element_inds + ] - def calc_slope_unstructured(self, idTmp, config_options): - """Calculate slope grids needed for incoming shortwave radiation downscaling. + @property + @lru_cache + def height(self) -> np.ndarray: + """Get the height grid for the unstructured domain nodes.""" + if ( + self.config_options.slope_var is not None + and self.config_options.slp_azi_var is not None + ): + return self.get_geogrid_var(self.config_options.hgt_var)[:].data[ + self.pet_node_inds + ] + elif self.config_options.hgt_var is not None: + return self.self.get_geogrid_var(self.config_options.hgt_var)[:].data[ + self.pet_node_inds + ] - Function to calculate slope grids needed for incoming shortwave radiation downscaling - later during the program. This calculates the slopes for both nodes and elements - :param idTmp: - :param config_options: - :return: - """ - idTmp = netCDF4.Dataset(config_options.geogrid, "r") + @property + @lru_cache + def height_elem(self) -> np.ndarray: + """Get the height grid for the unstructured domain elements.""" + if ( + self.config_options.slope_var is not None + and self.config_options.slp_azi_var is not None + ): + return self.get_geogrid_var(self.config_options.hgt_elem_var)[:].data[ + self.pet_element_inds + ] + elif self.config_options.hgt_var is not None: + return self.get_geogrid_var(self.config_options.hgt_elem_var)[:].data[ + self.pet_element_inds + ] - try: - node_lons = idTmp.variables[config_options.nodecoords_var][:][:, 0] - node_lats = idTmp.variables[config_options.nodecoords_var][:][:, 1] - except Exception as e: - config_options.errMsg = ( - "Unable to extract node coordinates in " + config_options.geogrid - ) - raise Exception - try: - elem_lons = idTmp.variables[config_options.elemcoords_var][:][:, 0] - elem_lats = idTmp.variables[config_options.elemcoords_var][:][:, 1] - except Exception as e: - config_options.errMsg = ( - "Unable to extract element coordinates in " + config_options.geogrid - ) - raise Exception - try: - elem_conn = idTmp.variables[config_options.elemconn_var][:][:, 0] - except Exception as e: - config_options.errMsg = ( - "Unable to extract element connectivity in " + config_options.geogrid - ) - raise Exception - try: - node_heights = idTmp.variables[config_options.hgt_var][:] - except Exception as e: - config_options.errMsg = ( - "Unable to extract HGT_M from: " + config_options.geogrid - ) - raise + @property + @lru_cache + def node_lons(self) -> np.ndarray: + """Get the longitude grid for the unstructured domain nodes.""" + return self.get_geogrid_var(self.config_options.nodecoords_var)[:][:, 0] + + @property + @lru_cache + def node_lats(self) -> np.ndarray: + """Get the latitude grid for the unstructured domain nodes.""" + return self.get_geogrid_var(self.config_options.nodecoords_var)[:][:, 1] + + @property + @lru_cache + def elem_lons(self) -> np.ndarray: + """Get the longitude grid for the unstructured domain elements.""" + return self.get_geogrid_var(self.config_options.elemcoords_var)[:][:, 0] + + @property + @lru_cache + def elem_lats(self) -> np.ndarray: + """Get the latitude grid for the unstructured domain elements.""" + return self.get_geogrid_var(self.config_options.elemcoords_var)[:][:, 1] + + @property + @lru_cache + def elem_conn(self) -> np.ndarray: + """Get the element connectivity for the unstructured domain.""" + return self.get_geogrid_var(self.config_options.elemconn_var)[:][:, 0] + + @property + @lru_cache + def node_heights(self) -> np.ndarray: + """Get the height grid for the unstructured domain nodes.""" + node_heights = self.get_geogrid_var(self.config_options.hgt_var)[:] if node_heights.shape[0] != self.ny_global: - config_options.errMsg = ( - "HGT_M dimension mismatch in: " + config_options.geogrid + self.config_options.errMsg = ( + f"HGT_M dimension mismatch in: {self.config_options.geogrid}" ) raise Exception - - try: - elem_heights = idTmp.variables[config_options.hgt_elem_var][:] - except Exception as e: - config_options.errMsg = ( - "Unable to extract HGT_M_ELEM from: " + config_options.geogrid - ) - raise - - if elem_heights.shape[0] != len(elem_lons): - config_options.errMsg = ( - "HGT_M_ELEM dimension mismatch in: " + config_options.geogrid + return node_heights + + @property + @lru_cache + def elem_heights(self) -> np.ndarray: + """Get the height grid for the unstructured domain elements.""" + elem_heights = self.get_var(self.geogrid_ds, self.config_options.hgt_elem_var)[ + : + ] + + if elem_heights.shape[0] != len(self.elem_lons): + self.config_options.errMsg = ( + f"HGT_M_ELEM dimension mismatch in: {self.config_options.geogrid}" ) raise Exception - - idTmp.close() - - # calculate node coordinate distances in meters - # based on general geospatial formula approximations - # on a spherical grid - dx = np.diff(node_lons) * 40075160 * np.cos(node_lats[0:-1] * np.pi / 180) / 360 - dx = np.append(dx, dx[-1]) - dy = np.diff(node_lats) * 40008000 / 360 - dy = np.append(dy, dy[-1]) - dz = np.diff(node_heights) - dz = np.append(dz, dz[-1]) - - slope_nodes = dz / np.sqrt((dx**2) + (dy**2)) - slp_azi_nodes = (180 / np.pi) * np.arctan(dx / dy) - - # calculate element coordinate distances in meters - # based on general geospatial formula approximations - # on a spherical grid - dx = np.diff(elem_lons) * 40075160 * np.cos(elem_lats[0:-1] * np.pi / 180) / 360 - dx = np.append(dx, dx[-1]) - dy = np.diff(elem_lats) * 40008000 / 360 - dy = np.append(dy, dy[-1]) - dz = np.diff(elem_heights) - dz = np.append(dz, dz[-1]) - - slope_elem = dz / np.sqrt((dx**2) + (dy**2)) - slp_azi_elem = (180 / np.pi) * np.arctan(dx / dy) - - # Reset temporary arrays to None to free up memory - node_lons = None - node_lats = None - elem_lons = None - elem_lats = None - node_heights = None - elem_heights = None - dx = None - dy = None - dz = None - - return slope_nodes, slp_azi_nodes, slope_elem, slp_azi_elem - - def initialize_destination_geo_hydrofabric(self, config_options, mpi_config): - """Initialize GeoMetaWrfHydro class variables. - - Initialization function to initialize ESMF through ESMPy, - calculate the global parameters of the WRF-Hydro grid - being processed to, along with the local parameters - for this particular processor. - :return: - """ - - if config_options.geogrid is not None: - # Phase 1: Rank 0 extracts all needed global data - if mpi_config.rank == 0: - try: - idTmp = nc_utils.nc_Dataset_retry( - mpi_config, - config_options, - err_handler, - config_options.geogrid, - "r", - ) - - # Extract everything we need with retries - tmp_vars = idTmp.variables - - if config_options.aws: - nodecoords_data = nc_utils.nc_read_var_retry( - mpi_config, - config_options, - err_handler, - tmp_vars[config_options.nodecoords_var], - ) - self.lat_bounds = nodecoords_data[:, 1] - self.lon_bounds = nodecoords_data[:, 0] - - # Store these for later broadcast/scatter - elementcoords_global = nc_utils.nc_read_var_retry( - mpi_config, - config_options, - err_handler, - tmp_vars[config_options.elemcoords_var], - ) - - self.nx_global = elementcoords_global.shape[0] - self.ny_global = self.nx_global - - element_ids_global = nc_utils.nc_read_var_retry( - mpi_config, - config_options, - err_handler, - tmp_vars[config_options.element_id_var], - ) - - heights_global = None - if config_options.hgt_var is not None: - heights_global = nc_utils.nc_read_var_retry( - mpi_config, - config_options, - err_handler, - tmp_vars[config_options.hgt_var], - ) - slopes_global = None - slp_azi_global = None - if config_options.slope_var is not None: - slopes_global = nc_utils.nc_read_var_retry( - mpi_config, - config_options, - err_handler, - tmp_vars[config_options.slope_var], - ) - if config_options.slope_azimuth_var is not None: - slp_azi_global = nc_utils.nc_read_var_retry( - mpi_config, - config_options, - err_handler, - tmp_vars[config_options.slope_azimuth_var], - ) - - except Exception as e: - LOG.critical( - f"Failed to open mesh file: {config_options.geogrid} " - f"due to {str(e)}" - ) - raise - finally: - idTmp.close() - else: - elementcoords_global = None - element_ids_global = None - heights_global = None - slopes_global = None - slp_azi_global = None - - # Broadcast dimensions - self.nx_global = mpi_config.broadcast_parameter( - self.nx_global, config_options, param_type=int - ) - self.ny_global = mpi_config.broadcast_parameter( - self.ny_global, config_options, param_type=int - ) - - mpi_config.comm.barrier() - - # Phase 2: Create ESMF Mesh (collective operation with retry) - try: - self.esmf_grid = esmf_utils.esmf_mesh_retry( - mpi_config, - config_options, - err_handler, - filename=config_options.geogrid, - filetype=ESMF.FileFormat.ESMFMESH, - ) - except Exception as e: - LOG.critical( - f"Unable to create ESMF Mesh: {config_options.geogrid} " - f"due to {str(e)}" - ) - raise - - # Get processor bounds - self.get_processor_bounds(config_options) - - # Extract local coordinates from ESMF mesh - self.latitude_grid = self.esmf_grid.coords[1][1] - self.longitude_grid = self.esmf_grid.coords[1][0] - - # Phase 3: Broadcast global arrays and compute local indices - elementcoords_global = mpi_config.comm.bcast(elementcoords_global, root=0) - element_ids_global = mpi_config.comm.bcast(element_ids_global, root=0) - - # Each rank computes its own local indices - pet_elementcoords = np.column_stack( - [self.longitude_grid, self.latitude_grid] - ) - tree = spatial.KDTree(elementcoords_global) - _, pet_element_inds = tree.query(pet_elementcoords) - - self.element_ids = element_ids_global[pet_element_inds] - self.element_ids_global = element_ids_global - - # Broadcast and extract height/slope data - if config_options.hgt_var is not None: - heights_global = mpi_config.comm.bcast(heights_global, root=0) - self.height = heights_global[pet_element_inds] - - if config_options.slope_var is not None: - slopes_global = mpi_config.comm.bcast(slopes_global, root=0) - slp_azi_global = mpi_config.comm.bcast(slp_azi_global, root=0) - self.slope = slopes_global[pet_element_inds] - self.slp_azi = slp_azi_global[pet_element_inds] - - self.mesh_inds = pet_element_inds + return elem_heights + + @property + @lru_cache + def dx_elem(self) -> np.ndarray: + """Calculate the dx distance in meters for the longitude variable for the unstructured domain elements.""" + dx = ( + np.diff(self.elem_lons) + * 40075160 + * np.cos(self.elem_lats[0:-1] * np.pi / 180) + / 360 + ) + return np.append(dx, dx[-1]) + + @property + @lru_cache + def dy_elem(self) -> np.ndarray: + """Calculate the dy distance in meters for the latitude variable for the unstructured domain elements.""" + dy = np.diff(self.elem_lats) * 40008000 / 360 + return np.append(dy, dy[-1]) + + @property + @lru_cache + def dz_elem(self) -> np.ndarray: + """Calculate the dz distance in meters for the height variable for the unstructured domain elements.""" + dz = np.diff(self.elem_heights) + return np.append(dz, dz[-1]) + + @property + @lru_cache + def dx_node(self) -> np.ndarray: + """Calculate the dx distance in meters for the longitude variable for the unstructured domain nodes.""" + dx = ( + np.diff(self.node_lons) + * 40075160 + * np.cos(self.node_lats[0:-1] * np.pi / 180) + / 360 + ) + return np.append(dx, dx[-1]) + + @property + @lru_cache + def dy_node(self) -> np.ndarray: + """Calculate the dy distance in meters for the latitude variable for the unstructured domain nodes.""" + dy = np.diff(self.node_lats) * 40008000 / 360 + return np.append(dy, dy[-1]) + + @property + @lru_cache + def dz_node(self) -> np.ndarray: + """Calculate the dz distance in meters for the height variable for the unstructured domain nodes.""" + dz = np.diff(self.node_heights) + return np.append(dz, dz[-1]) + + @property + @lru_cache + def mesh_inds(self) -> np.ndarray: + """Get the local mesh node indices for the unstructured domain.""" + return self.pet_node_inds + + @property + @lru_cache + def mesh_inds_elem(self) -> np.ndarray: + """Get the local mesh element indices for the unstructured domain.""" + return self.pet_element_inds + + @property + @lru_cache + def nx_local(self) -> int: + """Get the local x dimension size for this processor.""" + return len(self.esmf_grid.coords[0][1]) + + @property + @lru_cache + def ny_local(self) -> int: + """Get the local y dimension size for this processor.""" + return len(self.esmf_grid.coords[0][1]) + + @property + @lru_cache + def nx_local_elem(self) -> int: + """Get the local x dimension size for this processor.""" + return len(self.esmf_grid.coords[1][1]) + + @property + @lru_cache + def ny_local_elem(self) -> int: + """Get the local y dimension size for this processor.""" + return len(self.esmf_grid.coords[1][1]) + + +GEOGRID = { + # "gridded": GriddedGeoMeta, + # "unstructured": UnstructuredGeoMeta, + "hydrofabric": HydrofabricGeoMeta, +} diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/regrid.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/regrid.py index 99126c62..df0b9e89 100755 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/regrid.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/core/regrid.py @@ -43,7 +43,7 @@ ConfigOptions, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.geoMod import ( - GeoMetaWrfHydro, + GeoMeta, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig from nextgen_forcings_ewts import MODULE_NAME @@ -11979,7 +11979,7 @@ def check_supp_pcp_regrid_status( def get_weight_file_names( mpi_config: MpiConfig, config_options: ConfigOptions, - input_forcings: GeoMetaWrfHydro, + input_forcings: GeoMeta, ) -> tuple[str | None, str | None]: """Get weight file names for regridding.""" if not config_options.weightsDir: @@ -12005,7 +12005,7 @@ def get_weight_file_names( def load_weight_file( mpi_config: MpiConfig, config_options: ConfigOptions, - input_forcings: GeoMetaWrfHydro, + input_forcings: GeoMeta, weight_file: str, element_mode: bool, ) -> None: @@ -12051,7 +12051,7 @@ def load_weight_file( def make_regrid( mpi_config: MpiConfig, config_options: ConfigOptions, - input_forcings: GeoMetaWrfHydro, + input_forcings: GeoMeta, weight_file: str | None, fill: bool, element_mode: bool, @@ -12120,7 +12120,7 @@ def make_regrid( def execute_regrid( mpi_config: MpiConfig, config_options: ConfigOptions, - input_forcings: GeoMetaWrfHydro, + input_forcings: GeoMeta, weight_file: str, element_mode: bool, ) -> None: diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/model.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/model.py index a71a2f8c..d6bd80dc 100755 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/model.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/model.py @@ -18,7 +18,7 @@ ConfigOptions, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.geoMod import ( - GeoMetaWrfHydro, + GeoMeta, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.ioMod import OutputObj from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig @@ -80,7 +80,7 @@ def run( model: dict, future_time: float, config_options: ConfigOptions, - wrf_hydro_geo_meta: GeoMetaWrfHydro, + wrf_hydro_geo_meta: GeoMeta, input_forcing_mod: dict, supp_pcp_mod: dict, mpi_config: MpiConfig, @@ -347,7 +347,7 @@ def loop_through_forcing_products( self, future_time: float, config_options: ConfigOptions, - wrf_hydro_geo_meta: GeoMetaWrfHydro, + wrf_hydro_geo_meta: GeoMeta, input_forcing_mod: dict, supp_pcp_mod: dict, mpi_config: MpiConfig, @@ -664,7 +664,7 @@ def loop_through_forcing_products( def process_suplemental_precip( self, config_options: ConfigOptions, - wrf_hydro_geo_meta: GeoMetaWrfHydro, + wrf_hydro_geo_meta: GeoMeta, supp_pcp_mod: dict, mpi_config: MpiConfig, output_obj: OutputObj, @@ -736,7 +736,7 @@ def process_suplemental_precip( def write_output( self, config_options: ConfigOptions, - wrf_hydro_geo_meta: GeoMetaWrfHydro, + wrf_hydro_geo_meta: GeoMeta, mpi_config: MpiConfig, output_obj: OutputObj, ): @@ -764,7 +764,7 @@ def update_dict( self, model: dict, config_options: ConfigOptions, - wrf_hydro_geo_meta: GeoMetaWrfHydro, + wrf_hydro_geo_meta: GeoMeta, output_obj: OutputObj, ): """Flatten the Forcings Engine output object and update the BMI dictionary.""" diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/retry_utils.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/retry_utils.py index 5b2598a2..dc14dfb5 100644 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/retry_utils.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/retry_utils.py @@ -3,8 +3,8 @@ import traceback import types -from .core.config import ConfigOptions -from .core.parallel import MpiConfig +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.config import ConfigOptions def retry_w_mpi_context(