Source code for watershed_workflow.sources.manager_shapefile
"""Basic manager for interacting with shapefiles.
"""
from typing import Optional, List
import os
import pyogrio
import geopandas as gpd
from shapely.geometry.base import BaseGeometry
import watershed_workflow.utils
import watershed_workflow.crs
import watershed_workflow.warp
from watershed_workflow.crs import CRS
from . import manager_shapes
from . import standard_names as names
[docs]
class ManagerShapefile(manager_shapes.ManagerShapes):
"""A simple class for reading shapefiles.
Parameters
----------
filename : str
Path to the shapefile.
id_name : str, optional
Name of the ID field in the shapefile.
"""
def __init__(self,
filename: str,
url : Optional[str] = None,
id_name: Optional[str] = None
):
"""Initialize shapefile manager.
Parameters
----------
filename : str
Path to the shapefile.
url : str, optional
URL from which to download the file.
id_name : str, optional
Name of the ID field in the shapefile.
"""
self.filename = filename
self.url = url
self.id_name = id_name
# flag to indicate that we have the file and we have processed
# it for metadata
self._file_preprocessed = False
# Use basename of file as name
name = f'shapefile: "{os.path.basename(filename)}"'
# Use id_name or 'ID' as native_id_field
native_id_field = id_name if id_name is not None else 'ID'
if url is not None:
# url is the source
source = url
else:
# Use absolute path as source for complete provenance
source = os.path.abspath(filename)
# Initialize base class
super().__init__(name, source, None, None, native_id_field)
def _prerequestDataset(self):
# first download -- this is done here and not in _request so
# that we can set the resolution and CRS for input geometry
# manipulation.
if not os.path.isfile(self.filename) and self.url is not None:
self._download()
if not self._file_preprocessed:
# Get file info to determine native CRS
info = pyogrio.read_info(self.filename)
self.native_crs_in = watershed_workflow.crs.from_string(info['crs'])
# Estimate resolution from bounds (simple heuristic)
# Use 1/1000th of the smallest dimension as resolution estimate
bounds = info['total_bounds']
width = bounds[2] - bounds[0]
height = bounds[3] - bounds[1]
self.native_resolution = min(width, height) / 1000.0
# only do this work once
self._file_preprocessed = True
def _getShapes(self) -> gpd.GeoDataFrame:
"""Read the file and get all shapes.
Returns
-------
gpd.GeoDataFrame
All shapes from the shapefile.
"""
return gpd.read_file(self.filename)
def _getShapesByGeometry(self, geometry_gdf: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""Fetch shapes for the given geometry.
Parameters
----------
geometry_gdf : gpd.GeoDataFrame
GeoDataFrame with geometries in native_crs_in to search for shapes.
Returns
-------
gpd.GeoDataFrame
Raw GeoDataFrame with native column names and CRS properly set.
"""
# Use bbox filtering - full intersection handled by base class
union_geometry = geometry_gdf.union_all()
df = gpd.read_file(self.filename, bbox=union_geometry.bounds)
return df
def _getShapesByID(self, ids: List[str]) -> gpd.GeoDataFrame:
"""Fetch shapes by ID list.
Parameters
----------
ids : List[str]
List of IDs to retrieve.
Returns
-------
gpd.GeoDataFrame
Raw GeoDataFrame with native column names and CRS properly set.
"""
if self.id_name is not None:
# Read full file and filter by specified ID field
df = gpd.read_file(self.filename)
if self.id_name not in df.columns:
raise ValueError(f"ID field '{self.id_name}' not found in shapefile columns: {list(df.columns)}")
id_column = df[self.id_name]
if len(id_column) > 0:
target_type = type(id_column.iloc[0])
try:
converted_ids = [target_type(id_val) for id_val in ids]
df = df[df[self.id_name].isin(converted_ids)]
except (ValueError, TypeError) as e:
raise ValueError(f"Cannot convert IDs {ids} to type {target_type} for field '{self.id_name}': {e}")
else:
# No ID field specified - use row indices
try:
int_ids = [int(id_val) for id_val in ids]
except (ValueError, TypeError) as e:
raise ValueError(f"Cannot convert IDs {ids} to integers for row-based access: {e}")
# Validate indices first
info = pyogrio.read_info(self.filename)
total_rows = info['features']
valid_indices = [i for i in int_ids if 0 <= i < total_rows]
if len(valid_indices) != len(int_ids):
invalid_indices = [i for i in int_ids if i < 0 or i >= total_rows]
raise ValueError(f"Invalid row indices {invalid_indices}. File has {total_rows} rows (0-{total_rows-1})")
# Optimize for single row case
if len(valid_indices) == 1:
# Read just the single row using slice
index = valid_indices[0]
df = gpd.read_file(self.filename, rows=slice(index, index + 1))
else:
# Read full file and select specific rows
# Note: gpd.read_file(rows=list) is not supported, so we read all and filter
df = gpd.read_file(self.filename)
df = df.iloc[valid_indices]
return df
def _addStandardNames(self, df: gpd.GeoDataFrame) -> gpd.GeoDataFrame:
"""Convert native column names to standard names.
Parameters
----------
df : gpd.GeoDataFrame
GeoDataFrame with native column names.
Returns
-------
gpd.GeoDataFrame
GeoDataFrame with standard column names added.
"""
# Map ID field if it exists, otherwise create row-based IDs
if self.id_name is not None and self.id_name in df.columns:
df[names.ID] = df[self.id_name]
else:
# For row-based access or when ID field doesn't exist, use row indices
df[names.ID] = range(len(df))
# No other standard name mappings for generic shapefiles
return df