Source code for rhg_compute_tools.io_tools.writers

import fsspec
import git
import xarray as xr
import pandas as pd

import dask
from pathlib import Path
import os
import zarr
from typing import Union

import rhg_compute_tools.xarray
from rhg_compute_tools.io_tools.readers import get_fs


def _xr_document_repo_state(ds: xr.Dataset, repo_root=".") -> xr.Dataset:

    repo = git.Repo(repo_root, search_parent_directories=True)
    repo_dir = os.path.dirname(repo.git_dir)

    rhg_compute_tools.xarray.document_dataset(ds, repo_dir)

    return ds


[docs]def document_dataset(ds: xr.Dataset, repo_root: str = ".") -> xr.Dataset: ds = _xr_document_repo_state(ds, repo_root=repo_root) ds.attrs["updated"] = pd.Timestamp.now(tz="UTC").strftime("%c (%Z)") return ds
[docs]def get_maximal_chunks_encoding(ds: xr.Dataset, **var_chunks) -> dict: encoding_kwargs = {"encoding": {}} for c in ds.coords.keys(): if ds.coords[c].chunks is None: encoding_kwargs["encoding"][c] = { "chunks": tuple([-1 for _ in ds.coords[c].dims]) } else: encoding_kwargs["encoding"][c] = { "chunks": tuple([max(v) for v in ds.coords[c].chunks]) } for v in ds.data_vars.keys(): if ds[v].chunks is None: encoding_kwargs["encoding"][v] = { "chunks": tuple([var_chunks.get(d, -1) for d in ds[v].dims]) } else: encoding_kwargs["encoding"][v] = { "chunks": tuple([max(v) for v in ds[v].chunks]) } return encoding_kwargs
[docs]def write_zarr( out_ds: xr.Dataset, out_fp: str, fs: Union[None, fsspec.filesystem] = None, set_maximal_chunks: bool = True, writer_kwargs: Union[dict, None] = None, encoding_kwargs: Union[dict, None] = None, ) -> None: if fs is None: fs = get_fs(out_fp) if writer_kwargs is None: writer_kwargs = {} assert isinstance(out_ds, xr.Dataset), ( "Do not write a DataArray. Instead use da.to_dataset(name='variable_name') " "to convert to a Dataset, and then assign metadata prior to writing" ) mapper = fs.get_mapper(out_fp) if encoding_kwargs is None: if set_maximal_chunks: encoding_kwargs = get_maximal_chunks_encoding(out_ds) else: encoding_kwargs = {} for v in list(out_ds.coords.keys()): if out_ds.coords[v].dtype == object and v != "time": out_ds.coords[v] = out_ds.coords[v].astype("unicode") for v in list(out_ds.variables.keys()): if out_ds[v].dtype == object and v != "time": out_ds[v] = out_ds[v].astype("unicode") try: futures = out_ds.to_zarr( mapper, compute=False, **writer_kwargs, **encoding_kwargs ) dask.compute(futures, retries=3) except zarr.errors.ContainsGroupError: raise zarr.errors.ContainsGroupError(out_fp)
[docs]def write_netcdf( out_ds: xr.Dataset, out_fp: str, fs: Union[None, fsspec.filesystem] = None, ) -> None: assert isinstance(out_ds, xr.Dataset), ( "for consistency, do not write a DataArray. Instead use da.to_dataset(name='variable_name') " "to convert to a Dataset, and then assign metadata prior to writing" ) # netcdf requires the parent directory to be created parent_dir = "/".join(out_fp.split("/")[:-1]).replace("gs://", "/gcs/") os.makedirs(parent_dir, exist_ok=True) futures = out_ds.to_netcdf(out_fp.replace("gs://", "/gcs/"), compute=False) dask.compute(futures, retries=3)