diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/general_utils.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/general_utils.py index 953af214..93692dc2 100644 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/general_utils.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/general_utils.py @@ -3,6 +3,7 @@ import json import logging import typing +import uuid import numpy as np @@ -168,3 +169,14 @@ def assert_equal_with_tol( if errors: raise ExpectVsActualError(errors) + + +def rand_str(length: int) -> str: + """Build and return a random string of length up to 32. + Note that if this is called by different MPI ranks, each rank will receive a different random string. + The string is not broadcasted.""" + if not (0 < length <= 32): + raise ValueError( + f"length requested was {length}, but this function only supports length 1 through 32" + ) + return str(uuid.uuid4()).replace("-", "")[:length] diff --git a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/historical_forcing.py b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/historical_forcing.py index 7fdbd06b..4c7491f1 100644 --- a/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/historical_forcing.py +++ b/NextGen_Forcings_Engine_BMI/NextGen_Forcings_Engine/historical_forcing.py @@ -1,15 +1,16 @@ """Module for processing AORC and NWM data.""" import datetime +import gc import os -import re import typing from contextlib import contextmanager from datetime import timedelta from functools import cached_property -from time import perf_counter +from time import perf_counter, sleep -import dask +# Use the Error, Warning, and Trapping System Package for logging +import ewts import geopandas as gpd import matplotlib.pyplot as plt import numpy as np @@ -22,17 +23,17 @@ from pyproj import CRS from zarr.storage import ObjectStore +from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.general_utils import rand_str from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.config import ( ConfigOptions, ) from NextGen_Forcings_Engine_BMI.NextGen_Forcings_Engine.core.parallel import MpiConfig -# Use the Error, Warning, and Trapping System Package for logging -import ewts LOG = ewts.get_logger(ewts.FORCING_ID) zarr.config.set({"async.concurrency": 100}) + class BaseProcessor: """Base class for data processors.""" @@ -148,7 +149,12 @@ def gpkg_name(self) -> str: @property def nc_path(self) -> str: """Construct file path for cached netcdf files.""" - return f"/tmp/{self.dataset_name}_{self.gpkg_name}_{self.current_time_str}_{self.end_time_str}.nc" + return f"/tmp/{self.cache_filename}.nc" + + @property + def cache_filename(self): + """Cache filename.""" + return f"{self.dataset_name}_{self.gpkg_name}_{self.current_time_str}_{self.end_time_str}" @property def end_time_datetime(self) -> pd.Timestamp: @@ -239,11 +245,33 @@ def compute_ds(self) -> xr.Dataset: ds = None if self.mpi_config.rank == 0: with self.timing_block("computing dataset", LOG.info): - ds = self.sliced_ds.compute().rio.write_crs(self.src_crs) + ds = self.sliced_ds.rio.write_crs(self.src_crs) self.mpi_config.comm.barrier() ds = self.mpi_config.comm.bcast(ds, root=0) if self.mpi_config.rank == 0: - ds.to_netcdf(self.nc_path) + if not os.path.exists(self.nc_path): + tmp_file = ( + f"{self.nc_path}.{rand_str(12)}{os.path.splitext(self.nc_path)[1]}" + ) + c = 0 + while c < 10: + LOG.info(f"Writing tmp file: {tmp_file}") + try: + ds.to_netcdf(tmp_file, "w") + LOG.info(f"Renaming: {tmp_file} -> {self.nc_path}") + os.replace(tmp_file, self.nc_path) + LOG.info(f"Renamed: {tmp_file} -> {self.nc_path}") + break + except Exception as e: + LOG.warning( + f"There appears to be a lock on the netcdf cache file while writing. Sleeping 1 second and trying again ({c}). | Error: {e}" + ) + sleep(1) + c += 1 + else: + raise PermissionError( + f"Could not write the netcdf cache file within the specified number of retries(10): {self.nc_path}" + ) return ds @cached_property @@ -322,6 +350,27 @@ def slice_ds(self, ds: xr.Dataset) -> xr.Dataset: ) return sliced_ds + def load_cache(self) -> xr.Dataset | None: + """Load the cahed netcdf file.""" + if os.path.exists(self.nc_path): + with self.timing_block(f"opening local dataset {self.nc_path}"): + c = 0 + while c < 10: + try: + ds = xr.open_dataset(self.nc_path) + dataset = ds.load() + ds.close() + gc.collect() + return dataset + except Exception as e: + LOG.warning(f"Lock on cache file; sleeping 1s({c}). Error: {e}") + sleep(1) + c += 1 + + error_message = f"Exceeded number of attempts (10) to read local cache file for historical forcing data. File: {self.nc_path}. Deleteing the cache file and recreating from s3" + LOG.warning(error_message) + os.remove(self.nc_path) + class AORCConusProcessor(BaseProcessor): """Processor for CONUS AORC data.""" @@ -367,20 +416,20 @@ def sliced_ds(self) -> xr.Dataset: :return: xarray Dataset :raises Exception: If zarr open fails """ + cached_data = self.load_cache() + if cached_data is not None: + return cached_data try: - if os.path.exists(self.nc_path): - with self.timing_block(f"opening local dataset {self.nc_path}"): - return xr.open_dataset(self.nc_path) - else: - with self.timing_block(f"lazy loading {self.dataset_name} data"): - return self.slice_ds( - self.s3_lazy_ds[self.current_time.year] - ).rename({self.x_label: "x", self.y_label: "y"}) + with self.timing_block(f"lazy loading {self.dataset_name} data"): + return ( + self.slice_ds(self.s3_lazy_ds[self.current_time.year]) + .rename({self.x_label: "x", self.y_label: "y"}) + .load() + ) except Exception as e: - LOG.critical( - f"Error opening {self.dataset_name} data from {self.url(self.current_time.year)}: {e}\n" - ) - raise e + error_message = f"Error opening {self.dataset_name} data from {self.url(self.current_time.year)}: {e}\n" + LOG.critical(error_message) + raise ValueError(error_message) @cached_property def s3_lazy_ds(self) -> dict[int, xr.Dataset]: @@ -448,8 +497,10 @@ def sliced_ds(self) -> xr.Dataset: s3 = s3fs.S3FileSystem() with s3.open(self.url(date)) as f: ds = xr.open_dataset(f, engine="h5netcdf") + dataset = ds.load() + ds.close() datasets.append( - self.slice_ds(ds, date, date + np.timedelta64(1, "h")) + self.slice_ds(dataset, date, date + np.timedelta64(1, "h")) ) except Exception as e: LOG.critical( @@ -599,20 +650,20 @@ def sliced_ds(self) -> xr.Dataset: :return: xarray Dataset :raises Exception: If zarr open fails """ + cached_data = self.load_cache() + if cached_data is not None: + return cached_data try: - if os.path.exists(self.nc_path): - with self.timing_block(f"opening local dataset {self.nc_path}"): - return xr.open_dataset(self.nc_path) - else: - with self.timing_block(f"lazy loading {self.dataset_name} data"): - return self.slice_ds(self.s3_lazy_ds).rename( - {self.x_label: "x", self.y_label: "y"} - ) + with self.timing_block(f"lazy loading {self.dataset_name} data"): + return ( + self.slice_ds(self.s3_lazy_ds) + .rename({self.x_label: "x", self.y_label: "y"}) + .load() + ) except Exception as e: - LOG.critical( - f"Error opening {self.dataset_name} data from {self.url}: {e}\n" - ) - raise e + error_message = f"Error opening {self.dataset_name} data from {self.url(self.current_time.year)}: {e}\n" + LOG.critical(error_message) + raise ValueError(error_message) @cached_property def s3_lazy_ds(self) -> xr.Dataset: @@ -653,20 +704,20 @@ def sliced_ds(self) -> xr.Dataset: :return: xarray Dataset :raises Exception: If zarr open fails """ + cached_data = self.load_cache() + if cached_data is not None: + return cached_data try: - if os.path.exists(self.nc_path): - with self.timing_block(f"opening local dataset {self.nc_path}"): - return xr.open_dataset(self.nc_path) - else: - with self.timing_block(f"lazy loading {self.dataset_name} data"): - return self.slice_ds(self.s3_lazy_ds).rename( - {self.x_label: "x", self.y_label: "y"} - ) + with self.timing_block(f"lazy loading {self.dataset_name} data"): + return ( + self.slice_ds(self.s3_lazy_ds) + .rename({self.x_label: "x", self.y_label: "y"}) + .load() + ) except Exception as e: - LOG.critical( - f"Error opening {self.dataset_name} data from {self.url}: {e}\n" - ) - raise e + error_message = f"Error opening {self.dataset_name} data from {self.url(self.current_time.year)}: {e}\n" + LOG.critical(error_message) + raise ValueError(error_message) @cached_property def src_crs(self) -> CRS: