Source code for miranda.treatments._dimensions

from __future__ import annotations
import logging
import warnings
from typing import Any

import numpy as np
import xarray as xr
from xclim.core.calendar import parse_offset

from miranda.treatments.utils import _get_section_entry_key, _iter_entry_key  # noqa
from miranda.units import check_time_frequency


__all__ = [
    "dimensions_compliance",
    "ensure_correct_time_frequency",
    "find_project_variable_codes",
    "get_daily_snapshot",
    "offset_time_dimension",
]


[docs] def find_project_variable_codes(code: str, configuration: dict[str, Any]) -> str: """ Find the variable code for a given variable name and project. Parameters ---------- code : str Variable name. configuration : dict Configuration dictionary. Returns ------- str """ variable_codes = {} if "variables" not in configuration: raise ValueError("No `variables` section found in configuration. Check JSON.") for variable_code in configuration["variables"]: variable_name = configuration["variables"][variable_code].get("_variable_name") if variable_name: variable_codes[variable_name] = variable_code else: warnings.warn( f"Variable `{variable_code}` does not have accompanying `variable_name`. " f"Verify JSON. Continuing with `{variable_code}` as `variable_name`.", stacklevel=2, ) variable_codes[variable_code] = variable_code if code in variable_codes.values(): variable = code else: variable = variable_codes.get(code) if not variable: raise NotImplementedError(f"Variable `{code}` not supported.") return variable
[docs] def dimensions_compliance(ds: xr.Dataset, project: str, metadata: dict) -> xr.Dataset: """ Rename dimensions to CF to their equivalents and reorder them if needed. Parameters ---------- ds : xarray.Dataset Dataset with dimensions to be updated. project : str Dataset project name. metadata : dict Metadata definition dictionary for project and variable(s). Returns ------- xarray.Dataset """ rename_dims = dict() for dim in ds.dims: if dim in metadata["dimensions"].keys(): cf_name = _get_section_entry_key(metadata, "dimensions", dim, "_cf_dimension_name", project) if cf_name: rename_dims[dim] = cf_name # Rename dimensions _rename_dims = [str(d) for d in rename_dims.keys()] msg = f"Renaming dimensions: {', '.join(_rename_dims)}." logging.info(msg) ds = ds.rename(rename_dims) for new in ["lon", "lat"]: if new == "lon" and "lon" in ds.coords: if np.any(ds.lon > 180): lon1 = ds.lon.where(ds.lon <= 180.0, ds.lon - 360.0) ds[new] = lon1 coord_precision = _get_section_entry_key(metadata, "dimensions", new, "_precision", project) if coord_precision is not None: ds[new] = ds[new].round(coord_precision) # Ensure that lon and lat are written in proper order for plotting purposes logging.info("Reordering dimensions.") transpose_order = [] if "lat" in ds.dims and "lon" in ds.dims: transpose_order = ["lat", "lon"] elif "rlat" in ds.dims and "rlon" in ds.dims: transpose_order = ["rlat", "rlon"] if "time" in ds.dims and transpose_order: transpose_order.insert(0, "time") transpose_order.extend(list(set(ds.dims) - set(transpose_order))) ds = ds.transpose(*transpose_order) ds = ds.sortby(transpose_order) # Add dimension original name and update attrs logging.info("Updating dimension attributes.") dim_descriptions = metadata["dimensions"] for dim in metadata["dimensions"].keys(): cf_name = dim_descriptions[dim].get("_cf_dimension_name") if cf_name is not None and cf_name in ds.dims: ds[cf_name].attrs.update(dict(original_variable=dim)) else: # variable name already follows CF standards cf_name = dim for field in dim_descriptions[dim].keys(): if not field.startswith("_"): ds[cf_name].attrs.update({field: dim_descriptions[dim][field]}) prev_history = ds.attrs.get("history", "") history = f"Transposed and renamed dimensions. {prev_history}" ds.attrs.update(dict(history=history)) return ds
[docs] def ensure_correct_time_frequency(d: xr.Dataset, p: str, m: dict) -> xr.Dataset: """Ensure that time frequency is consistent with expected frequency for project.""" key = "_ensure_correct_time" strict_time = "_strict_time" if "time" not in m["dimensions"].keys(): msg = f"No time corrections listed for project `{p}`. Continuing..." warnings.warn(msg, stacklevel=2) return d if "time" not in list(d.variables.keys()): msg = f"No time dimension among data variables: {' ,'.join([str(v) for v in d.variables.keys()])}. Continuing..." logging.info(msg) return d if key in m["dimensions"]["time"].keys(): freq_found = xr.infer_freq(d.time) if strict_time in m["dimensions"]["time"].keys(): if not freq_found: msg = "Time frequency could not be found. There may be missing timesteps." if m["dimensions"]["time"].get(strict_time): raise ValueError(msg) else: warnings.warn(f"{msg} Continuing...", stacklevel=2) return d correct_time_entry = m["dimensions"]["time"][key] if isinstance(correct_time_entry, str): correct_times = [parse_offset(correct_time_entry)[1]] elif isinstance(correct_time_entry, dict): correct_times = correct_time_entry.get(p) if isinstance(correct_times, list): correct_times = [parse_offset(t)[1] for t in correct_times] if correct_times is None: warnings.warn(f"No expected times set for specified project `{p}`.", stacklevel=2) elif isinstance(correct_time_entry, list): correct_times = correct_time_entry else: warnings.warn("No expected times set for family of projects.", stacklevel=2) return d if freq_found not in correct_times: error_msg = ( f"Time frequency {freq_found} not among allowed frequencies: " f"{', '.join(correct_times) if isinstance(correct_times, list) else correct_times}" ) if isinstance(correct_time_entry, dict): error_msg = f"{error_msg} for project `{p}`." else: error_msg = f"{error_msg}." raise ValueError(error_msg) msg = f"Resampling dataset with time frequency: {freq_found}." logging.info(msg) with xr.set_options(keep_attrs=True): d_out = d.assign_coords(time=d.time.resample(time=freq_found).mean(dim="time").time) d_out.time.attrs.update(d.time.attrs) prev_history = d.attrs.get("history", "") history = f"Resampled time with `freq={freq_found}`. {prev_history}" d_out.attrs.update(dict(history=history)) return d_out return d
[docs] def get_daily_snapshot(d: xr.Dataset, p: str, m: dict) -> xr.Dataset: """ Get a single hour snapshot per day from a sub-daily dataset. Parameters ---------- ds : xarray.Dataset The dataset. snapvalue : int or bool The timestep to extract when int or apply dropna when True Returns ------- xarray.Dataset The dataset with the snapshot hour applied. """ key = "_use_snapshot" processed = False for vv in d.data_vars: if vv in m["variables"].keys(): snapvalue = _get_section_entry_key(m, "variables", vv, key, p) if snapvalue is None: continue processed = True # at least one variable is processed if isinstance(snapvalue, int) and snapvalue is not True: mask_hour = d.time.dt.hour == int(snapvalue) d = d.sel(time=mask_hour) elif snapvalue is True: # if True, we assume data is only available at a single hour each day : # ex. CaSR snow depth is padded with NaNs d = d.dropna(dim="time", how="all") else: raise ValueError(f"Invalid _use_snapshot value: {snapvalue}.") if not processed: msg = f"No snapshot processing needed for any variable in `{p}`." logging.info(msg) return d if xr.infer_freq(d.time) == "D": # "After applying snapshot, the time frequency must be daily." # anchor time on day start d["time"] = d.resample(time="D").mean().time.values d.attrs["frequency"] = "day" return d else: msg = f"After applying snapshot, the time frequency is not daily. Found frequency: {xr.infer_freq(d.time)}." raise ValueError(msg)
[docs] def offset_time_dimension(d: xr.Dataset, p: str, m: dict) -> xr.Dataset: """Offset time dimension using listed frequency.""" key = "_offset_time" d_out = xr.Dataset(coords=d.coords, attrs=d.attrs) converted = [] offset, offset_meaning = None, None time_freq = dict() expected_period = _get_section_entry_key(m, "dimensions", "time", "_ensure_correct_time", p) if isinstance(expected_period, str): time_freq["expected_period"] = expected_period for vv, offs in _iter_entry_key(d, m, "dimensions", key, p): if offs: # Offset time by value of one time-step if offset is None and offset_meaning is None: try: offset, offset_meaning = check_time_frequency(d, **time_freq) except TypeError: msg = "Unable to parse the time frequency. Verify data integrity before retrying." logging.error(msg) raise msg = f"Offsetting data for `{vv}` by `{offset[0]} {offset_meaning}(s)`." logging.info(msg) with xr.set_options(keep_attrs=True): out = d[vv] out["time"] = out.time - np.timedelta64(offset[0], offset[1]) d_out[vv] = out converted.append(vv) elif offs is False: msg = f"No time offsetting needed for `{vv}` in `{p}` (Explicitly set to False)." logging.info(msg) continue prev_history = d.attrs.get("history", "") history = f"Offset variable `{vv}` values by `{offset[0]} {offset_meaning}(s). {prev_history}" d_out.attrs.update(dict(history=history)) # Copy unconverted variables for vv in d.data_vars: if vv not in converted: d_out[vv] = d[vv] return d_out