Source code for watershed_workflow.sources.manager_shapefile
"""Basic manager for interacting with shapefiles.
"""
from typing import Optional, List
import os
import pyogrio
import geopandas as gpd
from shapely.geometry.base import BaseGeometry
import watershed_workflow.utils
import watershed_workflow.crs
import watershed_workflow.warp
from watershed_workflow.crs import CRS
from . import manager_shapes
from . import standard_names as names
[docs]
class ManagerShapefile(manager_shapes.ManagerShapes):
    """A simple class for reading shapefiles.
    Parameters
    ----------
    filename : str
      Path to the shapefile.
    id_name : str, optional
      Name of the ID field in the shapefile.
    """
    
    def __init__(self,
                 filename: str,
                 url : Optional[str] = None,
                 id_name: Optional[str] = None
                 ):
        """Initialize shapefile manager.
        
        Parameters
        ----------
        filename : str
            Path to the shapefile.
        url : str, optional
            URL from which to download the file.
        id_name : str, optional
            Name of the ID field in the shapefile.
        """
        self.filename = filename
        self.url = url
        self.id_name = id_name
        # flag to indicate that we have the file and we have processed
        # it for metadata
        self._file_preprocessed = False
        
        # Use basename of file as name
        name = f'shapefile: "{os.path.basename(filename)}"'
        # Use id_name or 'ID' as native_id_field
        native_id_field = id_name if id_name is not None else 'ID'
        if url is not None:
            # url is the source
            source = url
        else:
            # Use absolute path as source for complete provenance  
            source = os.path.abspath(filename)
        # Initialize base class
        super().__init__(name, source, None, None, native_id_field)
    def _prerequestDataset(self):
        # first download -- this is done here and not in _request so
        # that we can set the resolution and CRS for input geometry
        # manipulation.
        if not os.path.isfile(self.filename) and self.url is not None:
            self._download()
        if not self._file_preprocessed:
            # Get file info to determine native CRS
            info = pyogrio.read_info(self.filename)
            self.native_crs_in = watershed_workflow.crs.from_string(info['crs'])
        
            # Estimate resolution from bounds (simple heuristic)
            # Use 1/1000th of the smallest dimension as resolution estimate
            bounds = info['total_bounds']
            width = bounds[2] - bounds[0]
            height = bounds[3] - bounds[1]
            self.native_resolution = min(width, height) / 1000.0
            # only do this work once
            self._file_preprocessed = True
        
        
    def _getShapes(self) -> gpd.GeoDataFrame:
        """Read the file and get all shapes.
        
        Returns
        -------
        gpd.GeoDataFrame
            All shapes from the shapefile.
        """
        return gpd.read_file(self.filename)
    def _getShapesByGeometry(self, geometry_gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
        """Fetch shapes for the given geometry.
        Parameters
        ----------
        geometry_gdf : gpd.GeoDataFrame
            GeoDataFrame with geometries in native_crs_in to search for shapes.
        Returns
        -------
        gpd.GeoDataFrame
            Raw GeoDataFrame with native column names and CRS properly set.
        """
        # Use bbox filtering - full intersection handled by base class
        union_geometry = geometry_gdf.union_all()
        df = gpd.read_file(self.filename, bbox=union_geometry.bounds)
        return df
    def _getShapesByID(self, ids: List[str]) -> gpd.GeoDataFrame:
        """Fetch shapes by ID list.
        Parameters
        ----------
        ids : List[str]
            List of IDs to retrieve.
        Returns
        -------
        gpd.GeoDataFrame
            Raw GeoDataFrame with native column names and CRS properly set.
        """
        if self.id_name is not None:
            # Read full file and filter by specified ID field
            df = gpd.read_file(self.filename)
            if self.id_name not in df.columns:
                raise ValueError(f"ID field '{self.id_name}' not found in shapefile columns: {list(df.columns)}")
                
            id_column = df[self.id_name]
            if len(id_column) > 0:
                target_type = type(id_column.iloc[0])
                try:
                    converted_ids = [target_type(id_val) for id_val in ids]
                    df = df[df[self.id_name].isin(converted_ids)]
                except (ValueError, TypeError) as e:
                    raise ValueError(f"Cannot convert IDs {ids} to type {target_type} for field '{self.id_name}': {e}")
        else:
            # No ID field specified - use row indices
            try:
                int_ids = [int(id_val) for id_val in ids]
            except (ValueError, TypeError) as e:
                raise ValueError(f"Cannot convert IDs {ids} to integers for row-based access: {e}")
            
            # Validate indices first
            info = pyogrio.read_info(self.filename)
            total_rows = info['features']
            valid_indices = [i for i in int_ids if 0 <= i < total_rows]
            
            if len(valid_indices) != len(int_ids):
                invalid_indices = [i for i in int_ids if i < 0 or i >= total_rows]
                raise ValueError(f"Invalid row indices {invalid_indices}. File has {total_rows} rows (0-{total_rows-1})")
            
            # Optimize for single row case
            if len(valid_indices) == 1:
                # Read just the single row using slice
                index = valid_indices[0]
                df = gpd.read_file(self.filename, rows=slice(index, index + 1))
            else:
                # Read full file and select specific rows
                # Note: gpd.read_file(rows=list) is not supported, so we read all and filter
                df = gpd.read_file(self.filename)
                df = df.iloc[valid_indices]
            
        return df
    def _addStandardNames(self, df: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
        """Convert native column names to standard names.
        Parameters
        ----------
        df : gpd.GeoDataFrame
            GeoDataFrame with native column names.
        Returns
        -------
        gpd.GeoDataFrame
            GeoDataFrame with standard column names added.
        """
        # Map ID field if it exists, otherwise create row-based IDs
        if self.id_name is not None and self.id_name in df.columns:
            df[names.ID] = df[self.id_name]
        else:
            # For row-based access or when ID field doesn't exist, use row indices
            df[names.ID] = range(len(df))
        
        # No other standard name mappings for generic shapefiles
        return df