Source code for miranda.treatments._preprocessing

from __future__ import annotations
from functools import partial
from pathlib import Path
from typing import Any

import numpy as np
import xarray as xr

from miranda.convert.utils import date_parser


__all__ = [
    "correct_time_entries",
    "correct_var_names",
    "preprocessing_corrections",
]


[docs] def correct_time_entries( ds: xr.Dataset, split: str = "_", location: int = -1, field: str = "time", ) -> xr.Dataset: """ Correct time entries in dataset. Parameters ---------- ds : xarray.Dataset split : str location : int field : str Returns ------- xarray.Dataset """ filename = ds.encoding["source"] date = date_parser(Path(filename).stem.split(split)[location]) vals = np.arange(len(ds[field])) days_since = f"days since {date}" time = xr.coding.times.decode_cf_datetime(vals, units=days_since, calendar="standard") ds = ds.assign_coords({field: time}) prev_history = ds.attrs.get("history", "") history = f"Time index recalculated in preprocessing step ({days_since}). {prev_history}" ds.attrs.update(dict(history=history)) return ds
[docs] def correct_var_names(ds: xr.Dataset, split: str = "_", location: int = 0) -> xr.Dataset: """ Correct variable names in dataset. Parameters ---------- ds : xarray.Dataset split : str location : int Returns ------- xarray.Dataset """ filename = ds.encoding["source"] new_name = Path(filename).stem.split(split)[location] old_name = list(ds.data_vars.keys())[0] prev_history = ds.attrs.get("history", "") history = f"Variable renamed in preprocessing step ({old_name}: {new_name}). {prev_history}" ds.attrs.update(dict(history=history)) return ds.rename({old_name: new_name})
[docs] def preprocessing_corrections(ds: xr.Dataset, configuration: dict[str, Any]) -> xr.Dataset: """ Corrections function dispatcher to ensure minimal dataset validity on open. Parameters ---------- ds : xarray.Dataset configuration : dict Returns ------- xarray.Dataset """ def _preprocess_correct(d: xr.Dataset, *, ops: list[partial]) -> xr.Dataset: for correction in ops: d = correction(d) return d correction_fields = configuration.get("_preprocess") if correction_fields: preprocess_ops = [] for field in correction_fields: if field == "_variable_name": preprocess_ops.append(partial(correct_var_names, **correction_fields[field])) if field == "_time": preprocess_ops.append(partial(correct_time_entries, **correction_fields[field])) if preprocess_ops: corrector = partial(_preprocess_correct, ops=preprocess_ops) return corrector(ds) return ds