Source code for miranda.io.utils

"""IO Utilities module."""

from __future__ import annotations
import importlib.util as ilu
import json
import logging
import os
from collections.abc import Sequence
from pathlib import Path
from typing import Any, cast

import dask.delayed
import h5netcdf
import xarray as xr


HAS_NETCDF4 = bool(ilu.find_spec("netCDF4"))

logger = logging.getLogger("miranda.io.utils")


__all__ = [
    "delayed_write",
    "get_chunks_on_disk",
    "get_global_attrs",
    "get_time_attrs",
    "name_output_file",
    "sort_variables",
]

_data_folder = Path(__file__).parent / "data"
name_configurations = json.load(_data_folder.joinpath("ouranos_name_config.json").open("r", encoding="utf-8"))


[docs] def name_output_file( ds_or_dict: xr.Dataset | dict[str, str], output_format: str, data_vars: str | None = None, ) -> str: """ Name an output file based on facets within a Dataset or a dictionary. Parameters ---------- ds_or_dict : xr.Dataset or dict A miranda-converted Dataset or a dictionary containing the appropriate facets. output_format : {"netcdf", "zarr"} Output filetype to be used for generating filename suffix. data_vars : str, optional If using a Dataset, the name of the data variable to be used for naming the file. Returns ------- str The formatted filename. Notes ----- If using a dictionary, the following keys must be set: * "variable", "frequency", "institution", "time_start", "time_end". """ if output_format.lower() not in {"netcdf", "zarr"}: raise NotImplementedError(f"Format: {output_format}.") else: suffix = dict(netcdf="nc", zarr="zarr")[output_format] facets = dict() facets["suffix"] = suffix if isinstance(ds_or_dict, xr.Dataset): if data_vars is not None: facets["variable"] = data_vars elif len(ds_or_dict.data_vars) == 1: facets["variable"] = list(ds_or_dict.data_vars.keys())[0] elif len(ds_or_dict.data_vars) == 2 and "rotated_pole" in ds_or_dict.data_vars.keys(): facets["variable"] = [v for v in ds_or_dict.data_vars if v != "rotated_pole"][0] else: raise NotImplementedError(f"Too many `data_vars` in Dataset: {', '.join(ds_or_dict.data_vars.keys())}.") for f in [ "bias_adjust_project", "domain", "frequency", "institution", "source", "experiment", "member", "processing_level", "project", "type", "mip_era", "activity", ]: facets[f] = ds_or_dict.attrs.get(f) if facets["frequency"] in ["1hr", "day"]: date_format = "%Y%m%d" elif facets["frequency"] == "month": date_format = "%Y%m" elif facets["frequency"] == "year": date_format = "%Y" else: raise KeyError("`frequency` not found.") facets["time_start"], facets["time_end"] = ds_or_dict.time.isel(time=[0, -1]).dt.strftime(date_format).values facets["year_start"], facets["year_end"] = ds_or_dict.time.isel(time=[0, -1]).dt.year.values elif isinstance(ds_or_dict, dict): for f in [ "bias_adjust_project", "domain", "frequency", "institution", "processing_level", "project", "type", "time", "time_end", "time_start", "variable", ]: facets[f] = ds_or_dict.get(f) else: raise NotImplementedError("Must be a Dataset or dictionary.") if {"time_start", "time_end"}.issubset(facets) and "time" not in facets: if facets["time_start"] == facets["time_end"]: facets["time"] = "-".join([facets["time_start"], facets["time_end"]]) else: facets["time"] = facets["time_start"] str_name = "{variable}_{frequency}_{institution}_{project}_{time}.{suffix}" # Get the string for the name if facets["type"] in name_configurations.keys(): if facets["project"] in name_configurations[facets["type"]].keys(): str_name = name_configurations[facets["type"]][facets["project"]] missing = [] for k, v in facets.items(): if v is None and k in str_name: # only missing if the facets is needed in the name missing.append(k) if missing: raise ValueError(f"The following facets were not found: {' ,'.join(missing)}.") # fill in string with facets return str_name.format(**facets)
[docs] def delayed_write( ds: xr.Dataset, outfile: str | os.PathLike, output_format: str, overwrite: bool, target_chunks: dict | None = None, **kwargs: Any, ) -> dask.delayed.Delayed: """ Stage a Dataset writing job using `dask.delayed` objects. Parameters ---------- ds : xr.Dataset The Dataset to be written. outfile : str or os.PathLike The output file. output_format : {"netcdf", "zarr"} The output format. overwrite : bool Whether to overwrite existing files. Default: False. target_chunks : dict The target chunks for the output file. **kwargs : Any Additional keyword arguments. Returns ------- dask.delayed.delayed The delayed write job. """ # Set correct chunks in encoding options if not kwargs: kwargs = {} kwargs["encoding"] = {} try: for name, da in ds.data_vars.items(): chunks = [] for dim in da.dims: if target_chunks: if dim in target_chunks.keys(): chunks.append(target_chunks[str(dim)]) else: chunks.append(len(da[dim])) if output_format == "netcdf": kwargs["encoding"][name] = { "chunksizes": chunks, "zlib": True, } kwargs["compute"] = False if Path(outfile).exists() and not overwrite: kwargs["mode"] = "a" elif output_format == "zarr": ds = ds.chunk(target_chunks) if "append_dim" not in kwargs.keys(): kwargs["encoding"][name] = { "chunks": chunks, } kwargs["compute"] = False if overwrite: kwargs["mode"] = "w" if kwargs["encoding"]: if "append_dim" not in kwargs.keys(): kwargs["encoding"]["time"] = {"dtype": "int32"} except KeyError: logger.error("Unable to encode chunks. Verify dataset.") raise return getattr(ds, f"to_{output_format}")(outfile, **kwargs)
[docs] def get_time_attrs( file_or_dataset: str | os.PathLike[str] | xr.Dataset, ) -> tuple[str, int]: """ Determine attributes related to time dimensions. Parameters ---------- file_or_dataset : str or os.PathLike or xr.Dataset The file or dataset to be examined. Returns ------- tuple The calendar and time. """ if isinstance(file_or_dataset, (str, Path)): ds = xr.open_dataset(Path(file_or_dataset).expanduser()) else: ds = cast(xr.Dataset, file_or_dataset) calendar = ds.time.dt.calendar time = len(ds.time) return calendar, time
[docs] def get_global_attrs( file_or_dataset: str | os.PathLike[str] | xr.Dataset, ) -> dict[str, str | int]: """ Collect global attributes from NetCDF, Zarr, or Dataset object. Parameters ---------- file_or_dataset : str or os.PathLike or xr.Dataset The file or dataset to be examined. Returns ------- dict The global attributes. """ if isinstance(file_or_dataset, (str, Path)): file = Path(file_or_dataset).expanduser() elif isinstance(file_or_dataset, xr.Dataset): file = file_or_dataset else: raise NotImplementedError(f"Type: `{type(file_or_dataset)}`.") data = {} if isinstance(file, Path): if file.is_file() and file.suffix in [".nc", ".nc4"]: if HAS_NETCDF4: import netCDF4 with netCDF4.Dataset(file, mode="r") as ds: for k in ds.ncattrs(): data[k] = getattr(ds, k) else: with h5netcdf.File(file, mode="r") as ds: for k in ds.attrs: data[k] = ds.attrs[k] elif file.is_dir() and file.suffix == ".zarr": raise NotImplementedError("Zarr v3 not yet supported.") else: data.update(file.attrs) return data
[docs] def sort_variables(files: list[str | os.PathLike[str] | Path], variables: Sequence[str] | None) -> dict[str, list[Path]]: """ Sort all variables within supplied files for treatment. Parameters ---------- files : list of str or os.PathLike or Path The files to be sorted. variables : sequence of str, optional The variables to be sorted. If not provided, all variables will be grouped. Returns ------- dict[str, list[Path]] Files sorted by variables. """ variable_sorted = {} if variables: logger.info("Sorting variables into groups. This could take some time.") for variable in variables: var_group = [Path(file) for file in files if Path(file).name.startswith(variable)] if not var_group: msg = f"No files found for {variable}. Continuing..." logger.warning(msg) continue variable_sorted[variable] = sorted(var_group) else: variable_sorted["all_variables"] = files return variable_sorted
[docs] def get_chunks_on_disk(file: str | os.PathLike[str] | Path) -> dict[str, int]: """ Determine the chunks on disk for a given NetCDF or Zarr file. Parameters ---------- file : str or os.PathLike or Path File to be examined. Supports NetCDF and Zarr. Returns ------- dict The chunks on disk. """ chunks = {} file = Path(file) if file.suffix.lower() in [".nc", ".nc4"]: if HAS_NETCDF4: import netCDF4 with netCDF4.Dataset(file) as ds: for v in ds.variables: chunks[v] = {} for ii, dim in enumerate(ds[v].dimensions): if ds[v].chunking(): chunks[v][dim] = ds[v].chunking()[ii] else: with h5netcdf.File(file, mode="r") as ds: data = {} for k in ds.attrs: data[k] = ds.attrs[k] elif file.suffix.lower() == "zarr" and file.is_dir(): raise NotImplementedError("Zarr v3 not yet supported.") else: raise NotImplementedError(f"File type: {file.suffix}.") return chunks