#!/usr/bin/env python
"""
dataset.py
Written by Tyler Sutterley (05/2026)
An xarray.Dataset extension for tidal model data
PYTHON DEPENDENCIES:
numpy: Scientific Computing Tools For Python
https://numpy.org
https://numpy.org/doc/stable/user/numpy-for-matlab-users.html
pint: Python package to define, operate and manipulate physical quantities
https://pypi.org/project/Pint/
https://pint.readthedocs.io/en/stable
pyproj: Python interface to PROJ library
https://pypi.org/project/pyproj/
https://pyproj4.github.io/pyproj/
scipy: Scientific Tools for Python
https://docs.scipy.org/doc/
xarray: N-D labeled arrays and datasets in Python
https://docs.xarray.dev/en/stable/
UPDATE HISTORY:
Updated 05/2026: added parameters to allow for extrapolation with
inverse distance weighting (IDW) in addition to nearest-neighbors (NN)
Updated 04/2026: add barycentric interpolation for unstructured grids
add support for unstructured (e.g. finite element) grids
added function to calculate the high and low peaks of a prediction
added function to try to convert units into a pint-friendly format
added combine_attrs to merge conflicts into a list
Updated 03/2026: allow caching of the kd-tree for extrapolation
Updated 02/2026: create subaccessor registration functions
add functions to test if units are compatible with known groups
Updated 01/2026: handle scalar inputs for coordinate transformations
Updated 12/2025: add coords functions to transform coordinates
set units attribute for amplitude and phase data arrays
add functions for assigning coordinates to datasets
Updated 11/2025: get crs directly using pyproj.CRS.from_user_input
set variable name to constituent for to_dataarray method
added is_global property for models covering a global domain
added pad function to pad global datasets along boundaries
added inpaint function to fill missing data in datasets
Updated 09/2025: added argument to limit the list of constituents
when converting to an xarray DataArray
Written 08/2025
"""
import re
import pint
import pyproj
import warnings
import numpy as np
import xarray as xr
from typing import Any
from xarray.core.utils import equivalent
# suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
__all__ = [
"DataTree",
"Dataset",
"DataArray",
"combine_attrs",
"equivalent_attrs",
"register_datatree_subaccessor",
"register_dataset_subaccessor",
"register_dataarray_subaccessor",
"_transform",
"_coords",
]
# pint unit registry
__ureg__ = pint.UnitRegistry()
# default units for pyTMD outputs
_default_units = {
"elevation": "m",
"current": "cm/s",
"transport": "m^2/s",
}
[docs]
@xr.register_datatree_accessor("tmd")
class DataTree:
"""Accessor for extending an ``xarray.DataTree`` for tidal model data"""
def __init__(self, dtree):
# initialize DataTree
self._dtree = dtree
[docs]
def assign_coords(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Assign new coordinates to the ``DataTree``
Parameters
----------
x: np.ndarray
Updated x-coordinates
y: np.ndarray
Updated y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of coordinates
kwargs: dict
Keyword arguments for ``xarray.Dataset.assign_coords``
Returns
-------
dtree: xarray.DataTree
``DataTree`` with updated coordinates
"""
# assign new coordinates to each dataset
dtree = self._dtree.copy()
for key, ds in self._dtree.items():
ds = ds.to_dataset().assign_coords(dict(x=x, y=y), **kwargs)
ds.attrs["crs"] = crs
dtree[key] = ds
# return the datatree
return dtree
[docs]
def coords_as(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Transform coordinates into ``DataArrays`` in the ``DataTree``
coordinate reference system
Parameters
----------
x: np.ndarray
Input x-coordinates
y: np.ndarray
Input y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of input coordinates
Returns
-------
X: xarray.DataArray
Transformed x-coordinates
Y: xarray.DataArray
Transformed y-coordinates
"""
# convert coordinate reference system to that of the datatree
# and format as xarray DataArray with appropriate dimensions
X, Y = _coords(x, y, source_crs=crs, target_crs=self.crs, **kwargs)
# return the transformed coordinates
return X, Y
[docs]
def crop(self, *args, **kwargs):
"""
Crop ``DataTree`` to input bounding box
"""
# create copy of datatree
dtree = self._dtree.copy()
# crop each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.tmd.crop(*args, **kwargs)
# return the datatree
return dtree
[docs]
def inpaint(self, **kwargs):
"""
Inpaint over missing data in ``DataTree``
"""
# create copy of datatree
dtree = self._dtree.copy()
# inpaint each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.tmd.inpaint(**kwargs)
# return the datatree
return dtree
[docs]
def interp(
self,
x: np.ndarray,
y: np.ndarray,
**kwargs,
):
"""
Interpolate ``DataTree`` to new coordinates
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
"""
# create copy of datatree
dtree = self._dtree.copy()
# interpolate each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.tmd.interp(x, y, **kwargs)
# return the datatree
return dtree
[docs]
def subset(self, c: str | list):
"""
Reduce to a subset of constituents
Parameters
----------
c: str or list
List of constituents names
"""
# create copy of datatree
dtree = self._dtree.copy()
# subset each dataset in the datatree
for key, ds in dtree.items():
ds = ds.to_dataset()
dtree[key] = ds.tmd.subset(c)
# return the datatree
return dtree
[docs]
def to_ellipse(self, **kwargs):
"""
Expresses tidal currents in terms of four ellipse parameters
Returns
-------
dtree: xr.DataTree
``DataTree`` containing:
- ``major``: amplitude of the semi-major axis
- ``minor``: amplitude of the semi-minor axis
- ``incl``: angle of inclination of the northern semi-major axis
- ``phase``: phase lag of the current behind the tidal potential
"""
from pyTMD.ellipse import ellipse
# get u and v components from datatree
dsu = (
self._dtree.get("u", None) or self._dtree.get("U", None)
).to_dataset()
dsv = (
self._dtree.get("v", None) or self._dtree.get("V", None)
).to_dataset()
# calculate ellipse parameters for each constituent
dmajor = xr.Dataset()
dminor = xr.Dataset()
dincl = xr.Dataset()
dphase = xr.Dataset()
# for each constituent in the u-component
for c in dsu.tmd.constituents:
# assert units between datasets are the same
if dsu[c].attrs.get("units", "") != dsv[c].attrs.get("units", ""):
raise ValueError(
f"Incompatible units for {c} in u and v datasets"
)
# calculate ellipse parameters
major, minor, incl, phase = ellipse(dsu[c].values, dsv[c].values)
# create xarray DataArray for ellipse parameters
dmajor[c] = xr.DataArray(major, dims=dsu[c].dims, coords=dsu.coords)
dminor[c] = xr.DataArray(minor, dims=dsu[c].dims, coords=dsu.coords)
dincl[c] = xr.DataArray(incl, dims=dsu[c].dims, coords=dsu.coords)
dphase[c] = xr.DataArray(phase, dims=dsu[c].dims, coords=dsu.coords)
# add attributes to each variable
dmajor[c].attrs["units"] = dsu[c].attrs.get("units", "")
dminor[c].attrs["units"] = dsu[c].attrs.get("units", "")
dincl[c].attrs["units"] = "degrees"
dphase[c].attrs["units"] = "degrees"
# create output datatree
dtree = xr.DataTree()
# add datasets to output datatree
dtree["major"] = dmajor
dtree["minor"] = dminor
dtree["incl"] = dincl
dtree["phase"] = dphase
# return datatree
return dtree
[docs]
def from_ellipse(self, **kwargs):
"""
Calculates tidal currents from the four ellipse parameters
- ``major``: amplitude of the semi-major axis
- ``minor``: amplitude of the semi-minor axis
- ``incl``: angle of inclination of the northern semi-major axis
- ``phase``: phase lag of the current behind the tidal potential
Returns
-------
dtree: xr.DataTree
``DataTree`` containing transports or currents
"""
from pyTMD.ellipse import inverse
# get ellipse parameters from datatree
dmajor = self._dtree["major"].to_dataset()
dminor = self._dtree["minor"].to_dataset()
dincl = self._dtree["incl"].to_dataset()
dphase = self._dtree["phase"].to_dataset()
# calculate currents for each constituent
dsu = xr.Dataset()
dsv = xr.Dataset()
# for each constituent in the major parameter
for c in dmajor.tmd.constituents:
# calculate ellipse parameters
u, v = inverse(
dmajor[c].values,
dminor[c].values,
dincl[c].values,
dphase[c].values,
)
# create xarray DataArray for ellipse parameters
dsu[c] = xr.DataArray(u, dims=dmajor[c].dims, coords=dmajor.coords)
dsv[c] = xr.DataArray(v, dims=dmajor[c].dims, coords=dmajor.coords)
# add attributes to each variable
dsu[c].attrs["units"] = dmajor[c].attrs.get("units", "")
dsv[c].attrs["units"] = dmajor[c].attrs.get("units", "")
if dmajor[c].tmd.group == "current":
ukey, vkey = "u", "v"
elif dmajor[c].tmd.group == "transport":
ukey, vkey = "U", "V"
# create output datatree
dtree = xr.DataTree()
# add datasets to output datatree
dtree[ukey] = dsu
dtree[vkey] = dsv
# return the datatree
return dtree
@property
def crs(self):
"""Coordinate reference system of the ``DataTree``"""
# inherit CRS from one of the datasets
for key, ds in self._dtree.items():
ds = ds.to_dataset()
return ds.tmd.crs
[docs]
@xr.register_dataset_accessor("tmd")
class Dataset:
"""Accessor for extending an ``xarray.Dataset`` for tidal model data"""
def __init__(self, ds):
# initialize Dataset
self._ds = ds
[docs]
def to_dataarray(self, **kwargs):
"""
Converts ``Dataset`` to a ``DataArray`` with constituents as a dimension
"""
kwargs.setdefault("constituents", self.constituents)
# reduce dataset to constituents and convert to dataarray
da = self._ds[kwargs["constituents"]].to_dataarray(dim="constituent")
# stack constituents as the last dimension
da = da.transpose(*da.dims[1:], da.dims[0])
da = da.assign_coords(constituent=kwargs["constituents"])
return da
[docs]
def assign_coords(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Assign new coordinates to the ``Dataset``
Parameters
----------
x: np.ndarray
Updated x-coordinates
y: np.ndarray
Updated y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of coordinates
kwargs: dict
Keyword arguments for ``xarray.Dataset.assign_coords``
Returns
-------
ds: xarray.Dataset
``Dataset`` with updated coordinates
"""
# assign new coordinates to dataset
ds = self._ds.assign_coords(dict(x=x, y=y), **kwargs)
ds.attrs["crs"] = crs
# return the dataset
return ds
[docs]
def barycentric_interp(
self,
x: np.ndarray,
y: np.ndarray,
**kwargs,
):
"""
Interpolate unstructured ``Datasets`` using a barycentric
method with first or second order triangular finite elements
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
order: int
Polynomial order of the triangular elements
- ``1``: linear
- ``2``: quadratic
cutoff: int or float, default np.inf
Maximum distance to check for elements
Returns
-------
other: xarray.Dataset
Interpolated ``Dataset``
"""
# import barycentric interpolation functions
from pyTMD.interpolate import (
_to_barycentric,
_inside_triangle,
_shape_functions,
_winding_number,
)
# get the polynomial order of the finite elements
order = self._ds["element"].attrs.get("order", 1)
# default order is same as the tide model
kwargs.setdefault("order", order)
# get cutoff distance to crop elements to bounding box
cutoff = kwargs.get("cutoff", np.inf)
# crop dataset to bounding box of other dataset plus buffer
if np.isfinite(cutoff) and self.crs.is_geographic:
# use twice the cutoff distance as a buffer
cutoff_km = cutoff * __ureg__.parse_units("km")
a_axis = 6378.137 * __ureg__.parse_units("km")
buffer = 2.0 * (cutoff_km / a_axis).to(self.axis_units).magnitude
# bounds of interpolation coordinates
bounds = [np.min(x), np.max(x), np.min(y), np.max(y)]
# crop dataset to bounding box of other dataset plus buffer
ds = self.crop(bounds=bounds, buffer=buffer)
elif np.isfinite(cutoff):
# use twice the cutoff distance as a buffer
cutoff_km = cutoff * __ureg__.parse_units("km")
buffer = 2.0 * cutoff_km.to(self.axis_units).magnitude
# bounds of interpolation coordinates
bounds = [np.min(x), np.max(x), np.min(y), np.max(y)]
# crop dataset to bounding box of other dataset plus buffer
ds = self.crop(bounds=bounds, buffer=buffer)
else:
# copy dataset without cropping
ds = self._ds.copy()
# allocate for barycentric coordinates
xi = xr.full_like(x, np.nan)
eta = xr.full_like(x, np.nan)
null_points = xi.isnull()
# allocate for indices of valid elements
element = xr.zeros_like(x, dtype="i")
# find the valid elements and barycentric coordinates
for i, elem in enumerate(ds.element):
# x and y coordinates of element vertices
x_elem = ds.x.isel(element=i).drop_vars("element")
y_elem = ds.y.isel(element=i).drop_vars("element")
# copy x-coordinates to not affect outside array
xtmp = x.copy(deep=False)
# if model is geographic:
# check if element crosses a meridian
if self.crs.is_geographic:
# calculate winding number of triangle element
# negative winding numbers are clockwise
wind = _winding_number(x_elem, y_elem)
# shift coordinates for meridian crossings
if (wind < 0) & (x_elem < 0.0).any():
# adjust points to be 0:360
x_elem = x_elem.where(x_elem >= 0.0, x_elem + 360.0)
xtmp = xtmp.where(xtmp >= 0, xtmp + 360.0)
elif (wind < 0) & (x_elem > 180.0).any():
# adjust points to be -180:180
x_elem = x_elem.where(x_elem <= 180.0, x_elem - 360.0)
xtmp = xtmp.where(xtmp <= 180.0, xtmp - 360.0)
# convert model coordinates to barycentric
xi_elem, eta_elem = _to_barycentric(x_elem, y_elem, xtmp, y)
# drop dimensions
xi_elem = xi_elem.drop_vars("vertex", errors="ignore")
eta_elem = eta_elem.drop_vars("vertex", errors="ignore")
# determine if points are within element and need values
inside_element = _inside_triangle(xi_elem, eta_elem)
# skip if nothing is inside the element
if not np.any(inside_element & null_points):
continue
# save barycentric coordinates and indices
update_element = np.logical_not(inside_element & null_points)
xi = xi.where(update_element, xi_elem, drop=False)
eta = eta.where(update_element, eta_elem, drop=False)
element = element.where(update_element, i, drop=False)
# can quit search if all interpolation points have values
null_points = xi.isnull()
if not null_points.any():
break
# get shape functions and convert to DataArray
N = _shape_functions(xi, eta, kwargs["order"])
beta = xr.concat(N, dim="node")
# allocate for output dataset
other = xr.Dataset()
# copy attributes
for att_name, att_val in self._ds.attrs.items():
other.attrs[att_name] = att_val
# iterate over variables in dataset
for i, v in enumerate(ds.data_vars.keys()):
# tide model variable for valid elements
var = ds[v].isel(element=element)
# calculate dot product over elements and nodes
other[v] = var.dot(beta, dim="node")
# copy variable attributes
for att_name, att_val in self._ds[v].attrs.items():
other[v].attrs[att_name] = att_val
# add coordinates to output dataset
other.coords["x"] = x
other.coords["y"] = y
# return the interpolated dataset
# drop empty vertex coordinates
return other.drop_vars("vertex", errors="ignore").compute()
[docs]
def coords_as(
self,
x: np.ndarray,
y: np.ndarray,
crs: str | int | dict = 4326,
**kwargs,
):
"""
Transform coordinates into ``DataArrays`` in the ``Dataset``
coordinate reference system
Parameters
----------
x: np.ndarray
Input x-coordinates
y: np.ndarray
Input y-coordinates
crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of input coordinates
Returns
-------
X: xarray.DataArray
Transformed x-coordinates
Y: xarray.DataArray
Transformed y-coordinates
"""
# convert coordinate reference system to that of the dataset
# and format as xarray DataArray with appropriate dimensions
X, Y = _coords(x, y, source_crs=crs, target_crs=self.crs, **kwargs)
# return the transformed coordinates
return X, Y
[docs]
def crop(
self,
bounds: list | tuple,
buffer: int | float = 0,
):
"""
Crop ``Dataset`` to input bounding box
Parameters
----------
bounds: list, tuple
Bounding box ``[min_x, max_x, min_y, max_y]``
buffer: int or float, default 0
Buffer to add to bounds for cropping
"""
# pad global grids along x-dimension (if necessary)
lon_wrap = self.crs.to_dict().get("lon_wrap", 0)
if self.grid_type == "unstructured":
# copy unstructured dataset
ds = self._ds.copy()
elif self.is_global and (lon_wrap == 180) and (np.min(bounds[:2]) < 0):
# number of points to pad for global grids
n = int(180 // (self._x[1] - self._x[0]))
ds = self.pad(n=(n, 0))
elif self.is_global and (lon_wrap == 0) and (np.max(bounds[:2]) > 180):
# number of points to pad for global grids
n = int(180 // (self._x[1] - self._x[0]))
ds = self.pad(n=(0, n))
else:
# copy dataset
ds = self._ds.copy()
# check if chunks are present
if hasattr(ds, "chunks") and ds.chunks is not None:
ds = ds.chunk(-1).compute()
# unpack bounds and buffer
xmin = bounds[0] - buffer
xmax = bounds[1] + buffer
ymin = bounds[2] - buffer
ymax = bounds[3] + buffer
# crop dataset to bounding box
if self.grid_type == "unstructured":
# crop unstructured datasets
# include elements that cross the bounding box
ds = ds.where(
(ds.x.max(dim="vertex") >= xmin)
& (ds.x.min(dim="vertex") <= xmax)
& (ds.y.max(dim="vertex") >= ymin)
& (ds.y.min(dim="vertex") <= ymax),
drop=True,
)
else:
# crop gridded datasets
ds = ds.where(
(ds.x >= xmin)
& (ds.x <= xmax)
& (ds.y >= ymin)
& (ds.y <= ymax),
drop=True,
)
# return the cropped dataset
return ds
[docs]
def grid_interp(
self,
x: np.ndarray,
y: np.ndarray,
method="linear",
**kwargs,
):
"""
Interpolate a regular or rectilinear ``Dataset`` to new coordinates
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
method: str, default 'linear'
Interpolation method
Returns
-------
other: xarray.Dataset
Interpolated ``Dataset``
"""
# pad global grids along x-dimension (if necessary)
if self.is_global:
self._ds = self.pad(n=1)
# verify longitudinal convention for geographic models
if self.crs.is_geographic:
# grid spacing in x-direction
dx = self._x[1] - self._x[0]
# adjust input longitudes to be consistent with model
if (np.min(x) < 0.0) & (self._x.max() > (180.0 + dx)):
# input points convention (-180:180)
# tide model convention (0:360)
x = xr.where(x < 0.0, x + 360.0, x)
elif (np.max(x) > 180.0) & (self._x.min() < (0.0 - dx)):
# input points convention (0:360)
# tide model convention (-180:180)
x = xr.where(x > 180.0, x - 360.0, x)
# interpolate dataset using built-in xarray methods
other = self._ds.interp(x=x, y=y, method=method)
# return xarray dataset
return other
[docs]
def infer(self, t: float | np.ndarray, **kwargs):
"""
Infer minor tides from ``Dataset`` at times
Parameters
----------
t: float or np.ndarray
Days relative to 1992-01-01T00:00:00 UTC
kwargs: dict
Keyword arguments for :py:func:`pyTMD.predict.infer_minor`
Returns
-------
darr: xarray.DataArray
Predicted tides
"""
from pyTMD.predict import infer_minor
# infer minor tides at times
darr = infer_minor(t, self._ds, **kwargs)
# return the inferred tides
return darr
[docs]
def inpaint(self, **kwargs):
"""
Inpaint over missing data in ``Dataset``
Parameters
----------
kwargs: dict
Keyword arguments for :py:func:`pyTMD.interpolate.inpaint`
Returns
-------
ds: xarray.Dataset
Interpolated ``Dataset``
"""
# import inpaint function
from pyTMD.interpolate import inpaint
# create copy of dataset
ds = self._ds.copy()
# inpaint each variable in the dataset
for v in ds.data_vars.keys():
ds[v].values = inpaint(
self._x, self._y, self._ds[v].values, **kwargs
)
# return the dataset
return ds
[docs]
def interp(
self,
x: np.ndarray,
y: np.ndarray,
**kwargs,
):
"""
Interpolate ``Dataset`` to new coordinates
Parameters
----------
x: np.ndarray
Interpolation x-coordinates
y: np.ndarray
Interpolation y-coordinates
extrapolate: bool, default False
Spatially extrapolate values beyond model domain
cutoff: int or float, default np.inf
Maximum distance for extrapolation
k: int, default 1
Number of nearest neighbors to use for extrapolation
- ``1``: nearest neighbor (NN)
- ``>1``: inverse distance weighting (IDW)
power: int or float, default 2
Power parameter for inverse distance weighting extrapolation
workers: int, default 1
Number of workers to use for parallel extrapolation
kwargs: dict
Keyword arguments for interpolation functions
Returns
-------
other: xarray.Dataset
Interpolated ``Dataset``
"""
# set default keyword arguments
kwargs.setdefault("method", "linear")
kwargs.setdefault("extrapolate", False)
# check if interpolating from a grid or mesh
if self.grid_type == "unstructured":
# use barycentric interpolation if data is unstructured
other = self.barycentric_interp(x, y, **kwargs)
else:
# use built-in xarray interpolation methods
other = self.grid_interp(x, y, **kwargs)
# extrapolate using nearest-neighbors or inverse distance weighting
if kwargs["extrapolate"]:
# get extrapolation parameters
# maximum distance for extrapolation
cutoff = kwargs.get("cutoff", np.inf)
# number of nearest neighbors to use for extrapolation
k = kwargs.get("k", 1)
# power parameter for inverse distance weighting extrapolation
power = kwargs.get("power", 2)
# number of workers to use for parallel extrapolation
workers = kwargs.get("workers", 1)
# use extrapolation (NN or IDW) to fill in missing values
other = self.extrap_like(
other, k=k, cutoff=cutoff, power=power, workers=workers
)
# return xarray dataset
return other
[docs]
def node_equilibrium(self):
"""
Compute the equilibrium amplitude and phase of the 18.6 year
node tide :cite:p:`Cartwright:1971iz,Cartwright:1973em`
"""
# copy dataset
ds = self._ds.copy()
# Cartwright and Edden potential amplitude
amajor = 0.027929 # node
# Love numbers for long-period tides (Wahr, 1981)
k2 = 0.299
h2 = 0.606
# tilt factor: response with respect to the solid earth
gamma_2 = 1.0 + k2 - h2
# check dimensions
if (ds.x.ndim == 1) and (ds.y.ndim == 1):
# 2D grid of coordinates
x, y = np.meshgrid(self._x, self._y)
else:
x, y = ds.x.values, ds.y.values
# transform model coordinates to lat/lon coordinates
lon, lat = _transform(
x, y, source_crs=self.crs, target_crs=4326, direction="FORWARD"
)
# colatitude in radians
th = np.radians(90.0 - lat)
# 2nd degree Legendre polynomials
P20 = 0.5 * (3.0 * np.cos(th) ** 2 - 1.0)
# normalization for spherical harmonics
dfactor = np.sqrt((4.0 + 1.0) / (4.0 * np.pi))
# calculate equilibrium node constants
hc = dfactor * P20 * gamma_2 * amajor * np.exp(-1j * np.pi)
ds["node"] = xr.DataArray(hc, dims=ds.dims, coords=ds.coords)
ds["node"].attrs["units"] = "m"
# return xarray dataset
return ds
[docs]
def pad(
self,
n: int = 1,
chunks=None,
):
"""
Pad ``Dataset`` by repeating edge values in the x-direction
Parameters
----------
n: int, default 1
Number of padding values to add on each side
Returns
-------
ds: xarray.Dataset
Padded ``Dataset``
"""
# (possibly) unchunk x-coordinates and pad to wrap at meridian
x = xr.DataArray(self._x, dims="x").pad(
x=n, mode="reflect", reflect_type="odd"
)
# pad dataset and re-assign x-coordinates
ds = self._ds.copy()
ds = ds.pad(x=n, mode="wrap").assign_coords(x=x)
# rechunk dataset (if specified)
if chunks is not None:
ds = ds.chunk(chunks)
# return the dataset
return ds
[docs]
def predict(self, t: float | np.ndarray, **kwargs):
"""
Predict tides from ``Dataset`` at times
Parameters
----------
t: float or np.ndarray
Days relative to 1992-01-01T00:00:00 UTC
kwargs: dict
Keyword arguments for :py:func:`pyTMD.predict.time_series`
Returns
-------
darr: xarray.DataArray
Predicted tides
"""
from pyTMD.predict import time_series
# predict tides at times
darr = time_series(t, self._ds, **kwargs)
# return the predicted tides
return darr
[docs]
def subset(self, c: str | list):
"""
Reduce to a subset of constituents
Parameters
----------
c: str or list
List of constituents names
"""
# create copy of dataset
ds = self._ds.copy()
# if no constituents are specified, return self
# else return reduced dataset
if c is None:
return ds
elif isinstance(c, str):
return ds[[c]]
else:
return ds[c]
[docs]
def to_units(
self,
units: str,
value: float = 1.0,
):
"""Convert ``Dataset`` to specified tide units
Parameters
----------
units: str
Output units
value: float, default 1.0
Scaling factor to apply
"""
# create copy of dataset
ds = self._ds.copy()
# convert each constituent in the dataset
for c in self.constituents:
ds[c] = ds[c].tmd.to_units(units, value=value)
# return the dataset
return ds
[docs]
def to_base_units(self):
"""Convert ``Dataset`` to base units"""
# create copy of dataset
ds = self._ds.copy()
# convert each constituent in the dataset
for c in self.constituents:
ds[c] = ds[c].tmd.to_base_units()
# return the dataset
return ds
[docs]
def to_default_units(self):
"""Convert ``Dataset`` to default tide units"""
# create copy of dataset
ds = self._ds.copy()
# convert each constituent in the dataset
for c in self.constituents:
ds[c] = ds[c].tmd.to_default_units()
# return the dataset
return ds
@property
def constituents(self):
"""List of tidal constituent names in the ``Dataset``"""
# import constituents parser
from pyTMD.constituents import _parse_name
# output list of tidal constituents
cons = []
# parse list of model constituents
for i, c in enumerate(self._ds.data_vars.keys()):
try:
cons.append(_parse_name(c))
except ValueError:
pass
# return list of constituents
return cons
@property
def crs(self):
"""Coordinate reference system of the ``Dataset``"""
# return the CRS of the dataset
# default is EPSG:4326 (WGS84)
CRS = self._ds.attrs.get("crs", 4326)
return pyproj.CRS.from_user_input(CRS)
@property
def is_global(self) -> bool:
"""Determine if ``Dataset`` covers a global domain"""
# grid spacing in x-direction
dx = self._x[1] - self._x[0]
# check if global grid
cyclic = np.isclose(self._x[-1] - self._x[0], 360.0 - dx)
return self.crs.is_geographic and cyclic
@property
def area_of_use(self) -> str | None:
"""Area of use from the ``Dataset`` CRS"""
if self.crs.area_of_use is not None:
return self.crs.area_of_use.name.replace(".", "").lower()
@property
def axis_units(self) -> str:
"""Units of the coordinate axes from the ``Dataset`` CRS"""
return self.crs.axis_info[0].unit_name
@property
def grid_type(self) -> str:
"""Spatial structure of the ``Dataset``"""
return self._ds.attrs.get("grid_type", "grid")
@property
def _x(self):
"""x-coordinates of the ``Dataset``"""
return self._ds.x.values
@property
def _y(self):
"""y-coordinates of the ``Dataset``"""
return self._ds.y.values
[docs]
@xr.register_dataarray_accessor("tmd")
class DataArray:
"""Accessor for extending an ``xarray.DataArray`` for tidal model data"""
def __init__(self, da):
# initialize DataArray
self._da = da
@property
def amplitude(self):
"""
Calculate the amplitude of a tide model constituent
Returns
-------
amp: xarray.DataArray
Tide model constituent amplitude
"""
# calculate constituent amplitude
amp = np.sqrt(self._da.real**2 + self._da.imag**2)
amp.attrs["units"] = self._da.attrs.get("units", "")
return amp
@property
def phase(self):
"""
Calculate the phase of a tide model constituent
Returns
-------
ph: xarray.DataArray
Tide model constituent phase (degrees)
"""
# calculate constituent phase and convert to degrees
ph = np.degrees(np.arctan2(-self._da.imag, self._da.real))
ph = ph.where(ph >= 0, ph + 360.0, drop=False)
ph.attrs["units"] = "degrees"
return ph
[docs]
def find_peaks(self, **kwargs):
"""
Find peaks in the ``DataArray``
Parameters
----------
kwargs: dict
Keyword arguments for ``xarray.DataArray.differentiate``
Returns
-------
high_peaks: xarray.DataArray
Boolean array indicating locations of high tide peaks
low_peaks: xarray.DataArray
Boolean array indicating locations of low tide peaks
"""
# differentiate to calculate high and low tides
diff = self._da.differentiate("time", **kwargs)
# look for zero crossings in the derivative to find peaks
# compare the sign of the derivative with the next time step
sign = np.sign(diff)
next_sign = sign.shift(time=-1)
# get the zero crossings to find the high and low tides
high_peaks = (sign >= 0) & (next_sign < 0)
low_peaks = (sign <= 0) & (next_sign > 0)
# return the peaks
return (high_peaks, low_peaks)
[docs]
def to_units(
self,
units: str,
value: float = 1.0,
):
"""Convert ``DataArray`` to specified tide units
Parameters
----------
units: str
Output units
value: float, default 1.0
Scaling factor to apply
"""
# convert to specified units
conversion = value * self.quantity.to(units)
da = self._da * conversion.magnitude
da.attrs["units"] = str(conversion.units)
return da
[docs]
def to_base_units(self, value=1.0):
"""Convert ``DataArray`` to base units
Parameters
----------
value: float, default 1.0
Scaling factor to apply
"""
# convert to base units
conversion = value * self.quantity.to_base_units()
da = self._da * conversion.magnitude
da.attrs["units"] = str(conversion.units)
return da
[docs]
def to_default_units(self, value=1.0):
"""Convert ``DataArray`` to default tide units
Parameters
----------
value: float, default 1.0
Scaling factor to apply
"""
# convert to default units
default_units = _default_units.get(self.group, self.units)
da = self.to_units(default_units, value=value)
return da
@property
def units(self):
"""Units of the ``DataArray``"""
try:
return self._parse_units(self._units)
except TypeError as exc:
raise ValueError(f"Unknown units: {self._units}") from exc
except AttributeError as exc:
raise AttributeError("DataArray has no attribute 'units'") from exc
@property
def quantity(self):
"""``Pint`` Quantity of the ``DataArray``"""
return 1.0 * self.units
@property
def group(self):
"""Variable group of the ``DataArray``"""
if self.units.is_compatible_with("m"):
return "elevation"
elif self.units.is_compatible_with("m/s"):
return "current"
elif self.units.is_compatible_with("m^2/s"):
return "transport"
elif self.units.is_compatible_with("m/s^2"):
return "acceleration"
elif self.units.is_compatible_with("degrees"):
return "angle"
else:
raise ValueError(f"Unknown unit group: {self._units}")
@staticmethod
def _parse_units(units: str):
"""
Convert units attributes to ``pint`` units
"""
# fix the exponent notation in units string
units = re.sub(
r"(\w)([-]?\d+)",
lambda m: m.group(1) + r"^" + m.group(2),
units,
flags=re.IGNORECASE,
)
# parse units string using pint
return __ureg__.parse_units(units.lower())
@property
def _units(self):
"""Units attribute of the ``DataArray`` as a string"""
return self._da.attrs.get("units")
@property
def _has_compatible_units(self):
"""Tests that units are compatible with known groups"""
try:
unit_group = self.group
except (TypeError, ValueError, AttributeError) as exc:
return False
else:
return True
[docs]
def combine_attrs(
attrs_list: list[dict],
context: str | None,
**kwargs,
) -> dict:
"""
Combine attributes from multiple datasets into a single dictionary
merging conflicting values into a list
Parameters
----------
attrs_list: list of dict
List of attribute dictionaries from multiple datasets
context: str
Context for the attributes being combined
skip_keys: list of str, default ["units"]
List of attribute keys to skip from comparison
Returns
-------
result: dict
Combined attributes dictionary
"""
# set default keyword arguments
skip_keys = kwargs.get("skip_keys", ["units"])
# return an empty dictionary when no attributes are provided
if not attrs_list:
return {}
# initialize combined attributes with the first dictionary in the list
result = attrs_list[0].copy()
append_keys = set()
# for each attribute key, check if values are equivalent
for attrs in attrs_list:
for key, value in attrs.items():
# skip keys that have already been identified as conflicts
# and keys that should be skipped from comparison
if key in append_keys or key in skip_keys:
continue
# check if the attribute values are equivalent
if not equivalent_attrs(result.get(key), value):
append_keys.add(key)
# combine conflicting attributes into lists
for key in append_keys:
# build list of values for this key across all datasets
combined_values = []
for attrs in attrs_list:
# check if the key is present
# if a list or tuple: extend the combined values
# if a single value: append to the combined values
if key in attrs and isinstance(attrs[key], (list, tuple)):
combined_values.extend(attrs[key])
elif key in attrs:
combined_values.append(attrs[key])
# clean up combined results: removing duplicates and null values
result[key] = sorted(set(filter(None, combined_values)))
# if only one unique value remains, simplify back to a single value
if len(result[key]) == 1:
result[key] = result[key].pop()
# return the combined attributes
return result
[docs]
def equivalent_attrs(a: Any, b: Any) -> bool:
"""
Check if two attribute values are equivalent (ignoring case for strings)
Adapted from ``xarray.structure.merge.equivalent_attrs``
Parameters
----------
a: Any
First attribute value
b: Any
Second attribute value
"""
# if both attributes are strings, compare them case-insensitively
if isinstance(a, str) and isinstance(b, str):
return equivalent(a.casefold(), b.casefold())
# otherwise, compare the attributes directly
# exceptions would indicate comparison is ambiguous
try:
return equivalent(a, b)
except (TypeError, ValueError):
return False
[docs]
def register_datatree_subaccessor(name):
"""Register a custom subaccessor on ``DataTree`` objects
Parameters
----------
name: str
Name of the subaccessor
"""
return xr.core.extensions._register_accessor(name, DataTree)
[docs]
def register_dataset_subaccessor(name):
"""Register a custom subaccessor on ``Dataset`` objects
Parameters
----------
name: str
Name of the subaccessor
"""
return xr.core.extensions._register_accessor(name, Dataset)
[docs]
def register_dataarray_subaccessor(name):
"""Register a custom subaccessor on ``DataArray`` objects
Parameters
----------
name: str
Name of the subaccessor
"""
return xr.core.extensions._register_accessor(name, DataArray)
[docs]
def _coords(
x: np.ndarray,
y: np.ndarray,
source_crs: str | int | dict = 4326,
target_crs: str | int | dict = None,
**kwargs,
):
"""
Transform coordinates into DataArrays in a new
coordinate reference system
Parameters
----------
x: np.ndarray
Input x-coordinates
y: np.ndarray
Input y-coordinates
source_crs: str, int, or dict, default 4326 (WGS84 Latitude/Longitude)
Coordinate reference system of input coordinates
target_crs: str, int, or dict, default None
Coordinate reference system of output coordinates
type: str or None, default None
Coordinate data type
If not provided: must specify ``time`` parameter to auto-detect
- ``None``: determined from input variable dimensions
- ``'drift'``: drift buoys or satellite/airborne altimetry
- ``'grid'``: spatial grids or images
- ``'time series'``: time series at a single point
time: np.ndarray or None, default None
Time variable for determining coordinate data type
Returns
-------
X: xarray.DataArray
Transformed x-coordinates
Y: xarray.DataArray
Transformed y-coordinates
"""
from pyTMD.spatial import data_type
# set default keyword arguments
kwargs.setdefault("type", None)
kwargs.setdefault("time", None)
# determine coordinate data type if possible
if (np.ndim(x) == 0) and (np.ndim(y) == 0):
coord_type = "time series"
elif kwargs["type"] is None:
# must provide time variable to determine data type
assert kwargs["time"] is not None, (
"Must provide time parameter when type is not specified"
)
coord_type = data_type(x, y, np.ravel(kwargs["time"]))
else:
# use provided coordinate data type
# and verify that it is lowercase
coord_type = kwargs.get("type").lower()
# convert coordinates to a new coordinate reference system
if (coord_type == "grid") and (np.size(x) != np.size(y)):
gridx, gridy = np.meshgrid(x, y)
mx, my = _transform(
gridx,
gridy,
source_crs=source_crs,
target_crs=target_crs,
direction="FORWARD",
)
else:
mx, my = _transform(
x,
y,
source_crs=source_crs,
target_crs=target_crs,
direction="FORWARD",
)
# convert to xarray DataArray with appropriate dimensions
if (np.ndim(x) == 0) and (np.ndim(y) == 0):
X = xr.DataArray(mx)
Y = xr.DataArray(my)
elif coord_type == "grid":
X = xr.DataArray(mx, dims=("y", "x"))
Y = xr.DataArray(my, dims=("y", "x"))
elif coord_type == "drift":
X = xr.DataArray(mx, dims=("time"))
Y = xr.DataArray(my, dims=("time"))
elif coord_type == "time series":
X = xr.DataArray(mx, dims=("station"))
Y = xr.DataArray(my, dims=("station"))
else:
raise ValueError(f"Unknown coordinate data type: {coord_type}")
# return the transformed coordinates
return (X, Y)