Source code for

"""IO Utilities module."""
from __future__ import annotations

import json
import logging.config
import os
from import Sequence
from datetime import date
from pathlib import Path

import dask
import netCDF4 as nc  # noqa
import xarray as xr
import zarr

from miranda.scripting import LOGGING_CONFIG


__all__ = [

_data_folder = Path(__file__).parent / "data"
name_configurations = json.load(open(_data_folder / "ouranos_name_config.json"))

[docs] def name_output_file( ds_or_dict: xr.Dataset | dict[str, str], output_format: str ) -> str: """Name an output file based on facets within a Dataset or a dictionary. Parameters ---------- ds_or_dict : xr.Dataset or dict A miranda-converted Dataset or a dictionary containing the appropriate facets. output_format : {"netcdf", "zarr"} Output filetype to be used for generating filename suffix. Returns ------- str Notes ----- If using a dictionary, the following keys must be set: * "variable", "frequency", "institution", "time_start", "time_end". """ if output_format.lower() not in {"netcdf", "zarr"}: raise NotImplementedError(f"Format: {output_format}.") else: suffix = dict(netcdf="nc", zarr="zarr")[output_format] facets = dict() facets["suffix"] = suffix if isinstance(ds_or_dict, xr.Dataset): if len(ds_or_dict.data_vars) == 1: facets["variable"] = list(ds_or_dict.data_vars.keys())[0] elif ( len(ds_or_dict.data_vars) == 2 and "rotated_pole" in ds_or_dict.data_vars.keys() ): facets["variable"] = [ v for v in ds_or_dict.data_vars if v != "rotated_pole" ][0] else: raise NotImplementedError( f"Too many `data_vars` in Dataset: {' ,'.join(ds_or_dict.data_vars.keys())}." ) for f in [ "bias_adjust_project", "domain", "frequency", "institution", "source", "experiment", "member", "processing_level", "project", "type", "mip_era", "activity", ]: facets[f] = ds_or_dict.attrs.get(f) if facets["frequency"] in ["1hr", "day"]: date_format = "%Y%m%d" elif facets["frequency"] == "month": date_format = "%Y%m" elif facets["frequency"] == "year": date_format = "%Y" else: raise KeyError("`frequency` not found.") facets["time_start"], facets["time_end"] = ( ds_or_dict.time.isel(time=[0, -1]).dt.strftime(date_format).values ) facets["year_start"], facets["year_end"] = ds_or_dict.time.isel( time=[0, -1] ).dt.year.values elif isinstance(ds_or_dict, dict): for f in [ "bias_adjust_project", "domain", "frequency", "institution", "processing_level", "project", "type", "time", "time_end", "time_start", "variable", ]: facets[f] = ds_or_dict.get(f) else: raise NotImplementedError("Must be a Dataset or dictionary.") if {"time_start", "time_end"}.issubset(facets) and "time" not in facets: if facets["time_start"] == facets["time_end"]: facets["time"] = "-".join([facets["time_start"], facets["time_end"]]) else: facets["time"] = facets["time_start"] str_name = "{variable}_{frequency}_{institution}_{project}_{time}.{suffix}" # Get the string for the name if facets["type"] in name_configurations.keys(): if facets["project"] in name_configurations[facets["type"]].keys(): str_name = name_configurations[facets["type"]][facets["project"]] missing = [] for k, v in facets.items(): if ( v is None and k in str_name ): # only missing if the facets is needed in the name missing.append(k) if missing: raise ValueError(f"The following facets were not found: {' ,'.join(missing)}.") # fill in string with facets return str_name.format(**facets)
[docs] def delayed_write( ds: xr.Dataset, outfile: str | os.PathLike, output_format: str, overwrite: bool, target_chunks: dict | None = None, ) -> dask.delayed: """Stage a Dataset writing job using `dask.delayed` objects. Parameters ---------- ds : xr.Dataset outfile : str or os.PathLike target_chunks : dict output_format : {"netcdf", "zarr"} overwrite : bool Returns ------- dask.delayed.delayed """ # Set correct chunks in encoding options kwargs = dict() kwargs["encoding"] = dict() try: for name, da in ds.data_vars.items(): chunks = list() for dim in da.dims: if target_chunks: if dim in target_chunks.keys(): chunks.append(target_chunks[str(dim)]) else: chunks.append(len(da[dim])) if output_format == "netcdf": kwargs["encoding"][name] = { "chunksizes": chunks, "zlib": True, } kwargs["compute"] = False if Path(outfile).exists() and not overwrite: kwargs["mode"] = "a" elif output_format == "zarr": ds = ds.chunk(target_chunks) kwargs["encoding"][name] = { "chunks": chunks, "compressor": zarr.Blosc(), } kwargs["compute"] = False if overwrite: kwargs["mode"] = "w" if kwargs["encoding"]: kwargs["encoding"]["time"] = {"dtype": "int32"} except KeyError: logging.error("Unable to encode chunks. Verify dataset.") raise return getattr(ds, f"to_{output_format}")(outfile, **kwargs)
[docs] def get_time_attrs(file_or_dataset: str | os.PathLike | xr.Dataset) -> (str, int): """Determine attributes related to time dimensions.""" if isinstance(file_or_dataset, (str, Path)): ds = xr.open_dataset(Path(file_or_dataset).expanduser()) else: ds = file_or_dataset calendar = ds.time.dt.calendar time = len(ds.time) return calendar, time
[docs] def get_global_attrs( file_or_dataset: str | os.PathLike | xr.Dataset, ) -> dict[str, str | int]: """Collect global attributes from NetCDF, Zarr, or Dataset object.""" if isinstance(file_or_dataset, (str, Path)): file = Path(file_or_dataset).expanduser() elif isinstance(file_or_dataset, xr.Dataset): file = file_or_dataset else: raise NotImplementedError(f"Type: `{type(file_or_dataset)}`.") if isinstance(file, Path): if file.is_file() and file.suffix in [".nc", ".nc4"]: with nc.Dataset(file, mode="r") as ds: data = dict() for k in ds.ncattrs(): data[k] = getattr(ds, k) elif file.is_dir() and file.suffix == ".zarr": with, mode="r") as ds: # noqa data = ds.attrs.asdict() else: data = file.attrs return data
[docs] def sort_variables( files: list[Path], variables: Sequence[str] ) -> dict[str, list[Path]]: """Sort all variables within supplied files for treatment. Parameters ---------- files: list of Path variables: sequence of str Returns ------- dict[str, list[Path]] """ variable_sorted = dict() if variables:"Sorting variables into groups. This could take some time.") for variable in variables: var_group = [] for file in files: if var_group.append(file) if not var_group: logging.warning(f"No files found for {variable}. Continuing...") continue variable_sorted[variable] = sorted(var_group) else: variable_sorted["all_variables"] = files return variable_sorted
[docs] def get_chunks_on_disk(file: os.PathLike | str) -> dict: """Determine the chunks on disk for a given NetCDF or Zarr file. Parameters ---------- file : str or os.PathLike File to be examined. Supports NetCDF and Zarr. Returns ------- dict """ chunks = dict() file = Path(file) if file.suffix.lower() in [".nc", ".nc4"]: with nc.Dataset(file) as ds: for v in ds.variables: chunks[v] = dict() for ii, dim in enumerate(ds[v].dimensions): chunks[v][dim] = ds[v].chunking()[ii] elif file.suffix.lower() == "zarr" and file.is_dir(): with, "r") as ds: # noqa for v in ds.arrays(): # Check if variable is chunked if v[1]: chunks[v[0]] = v[1] else: raise NotImplementedError(f"File type: {file.suffix}.") return chunks
[docs] def creation_date(path_to_file: str | os.PathLike) -> float | date: """Return the date that a file was created, falling back to when it was last modified if unable to determine. See for explanation. Parameters ---------- path_to_file : str or os.PathLike Returns ------- float or date """ if == "nt": return Path(path_to_file).stat().st_ctime stat = Path(path_to_file).stat() try: return date.fromtimestamp(stat.st_ctime) except AttributeError: # We're probably on Linux. No easy way to get creation dates here, # so we'll settle for when its content was last modified. return date.fromtimestamp(stat.st_mtime)