Source code for dtcg.datacube.geozarr

"""Copyright 2026 DTCG Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


=====

Functionality for exporting a GeoZarr file.
"""

from __future__ import annotations

from pathlib import Path
from typing import Optional

import numpy as np
import xarray as xr
from numcodecs import Blosc

from dtcg.datacube.update_metadata import MetadataMapper


[docs] class GeoZarrHandler(MetadataMapper):
[docs] def __init__( self: GeoZarrHandler, ds: xr.Dataset = None, ds_name: str = "L1", target_chunk_mb: float = 5.0, compressor: Optional[Blosc] = None, metadata_mapping_data: str = None, metadata_mapping_coords: str = None, zarr_format: int = 2, ): """Initialise a GeoZarrHandler object. Parameters ---------- ds : xarray.DataTree | xarray.Dataset, default None Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y'). Must include coordinate variables. Accepts either a dataset or data tree. data_tree : xarray.DataTree, default None Input data_tree. Either ds or data_tree must be provided. ds_name : str, default 'L1' Name of datacube. target_chunk_mb : float, default 5.0 Approximate chunk size in megabytes for efficient storage. compressor : Blosc, default None Compressor to apply on arrays. If None, the compression will be Blosc with zstd. metadata_mapping_data : str, default None Path to the YAML file containing variable metadata mappings. If None, defaults to 'metadata_mapping_data.yaml' in the current directory. metadata_mapping_coords : str, default None Path to the YAML file containing time coordinate metadata mappings. If None, defaults to 'metadata_mapping_data.yaml' in the current directory. zarr_format : int, default 2 Zarr format version to use (2 or 3). """ super().__init__( metadata_mapping_data=metadata_mapping_data, metadata_mapping_coords=metadata_mapping_coords, ) self.target_chunk_mb = target_chunk_mb self.compressor = compressor or Blosc( cname="zstd", clevel=3, shuffle=Blosc.BITSHUFFLE ) self.zarr_format = zarr_format self.encoding = {} self._set_data(ds=ds, ds_name=ds_name)
def _set_data( self, ds: xr.Dataset | xr.DataTree = None, ds_name: str = "L1" ) -> None: """Validate and set data. Parameters ---------- ds : xarray.DataTree | xarray.Dataset, default None Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y'). Must include coordinate variables. Accepts either a dataset or data tree. ds_name : str, default "L1" Input data_tree. Either ds or data_tree must be provided. """ if ds is None: raise ValueError("No dataset provided.") elif isinstance(ds, xr.Dataset): self.ds_name = ds_name ds = self._validate_dataset(ds) ds = self._update_metadata(ds, ds_name) self._define_encodings(ds, ds_name) # convert dataset to datatree self.data_tree = xr.DataTree.from_dict({ds_name: ds}) elif isinstance(ds, xr.DataTree): # define encodings for potential exporting later on self.data_tree = ds for tree_level in self.data_tree: if tree_level in ["L1"]: self._define_encodings( ds=self.data_tree[tree_level].ds, ds_name=tree_level ) elif "L2" in tree_level or "L3" in tree_level: for datacube_type in self.data_tree[tree_level]: if datacube_type not in [ "monthly", "annual_hydro", "daily_smb", ]: raise ValueError( "We currently only support model output datacubes of " "the types 'monthly', 'annual_hydro' and 'daily_smb'." ) self._define_encodings( ds=self.data_tree[tree_level][datacube_type], ds_name=tree_level, ds_type=datacube_type, ) else: raise TypeError("Dataset should either be an xarray Dataset or DataTree.") def _validate_dataset(self: GeoZarrHandler, ds: xr.Dataset) -> xr.Dataset: """Validate the input dataset to ensure it includes required dimensions and associated coordinate variables. Parameters ---------- ds : xarray.Dataset Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y'). Must include coordinate variables. Raises ------ ValueError - If 'x' or 'y' dimensions are missing. - If any dimension does not have an associated coordinate variable. """ # TODO: get accepted dims from metadata mapping? # accepted_dims = {"x", "y", "t", "t_wgms", "t_sfc_type", "snowcover_frac"} # if not set(ds.dims).issubset(accepted_dims): # raise ValueError( # "Incorrect dataset dimensions." # f" Accepted data dimensions are: {accepted_dims}" # ) for dim in ds.dims: if dim not in ds.coords: raise ValueError( f"Coordinate variable for dimension '{dim}' is missing in " "the dataset." ) return ds def _calculate_chunk_sizes( self: GeoZarrHandler, var: xr.DataArray ) -> dict[str, int]: """Calculate chunk sizes for a given variable to match the target chunk size in megabytes. Parameters ---------- var : xr.DataArray Data array whose dtype and dimensions are used to compute chunk sizes. Returns ------- dict[str, int] A dictionary of chunk sizes for dimensions 'x', 'y', and optionally 't'. """ target_bytes = self.target_chunk_mb * 1024 * 1024 if "t_sfc_type" in var.dims: t_var = "t_sfc_type" elif "t_wgms" in var.dims: t_var = "t_wgms" else: t_var = "t" t_size = var.sizes.get(t_var, 1) # Defaults to 1 if no 't' dimension chunk_sizes = {} if "x" in var.dims and "y" in var.dims: x_size = var.sizes["x"] y_size = var.sizes["y"] # Calculate the number of elements allowed per chunk # After accounting for a full 't' slice elements_per_t_slice = target_bytes // (var.dtype.itemsize * t_size) # Determine side length based on remaining budget side_length = int(np.sqrt(elements_per_t_slice)) chunk_x = min(x_size, side_length) chunk_y = min(y_size, side_length) chunk_sizes["x"] = chunk_x chunk_sizes["y"] = chunk_y if t_var in var.dims: # Use the full length of 't' - this allows more efficient loading, # assuming the user is always interested in the full time series chunk_sizes[t_var] = t_size for dim in var.dims: if dim in ["member", "snowcover_frac"]: # use one to save each dimension separately chunk_sizes[dim] = 1 elif dim not in [t_var, "x", "y"]: chunk_sizes[dim] = var.sizes[dim] return chunk_sizes def _define_encodings( self: GeoZarrHandler, ds: xr.Dataset, ds_name: str, ds_type: str = None ) -> None: """Define encoding settings for each data variable in the dataset, including chunking and compression. Parameters ---------- ds : xarray.Dataset Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y'). Must include coordinate variables. ds_name : str Dataset name to be used for this node of the tree. Notes ----- Chunk sizes are computed using `_calculate_chunk_sizes`, and the compressor is set according to the class-level setting. """ encoding_key = f"/{ds_name}" if ds_type is not None: encoding_key = f"{encoding_key}/{ds_type}" if encoding_key not in self.encoding: self.encoding[encoding_key] = {} for var in ds.data_vars: chunk_sizes = self._calculate_chunk_sizes(ds[var]) chunks = tuple(chunk_sizes.get(dim) for dim in ds[var].dims) self.encoding[encoding_key][var] = { "chunks": chunks, "compressor": self.compressor, } def _update_metadata(self, ds: xr.Dataset, ds_name: str) -> xr.Dataset: """Update metadata to Climate and Forecast convention. Parameters ---------- ds : xarray.Dataset Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y'). Must include coordinate variables. ds_name : str Layer name for this node of the tree. Metadata is first updated using the ``update_metadata`` method. Each data variable is tagged with the ``grid_mapping`` attribute for spatial referencing. """ ds = self.update_metadata(ds, ds_name) for var in ds.data_vars: var_dims = ds[var].dims if "x" in var_dims or "y" in var_dims: ds[var].attrs["grid_mapping"] = "spatial_ref" return ds
[docs] def export( self: GeoZarrHandler, storage_directory: str, overwrite: bool = True ) -> None: """Write the dataset to GeoZarr format. Parameters ---------- storage_directory : str Path to write the Zarr data. overwrite : bool, default True Whether to overwrite existing Zarr contents in the target location. """ dir_path = Path(storage_directory).parent if not dir_path.exists(): raise FileNotFoundError( f"Base directory of 'storage_directory' does not exist: {dir_path}" ) self.data_tree.to_zarr( storage_directory, mode="w" if overwrite else "a", consolidated=True, zarr_format=self.zarr_format, encoding=self.encoding, )
[docs] def add_layer( self: GeoZarrHandler, ds: xr.Dataset, ds_name: str, overwrite: bool = False ) -> None: """Add a new dataset as a child group of the DataTree at the root. Parameters ---------- ds : xarray.Dataset New dataset layer to be added to the existing data tree. ds_name : str Layer name to be used for this node of the tree. overwrite : bool If True, allow a layer of the same name to be overwritten. """ if ds_name in self.data_tree.children and not overwrite: raise ValueError(f"Group '{ds_name}' already exists.") # prepare new dataset ds = self._validate_dataset(ds) ds = self._update_metadata(ds, ds_name) # append additional encodings to the encodings class attribute self._define_encodings(ds, ds_name) # validate dataset attributes for var in ds.data_vars: attrs = ds[var].attrs.copy() attrs.pop("grid_mapping", None) self.METADATA_SCHEMA_DATA.validate(attrs) self.data_tree[ds_name] = xr.DataTree(dataset=ds)
[docs] def add_datacube( self: GeoZarrHandler, datacubes: dict, datacube_name: str, overwrite: bool = False, ) -> None: """Add a new dataset as a child group of the DataTree at the root. Parameters ---------- datacubes : dict A dictionary with keys one of the currently supported L2 datacubes ('monthly', 'annual_hydro', 'daily_smb') and values the corresponding xr.Dataset. datacube_name : str Layer name to be used for this node of the tree. It should either contain L2 or L3. If nothing from the both is included the name will get L2_ as suffix. overwrite : bool If True, allow a layer of the same name to be overwritten. """ if "L2" not in datacube_name and "L3" not in datacube_name: # by default, we assume it is L2 datacube_name = f"L2_{datacube_name}" if datacube_name in self.data_tree.children and not overwrite: raise ValueError(f"Group '{datacube_name}' already exists.") if not isinstance(datacubes, dict): raise ValueError( f"Datacubes need to be provided as dict with keys " f"one of 'monthly, 'annual_hydro' or 'daily_smb'." ) # prepare leaves of new layer new_leaves = {} for datacube_type in datacubes: if datacube_type not in ["monthly", "annual_hydro", "daily_smb"]: raise ValueError( "We currently only support model output " "datacubes of the types 'monthly', " "'annual_hydro' and 'daily_smb'." ) datacube_tmp = datacubes[datacube_type] datacube_tmp = self._validate_dataset(datacube_tmp) datacube_tmp = self._update_metadata(datacube_tmp, datacube_name) # append additional encodings to the encodings class attribute self._define_encodings( ds=datacube_tmp, ds_name=datacube_name, ds_type=datacube_type ) # validate dataset attributes for var in datacube_tmp.data_vars: attrs = datacube_tmp[var].attrs.copy() attrs.pop("grid_mapping", None) attrs.pop("inf_values", None) self.METADATA_SCHEMA_DATA.validate(attrs) for coord in datacube_tmp.coords: if coord not in ["spatial_ref"]: attrs = datacube_tmp[coord].attrs.copy() self.METADATA_SCHEMA_COORDS.validate(attrs) # after validation add to leaves new_leaves[datacube_type] = xr.DataTree( name=datacube_type, dataset=datacube_tmp ) self.data_tree[datacube_name] = xr.DataTree( name=datacube_name, children=new_leaves )
[docs] def get_layer(self: GeoZarrHandler, ds_name: str) -> xr.Dataset: """Get a dataset from a DataTree. Parameters ---------- ds_name : str Layer name. Returns ------- xr.Dataset Dataset layer in tree. Raises ------ KeyError If the layer name is not present in the data tree. AttributeError If the layer does not contain a dataset. """ try: layer = self.data_tree[ds_name].ds except KeyError: raise KeyError(f"{ds_name} layer not found in the data tree.") return layer