Source code for watershed_workflow.sources.manager_raster
"""Basic manager for interacting with raster files.
"""
from typing import Tuple, List, Optional, Iterable
import os
import xarray as xr
import shapely
import rioxarray
import cftime
import logging
import watershed_workflow.crs
from watershed_workflow.crs import CRS
from . import manager_dataset
[docs]
class ManagerRaster(manager_dataset.ManagerDataset):
    """A simple class for reading rasters."""
    def __init__(self,
                 filename : str,
                 url : Optional[str] = None,
                 native_resolution : Optional[float] = None,
                 native_crs : Optional[CRS] = None,
                 bands : Optional[Iterable[str] | int] = None,
                 ):
        """Initialize raster manager.
        
        Parameters
        ----------
        filename : str
            Path to the raster file.
        """
        self.filename = filename
        self.url = url
        # a flag to only preprocess once
        self._file_preprocessed = False
        
        # Use basename of file as name
        name = f'raster: "{os.path.basename(filename)}"'
        
        # Use absolute path as source for complete provenance
        source = os.path.abspath(filename)
        if bands is None:
            valid_variables = None
            default_variables = None
        elif isinstance(bands, int):
            valid_variables = [f'band {i+1}' for i in range(num_bands)]
            default_variables = [valid_variables[0],]
        else:
            valid_variables = bands
            default_variables = [valid_variables[0],]
        # Initialize base class
        super().__init__(
            name, source,
            native_resolution, native_crs, native_crs,
            None, None, valid_variables, default_variables
        )
    def _prerequestDataset(self) -> None:
        # first download -- this is done here and not in _request so
        # that we can set the resolution and CRS for input geometry
        # manipulation.
        if not os.path.isfile(self.filename) and self.url is not None:
            self._download()
        if not self._file_preprocessed:
            # Inspect raster to get native properties
            with rioxarray.open_rasterio(self.filename) as temp_ds:
                # Get native CRS
                self.native_crs_in = temp_ds.rio.crs
                self.native_crs_out = temp_ds.rio.crs
                # Get native resolution (approximate from first pixel)
                if len(temp_ds.coords['x']) > 1 and len(temp_ds.coords['y']) > 1:
                    x_res = abs(float(temp_ds.coords['x'][1] - temp_ds.coords['x'][0]))
                    y_res = abs(float(temp_ds.coords['y'][1] - temp_ds.coords['y'][0]))
                    self.native_resolution = max(x_res, y_res)
                else:
                    self.native_resolution = 1.0  # fallback
                # Create variable names for each band
                if self.valid_variables is None:
                    if hasattr(temp_ds, 'band'):
                        # pull from bands
                        self.valid_variables = [f'band_{i}' for i in temp_ds.band.values]
                        # First band as default
                        self.default_variables = [self.valid_variables[0],]
                    elif len(d.values.shape) == 3:
                        num_bands = d.values.shape[0]
                        self.valid_variables = [f'band_{i}' for i in range(num_bands)]
                        self.default_variables = [self.valid_variables[0],]
            # only do this work once
            self._file_preprocessed = True
    def _requestDataset(self, request : manager_dataset.ManagerDataset.Request
                        ) -> manager_dataset.ManagerDataset.Request:
        """Request the data -- ready upon request."""
        request.is_ready = True
        return request
    def _fetchDataset(self, request : manager_dataset.ManagerDataset.Request) -> xr.Dataset:
        """Fetch the data."""
        bounds = request.geometry.bounds
        
        # Open raster and clip to bounds
        if not self.filename.lower().endswith('.tif'):
            dataset = rioxarray.open_rasterio(self.filename, chunk='auto')
        else:
            dataset = rioxarray.open_rasterio(self.filename, cache=False)
            
        # Clip to bounds
        dataset = dataset.rio.clip_box(*bounds, crs=watershed_workflow.crs.to_rasterio(self.native_crs_out))
        
        # Convert to Dataset with band variables
        result_dataset = xr.Dataset()
        
        if request.variables is None:
            # single-variable case
            if len(dataset.shape) > 2:
                result_dataset['raster'] = dataset[0, :, :]  # Take first band
            else:
                result_dataset['raster'] = dataset
        else:
            for var in request.variables:
                assert var.startswith('band_')
                band_idx = int(var.split('_')[1]) - 1  # Convert to 0-indexed
                if len(dataset.shape) > 2 and band_idx < dataset.shape[0]:
                    band_data = dataset[band_idx, :, :]
                    band_data = band_data.drop_vars('band', errors='ignore')
                    result_dataset[var] = band_data
                elif len(dataset.shape) == 2:  # Single band raster
                    if band_idx == 0:
                        result_dataset[var] = dataset
                else:
                    raise ValueError(f"Band {band_idx + 1} not available in raster")
        return result_dataset
  
    def _download(self, force : bool = False):
        """A default download implementation."""
        os.makedirs(os.path.dirname(self.filename), exist_ok=True)
        if not os.path.exists(self.filename) or force:
            source_utils.download(self.url, self.filename, force)