Source code for rhg_compute_tools.io_tools.readers

import fsspec
import os
import re
import yaml
import tempfile
import rasterio
import numpy as np
import pandas as pd
import xarray as xr
import geopandas as gpd
from zarr.errors import GroupNotFoundError
from contextlib import contextmanager
import contextlib

try:
    from gcsfs.retry import HttpError as GCSFSHttpError
except ImportError:
    from gcsfs.utils import HttpError as GCSFSHttpError


[docs]def get_fs(fp) -> fsspec.AbstractFileSystem: filesystem = "" if "://" not in fp else fp.split("://")[0] fs = fsspec.filesystem( filesystem, cache_timeout=-1, timeout=120, requests_timeout=120, read_timeout=120, conn_timeout=120, ) return fs
[docs]@contextmanager def report_if_not_found(fp): try: yield except (FileNotFoundError, GroupNotFoundError): raise FileNotFoundError(fp) except KeyError as e: if ".zmetadata" not in str(e.args[0]): raise raise FileNotFoundError(fp)
[docs]def read_zarr( fp, fs=None, mapper_kwargs=None, isel_dict: dict = None, load_data: bool = False, **kwargs, ) -> xr.Dataset: if fs is None: fs = get_fs(fp) with report_if_not_found(fp): mapper = fs.get_mapper(fp) ds = xr.open_zarr(mapper, **kwargs) if isel_dict is not None: assert isinstance(isel_dict, dict), "`sel_dict_method` should be a dict" assert all( [x in ds.dims for x in isel_dict] ), f"all keys of `isel_dict` are not in ds.dims {isel_dict.keys(), list(ds.dims)}" ds = ds.isel(isel_dict) if load_data: ds = ds.load() return ds
[docs]def load_netcdf(blob, fsspec_kwargs=None, *args, retries=5, **kwargs) -> xr.Dataset: """Read a geotiff or raster file from a local or gs:// location""" if fsspec_kwargs is None: fsspec_kwargs = {} for i in range(retries + 1): try: if os.path.exists(blob): with xr.open_dataset(blob, *args, **kwargs) as ds: # yield ds ds = ds.load() return ds break elif "://" in str(blob): with tempfile.NamedTemporaryFile(suffix=".nc") as tmp_fp: fp = str(tmp_fp.name) protocol = re.match(r"(\w+)://", blob, re.I).group(1) fs = fsspec.filesystem(protocol, **fsspec_kwargs) fs.get(blob, fp) with xr.open_dataset(fp, *args, **kwargs) as ds: # yield ds ds = ds.load() return ds break except (IOError, GCSFSHttpError) as e: if i >= retries: raise OSError(f"read aborted after {i} retry attempts: {e}")
[docs]def read_dataset(fp, engine=None, **kwargs) -> xr.Dataset: if engine is None: if fp.endswith(".zarr"): engine = "zarr" elif fp.endswith(".nc") or fp.endswith(".nc4"): engine = "netcdf4" else: raise IOError(f"engine could not be auto-determined from fp: {fp}") if engine == "zarr": return read_zarr(fp, **kwargs) elif engine == "netcdf4": return load_netcdf(fp, **kwargs) else: raise IOError( f"engine not recognized: {engine}. Choose one of {{'zarr', 'netcdf4'}}." )
[docs]def read_shapefile(fp): with fsspec.open(fp) as f: return gpd.read_file(f)
[docs]@contextlib.contextmanager def read_rasterio(blob, fsspec_kwargs=None, *args, retries=5, **kwargs) -> xr.DataArray: """Read a NETCDF file from a local or gs:// location""" if fsspec_kwargs is None: fsspec_kwargs = {} for i in range(retries + 1): try: if os.path.exists(blob): with xr.open_rasterio(blob) as ds: yield ds break elif "://" in str(blob): with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, os.path.basename(blob)) fs = fsspec.filesystem(blob.split("://")[0]) fs.get(blob.replace("/gcs/", "gs://"), f) with rasterio.open(f) as data: with xr.open_rasterio(data) as ds: yield ds break else: raise ValueError("file protocol not recognized: {blob}") except (IOError, GCSFSHttpError) as e: if i >= retries: raise OSError(f"read aborted after {i} retry attempts: {e}")
[docs]@contextlib.contextmanager def read_netcdf(blob, fsspec_kwargs=None, *args, retries=5, **kwargs) -> xr.Dataset: """Read a geotiff or raster file from a local or gs:// location. Very similar to load_netcdf, but without the load.""" if fsspec_kwargs is None: fsspec_kwargs = {} for i in range(retries + 1): try: if os.path.exists(blob): with xr.open_dataset(blob, *args, **kwargs) as ds: yield ds break elif "://" in str(blob): with tempfile.NamedTemporaryFile(suffix=".nc") as tmp_fp: fp = str(tmp_fp.name) protocol = re.match(r"(\w+)://", blob, re.I).group(1) fs = fsspec.filesystem(protocol, **fsspec_kwargs) fs.get(blob, fp) with xr.open_dataset(fp, *args, **kwargs) as ds: yield ds break except (IOError, GCSFSHttpError) as e: if i >= retries: raise OSError(f"read aborted after {i} retry attempts: {e}")
[docs]def read_csv(blob, **fsspec_kwargs) -> pd.DataFrame: """ Read a csv file from a local or gs:// location """ if not "://" in str(blob): fsspec_kwargs = {} return pd.read_csv(blob, **fsspec_kwargs)
[docs]def read_parquet(blob, **fsspec_kwargs) -> pd.DataFrame: """ Read a parquet file from a local or gs:// location """ if not "://" in str(blob): fsspec_kwargs = {} return pd.read_parquet(blob, **fsspec_kwargs)
[docs]def read_dataframe(blob, **fsspec_kwargs) -> pd.DataFrame: """ Read a CSV or parquet file from a local or gs:// location """ if blob.endswith(".csv") or blob.endswith(".txt"): return read_csv(blob, **fsspec_kwargs) elif blob.endswith(".parquet"): return read_parquet(blob, **fsspec_kwargs) else: parts = os.path.basename(blob).split(".") if len(parts) == 1: raise ValueError("No extension could be inferred for file: {}".format(blob)) ext = ".".join(parts[1:]) raise ValueError( "File type could not be inferred from extension: {}".format(ext) )
[docs]def read_csvv_response(fp) -> xr.Dataset: with fsspec.open(fp, "r") as f: firstline = f.readline().strip() assert firstline == "---", firstline header, data = f.read().split("...\n") header = yaml.safe_load(header) fields = [ "observations", "prednames", "covarnames", "obsnames", "gamma", "gammavcv", "residvcv", ] parsed_data = {} current_field = None current = [] for line in data.split("\n"): if line.strip() in fields: if current_field is not None: if "names" in current_field: parsed_data[current_field] = np.array( [c.strip().rstrip(",").split(",") for c in current] ) else: parsed_data[current_field] = np.loadtxt(current, delimiter=",") current = [] current_field = line.strip() else: current.append(line.strip().rstrip(",").replace(", ", ",")) if "names" in current_field: parsed_data[current_field] = np.array( [c.strip().rstrip(",").split(",") for c in current] ) else: parsed_data[current_field] = np.loadtxt(current, delimiter=",") if "obsnames" not in parsed_data: parsed_data["obsnames"] = np.array( ["outcome"] * len(parsed_data["prednames"].flat) ) X = pd.MultiIndex.from_arrays( [ parsed_data["prednames"].flat, parsed_data["covarnames"].flat, parsed_data["obsnames"].flat, ], names=["predictor", "covariate", "outcome"], ) Y = pd.MultiIndex.from_arrays( [ parsed_data["prednames"].flat, parsed_data["covarnames"].flat, parsed_data["obsnames"].flat, ], names=["predictor_y", "covariate_y", "outcome_y"], ) ds = xr.Dataset( { "gamma": xr.DataArray(parsed_data["gamma"], [X], ["X"]), "vcv": xr.DataArray(parsed_data["gammavcv"], [X, Y], ["X", "Y"]), "residvcv": xr.DataArray(parsed_data["residvcv"], [], []), }, attrs={k: v for k, v in header.items() if not isinstance(v, dict)}, ) dict_keys = [k for k, v in header.items() if isinstance(v, dict)] for dk in dict_keys: ds.attrs.update({f"{dk}_{k}": v for k, v in header[dk].items()}) return ds