Source code for watershed_workflow.sources.manager_raster

"""Basic manager for interacting with raster files.
"""
from typing import Tuple, List, Optional, Iterable

import os
import xarray as xr
import shapely
import rioxarray
import cftime
import logging

import watershed_workflow.crs
from watershed_workflow.crs import CRS

from . import manager_dataset


[docs] class ManagerRaster(manager_dataset.ManagerDataset): """A simple class for reading rasters.""" def __init__(self, filename : str, url : Optional[str] = None, native_resolution : Optional[float] = None, native_crs : Optional[CRS] = None, bands : Optional[Iterable[str] | int] = None, ): """Initialize raster manager. Parameters ---------- filename : str Path to the raster file. """ self.filename = filename self.url = url # a flag to only preprocess once self._file_preprocessed = False # Use basename of file as name name = f'raster: "{os.path.basename(filename)}"' # Use absolute path as source for complete provenance source = os.path.abspath(filename) if bands is None: valid_variables = None default_variables = None elif isinstance(bands, int): valid_variables = [f'band {i+1}' for i in range(num_bands)] default_variables = [valid_variables[0],] else: valid_variables = bands default_variables = [valid_variables[0],] # Initialize base class super().__init__( name, source, native_resolution, native_crs, native_crs, None, None, valid_variables, default_variables ) def _prerequestDataset(self) -> None: # first download -- this is done here and not in _request so # that we can set the resolution and CRS for input geometry # manipulation. if not os.path.isfile(self.filename) and self.url is not None: self._download() if not self._file_preprocessed: # Inspect raster to get native properties with rioxarray.open_rasterio(self.filename) as temp_ds: # Get native CRS self.native_crs_in = temp_ds.rio.crs self.native_crs_out = temp_ds.rio.crs # Get native resolution (approximate from first pixel) if len(temp_ds.coords['x']) > 1 and len(temp_ds.coords['y']) > 1: x_res = abs(float(temp_ds.coords['x'][1] - temp_ds.coords['x'][0])) y_res = abs(float(temp_ds.coords['y'][1] - temp_ds.coords['y'][0])) self.native_resolution = max(x_res, y_res) else: self.native_resolution = 1.0 # fallback # Create variable names for each band if self.valid_variables is None: if hasattr(temp_ds, 'band'): # pull from bands self.valid_variables = [f'band_{i}' for i in temp_ds.band.values] # First band as default self.default_variables = [self.valid_variables[0],] elif len(d.values.shape) == 3: num_bands = d.values.shape[0] self.valid_variables = [f'band_{i}' for i in range(num_bands)] self.default_variables = [self.valid_variables[0],] # only do this work once self._file_preprocessed = True def _requestDataset(self, request : manager_dataset.ManagerDataset.Request ) -> manager_dataset.ManagerDataset.Request: """Request the data -- ready upon request.""" request.is_ready = True return request def _fetchDataset(self, request : manager_dataset.ManagerDataset.Request) -> xr.Dataset: """Fetch the data.""" bounds = request.geometry.bounds # Open raster and clip to bounds if not self.filename.lower().endswith('.tif'): dataset = rioxarray.open_rasterio(self.filename, chunk='auto') else: dataset = rioxarray.open_rasterio(self.filename, cache=False) # Clip to bounds dataset = dataset.rio.clip_box(*bounds, crs=watershed_workflow.crs.to_rasterio(self.native_crs_out)) # Convert to Dataset with band variables result_dataset = xr.Dataset() if request.variables is None: # single-variable case if len(dataset.shape) > 2: result_dataset['raster'] = dataset[0, :, :] # Take first band else: result_dataset['raster'] = dataset else: for var in request.variables: assert var.startswith('band_') band_idx = int(var.split('_')[1]) - 1 # Convert to 0-indexed if len(dataset.shape) > 2 and band_idx < dataset.shape[0]: band_data = dataset[band_idx, :, :] band_data = band_data.drop_vars('band', errors='ignore') result_dataset[var] = band_data elif len(dataset.shape) == 2: # Single band raster if band_idx == 0: result_dataset[var] = dataset else: raise ValueError(f"Band {band_idx + 1} not available in raster") return result_dataset def _download(self, force : bool = False): """A default download implementation.""" os.makedirs(os.path.dirname(self.filename), exist_ok=True) if not os.path.exists(self.filename) or force: source_utils.download(self.url, self.filename, force)