"""Copyright 2025 DTCG Contributors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=====
Functionality for exporting a GeoZarr file.
"""
from __future__ import annotations
from pathlib import Path
from typing import Optional
import numpy as np
import xarray as xr
from numcodecs import Blosc
from dtcg.datacube.update_metadata import MetadataMapper
[docs]
class GeoZarrHandler(MetadataMapper):
[docs]
def __init__(
self: GeoZarrHandler,
ds: xr.Dataset,
ds_name: str = "L1",
target_chunk_mb: float = 5.0,
compressor: Optional[Blosc] = None,
metadata_mapping_file_path: str = None,
zarr_format: int = 2,
):
"""Initialise a GeoZarrHandler object.
Parameters
----------
ds : xarray.Dataset
Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y').
Must include coordinate variables.
ds_name : str, default 'L1'
Name of datacube.
target_chunk_mb : float, default 5.0
Approximate chunk size in megabytes for efficient storage.
compressor : Blosc, default None
Compressor to apply on arrays. If None, the compression will
be Blosc with zstd.
metadata_mapping_file_path : str, default None
Path to the YAML file containing variable metadata mappings.
If None, defaults to 'metadata_mapping.yaml' in the current
directory.
zarr_format : int, default 2
Zarr format version to use (2 or 3).
"""
super().__init__(metadata_mapping_file_path=metadata_mapping_file_path)
self.ds_name = ds_name
self.target_chunk_mb = target_chunk_mb
self.compressor = compressor or Blosc(
cname="zstd", clevel=3, shuffle=Blosc.BITSHUFFLE
)
self.zarr_format = zarr_format
ds = self._validate_dataset(ds)
ds = self._update_metadata(ds, ds_name)
self.encoding = {}
self._define_encodings(ds, ds_name)
# convert dataset to datatree
self.data_tree = xr.DataTree.from_dict({ds_name: ds})
def _validate_dataset(self: GeoZarrHandler, ds: xr.Dataset) -> xr.Dataset:
"""Validate the input dataset to ensure it includes required
dimensions and associated coordinate variables.
Parameters
----------
ds : xarray.Dataset
Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y').
Must include coordinate variables.
Raises
------
ValueError
- If 'x' or 'y' dimensions are missing.
- If any dimension does not have an associated coordinate
variable.
"""
accepted_dims = {"x", "y", "t"}
if not set(ds.dims).issubset(accepted_dims):
raise ValueError(
"Incorrect dataset dimensions."
f" Accepted data dimensions are: {accepted_dims}"
)
for dim in ds.dims:
if dim not in ds.coords:
raise ValueError(
f"Coordinate variable for dimension '{dim}' is missing in "
"the dataset."
)
return ds
def _calculate_chunk_sizes(
self: GeoZarrHandler, var: xr.DataArray
) -> dict[str, int]:
"""Calculate chunk sizes for a given variable to match the
target chunk size in megabytes.
Parameters
----------
var : xr.DataArray
Data array whose dtype and dimensions are used to compute
chunk sizes.
Returns
-------
dict[str, int]
A dictionary of chunk sizes for dimensions 'x', 'y', and
optionally 't'.
"""
target_bytes = self.target_chunk_mb * 1024 * 1024
t_size = var.sizes.get("t", 1) # Defaults to 1 if no 't' dimension
chunk_sizes = {}
if "x" in var.dims and "y" in var.dims:
x_size = var.sizes["x"]
y_size = var.sizes["y"]
# Calculate the number of elements allowed per chunk
# After accounting for a full 't' slice
elements_per_t_slice = target_bytes // (var.dtype.itemsize * t_size)
# Determine side length based on remaining budget
side_length = int(np.sqrt(elements_per_t_slice))
chunk_x = min(x_size, side_length)
chunk_y = min(y_size, side_length)
chunk_sizes["x"] = chunk_x
chunk_sizes["y"] = chunk_y
if "t" in var.dims:
# Use the full length of 't' - this allows more efficient loading,
# assuming the user is always interested in the full time series
chunk_sizes["t"] = t_size
return chunk_sizes
def _define_encodings(self: GeoZarrHandler, ds: xr.Dataset, ds_name: str) -> None:
"""Define encoding settings for each data variable in the
dataset, including chunking and compression.
Parameters
----------
ds : xarray.Dataset
Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y').
Must include coordinate variables.
ds_name : str
Dataset name to be used for this node of the tree.
Notes
-----
Chunk sizes are computed using `_calculate_chunk_sizes`, and the
compressor is set according to the class-level setting.
"""
if ds_name not in self.encoding:
self.encoding[f"/{ds_name}"] = {}
for var in ds.data_vars:
chunk_sizes = self._calculate_chunk_sizes(ds[var])
chunks = tuple(chunk_sizes.get(dim) for dim in ds[var].dims)
self.encoding[f"/{ds_name}"][var] = {
"chunks": chunks,
"compressor": self.compressor,
}
def _update_metadata(self, ds: xr.Dataset, ds_name: str) -> xr.Dataset:
"""Update metadata to Climate and Forecast convention.
Parameters
----------
ds : xarray.Dataset
Input dataset with dimensions ('x', 'y') or ('t', 'x', 'y').
Must include coordinate variables.
ds_name : str
Layer name for this node of the tree.
Metadata is first updated using the ``update_metadata`` method.
Each data variable is tagged with the ``grid_mapping`` attribute
for spatial referencing.
"""
ds = self.update_metadata(ds, ds_name)
for var in ds.data_vars:
var_dims = ds[var].dims
if "x" in var_dims or "y" in var_dims:
ds[var].attrs["grid_mapping"] = "spatial_ref"
return ds
[docs]
def export(
self: GeoZarrHandler, storage_directory: str, overwrite: bool = True
) -> None:
"""Write the dataset to GeoZarr format.
Parameters
----------
storage_directory : str
Path to write the Zarr data.
overwrite : bool, default True
Whether to overwrite existing Zarr contents in the target
location.
"""
dir_path = Path(storage_directory).parent
if not dir_path.exists():
raise FileNotFoundError(
f"Base directory of 'storage_directory' does not exist: {dir_path}"
)
self.data_tree.to_zarr(
storage_directory,
mode="w" if overwrite else "a",
consolidated=True,
zarr_format=self.zarr_format,
encoding=self.encoding,
)
[docs]
def add_layer(
self: GeoZarrHandler, ds: xr.Dataset, ds_name: str, overwrite: bool = False
) -> None:
"""Add a new dataset as a child group of the DataTree at the root.
Parameters
----------
ds : xarray.Dataset
New dataset layer to be added to the existing data tree.
ds_name : str
Layer name to be used for this node of the tree.
overwrite : bool
If True, allow a layer of the same name to be overwritten.
"""
if ds_name in self.data_tree.children and not overwrite:
raise ValueError(f"Group '{ds_name}' already exists.")
# prepare new dataset
ds = self._validate_dataset(ds)
ds = self._update_metadata(ds, ds_name)
# append additional encodings to the encodings class attribute
self._define_encodings(ds, ds_name)
# validate dataset attributes
for var in ds.data_vars:
attrs = ds[var].attrs.copy()
attrs.pop("grid_mapping", None)
self.METADATA_SCHEMA.validate(attrs)
self.data_tree[ds_name] = xr.DataTree(dataset=ds)
[docs]
def get_layer(self: GeoZarrHandler, ds_name: str) -> xr.Dataset:
"""Get a dataset from a DataTree.
Parameters
----------
ds_name : str
Layer name.
Returns
-------
xr.Dataset
Dataset layer in tree.
Raises
------
KeyError
If the layer name is not present in the data tree.
AttributeError
If the layer does not contain a dataset.
"""
try:
layer = self.data_tree[ds_name].ds
except KeyError:
raise KeyError(f"{ds_name} layer not found in the data tree.")
return layer