Source code for miranda.gis.utils
"""Utility functions for GIS operations."""
from __future__ import annotations
import datetime
import logging
import warnings
import numpy as np
import xarray as xr
logger = logging.getLogger("miranda.gis.utils")
__all__ = [
"conservative_regrid",
"threshold_mask",
]
def _simple_fix_dims(d: xr.Dataset | xr.DataArray) -> xr.Dataset | xr.DataArray:
"""
Adjust dimensions found in a file so that it can be used for regridding purposes.
Parameters
----------
d : xr.Dataset or xr.DataArray
The dataset to adjust.
Returns
-------
xr.Dataset or xr.DataArray
The adjusted dataset.
"""
if "lon" not in d.dims or "lat" not in d.dims:
dim_rename = dict()
for dim in d.dims:
if str(dim).lower().startswith("lon"):
dim_rename[str(dim)] = "lon"
if str(dim).lower().startswith("lat"):
dim_rename[str(dim)] = "lat"
d = d.rename(dim_rename)
if np.any(d.lon > 180):
lon_wrapped = d.lon.where(d.lon <= 180.0, d.lon - 360.0)
d["lon"] = lon_wrapped
d = d.sortby(["lon"])
if "time" in d.dims:
d = d.isel(time=0, drop=True)
return d
[docs]
def conservative_regrid(ds: xr.DataArray | xr.Dataset, ref_grid: xr.DataArray | xr.Dataset) -> xr.DataArray | xr.Dataset:
"""
Perform a conservative_normed regridding.
Parameters
----------
ds : xr.DataArray or xr.Dataset
The dataset to regrid.
ref_grid : xr.DataArray or xr.Dataset
The reference grid.
Returns
-------
xr.DataArray or xr.Dataset
The regridded dataset.
"""
try:
import xesmf as xe # noqa
except ModuleNotFoundError:
logger.error("This function requires the `xesmf` library which is not installed. Regridding step will be skipped.")
raise
ref_grid = _simple_fix_dims(ref_grid)
method = "conservative_normed"
msg = f"Performing regridding and masking with `xesmf` using method: {method}."
logging.info(msg)
regridder = xe.Regridder(ds, ref_grid, method, periodic=False)
ds = regridder(ds)
ds.attrs["history"] = f"{datetime.datetime.now()}:Regridded dataset using xesmf with method: {method}. {ds.attrs.get('history')}".strip()
return ds
[docs]
def threshold_mask(
ds: xr.Dataset | xr.DataArray,
*,
mask: xr.Dataset | xr.DataArray,
mask_cutoff: float | bool = False,
) -> xr.Dataset | xr.DataArray:
"""
Land-Sea mask operations.
Parameters
----------
ds : xr.Dataset or str or os.PathLike
The dataset to be masked.
mask : xr.Dataset or xr.DataArray
The land-sea mask.
mask_cutoff : float or bool
The mask cutoff value.
Returns
-------
xr.Dataset or xr.DataArray
The masked dataset.
"""
mask = _simple_fix_dims(mask)
if isinstance(mask, xr.Dataset):
if len(mask.data_vars) == 1:
mask_variable = list(mask.data_vars)[0]
mask = mask[mask_variable]
else:
raise ValueError("More than one data variable found in land-sea mask. Supply a DataArray instead.")
else:
mask_variable = mask.name
try:
from clisops.core import subset_bbox # noqa
log_msg = f"Masking dataset with {mask_variable}."
if mask_cutoff:
log_msg = f"{log_msg.strip('.')} at `{mask_cutoff}` cutoff value."
logging.info(log_msg)
lon_bounds = np.array([ds.lon.min(), ds.lon.max()])
lat_bounds = np.array([ds.lat.min(), ds.lat.max()])
mask_subset = subset_bbox(
mask,
lon_bnds=lon_bounds,
lat_bnds=lat_bounds,
).load()
except ModuleNotFoundError:
log_msg = "This function requires the `clisops` library which is not installed. subsetting step will be skipped."
warnings.warn(log_msg, stacklevel=2)
mask_subset = mask.load()
if mask_subset.dtype == bool:
if mask_cutoff:
logging.warning("Mask value cutoff set for boolean mask. Ignoring.")
mask_subset = mask_subset.where(mask)
else:
mask_subset = mask_subset.where(mask >= mask_cutoff)
ds = ds.where(mask_subset.notnull())
if mask_subset.min() >= 0:
if mask_subset.max() <= 1.00000001:
cutoff_info = f"{mask_cutoff * 100} %"
elif mask_subset.max() <= 100.00000001:
cutoff_info = f"{mask_cutoff} %"
else:
cutoff_info = f"{mask_cutoff}"
else:
cutoff_info = f"{mask_cutoff}"
ds.attrs["mask_cutoff"] = cutoff_info
prev_history = ds.attrs.get("history", "")
history_msg = f"Mask calculated using `{mask_variable}`."
if mask_cutoff:
history_msg = f"{history_msg.strip('.')} with cutoff value `{cutoff_info}`."
history = f"{history_msg} {prev_history}".strip()
ds.attrs.update(dict(history=history))
return ds