Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ and the bounding box area the shapefile covers:
>>> len(sf)
663
>>> sf.bbox
[-122.515048, 37.652916, -122.327622, 37.863433]
(-122.515048, 37.652916, -122.327622, 37.863433)

Finally, if you would prefer to work with the entire shapefile in a different
format, you can convert all of it to a GeoJSON dictionary, although you may lose
Expand Down
114 changes: 56 additions & 58 deletions src/shapefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from typing import (
IO,
Any,
Collection,
Container,
Generic,
Iterable,
Expand All @@ -34,7 +33,6 @@
Protocol,
Reversible,
Sequence,
TypedDict,
TypeVar,
Union,
overload,
Expand Down Expand Up @@ -109,12 +107,16 @@

T = TypeVar("T")
Point2D = tuple[float, float]
PointZ = tuple[float, float, float]
PointZM = tuple[float, float, float, float]
Point3D = tuple[float, float, float]
PointM = tuple[float, float, Optional[float]]
PointZ = tuple[float, float, float, Optional[float]]

Coord = Union[Point2D, PointZ, PointZM]
Coord = Union[Point2D, Point2D, Point3D]
Coords = list[Coord]

Point = Union[Point2D, PointM, PointZ]
Points = list[Point]

BBox = tuple[float, float, float, float]


Expand All @@ -131,19 +133,12 @@ def tell(self): ...
BinaryFileT = Union[str, IO[bytes]]
BinaryFileStreamT = Union[IO[bytes], io.BytesIO, BinaryWritableSeekable]

FieldTuple = tuple[str, str, int, bool]
FieldTuple = tuple[str, str, int, int]
RecordValue = Union[
bool, int, float, str, date
] # A Possible value in a Shapefile record, e.g. L, N, F, C, D types


class GeoJsonShapeT(TypedDict):
type: str
coordinates: Union[
tuple[()], Point2D, PointZ, PointZM, Coords, list[Coords], list[list[Coords]]
]


class HasGeoInterface(Protocol):
@property
def __geo_interface__(self) -> Any: ...
Expand Down Expand Up @@ -242,7 +237,7 @@ def is_cw(coords: Coords) -> bool:
return area2 < 0


def rewind(coords: Reversible[Coord]) -> list[Coord]:
def rewind(coords: Reversible[Coord]) -> Coords:
"""Returns the input coords in reversed order."""
return list(reversed(coords))

Expand All @@ -254,23 +249,23 @@ def ring_bbox(coords: Coords) -> BBox:
return bbox


def bbox_overlap(bbox1: BBox, bbox2: Collection[float]) -> bool:
def bbox_overlap(bbox1: BBox, bbox2: BBox) -> bool:
"""Tests whether two bounding boxes overlap."""
xmin1, ymin1, xmax1, ymax1 = bbox1
xmin2, ymin2, xmax2, ymax2 = bbox2
overlap = xmin1 <= xmax2 and xmax1 >= xmin2 and ymin1 <= ymax2 and ymax1 >= ymin2
overlap = xmin1 <= xmax2 and xmin2 <= xmax1 and ymin1 <= ymax2 and ymin2 <= ymax1
return overlap


def bbox_contains(bbox1: BBox, bbox2: BBox) -> bool:
"""Tests whether bbox1 fully contains bbox2."""
xmin1, ymin1, xmax1, ymax1 = bbox1
xmin2, ymin2, xmax2, ymax2 = bbox2
contains = xmin1 < xmin2 and xmax1 > xmax2 and ymin1 < ymin2 and ymax1 > ymax2
contains = xmin1 < xmin2 and xmax2 < xmax1 and ymin1 < ymin2 and ymax2 < ymax1
return contains


def ring_contains_point(coords: list[Coord], p: Point2D) -> bool:
def ring_contains_point(coords: Coords, p: Point2D) -> bool:
"""Fast point-in-polygon crossings algorithm, MacMartin optimization.

Adapted from code by Eric Haynes
Expand Down Expand Up @@ -319,7 +314,7 @@ class RingSamplingError(Exception):
pass


def ring_sample(coords: list[Coord], ccw: bool = False) -> Point2D:
def ring_sample(coords: Coords, ccw: bool = False) -> Point2D:
"""Return a sample point guaranteed to be within a ring, by efficiently
finding the first centroid of a coordinate triplet whose orientation
matches the orientation of the ring and passes the point-in-ring test.
Expand Down Expand Up @@ -369,14 +364,14 @@ def itercoords():
)


def ring_contains_ring(coords1: list[Coord], coords2: list[Point2D]) -> bool:
def ring_contains_ring(coords1: Coords, coords2: list[Point2D]) -> bool:
"""Returns True if all vertexes in coords2 are fully inside coords1."""
return all(ring_contains_point(coords1, p2) for p2 in coords2)


def organize_polygon_rings(
rings: Iterable[list[Coord]], return_errors: Optional[dict[str, int]] = None
) -> list[list[list[Coord]]]:
rings: Iterable[Coords], return_errors: Optional[dict[str, int]] = None
) -> list[list[Coords]]:
"""Organize a list of coordinate rings into one or more polygons with holes.
Returns a list of polygons, where each polygon is composed of a single exterior
ring, and one or more interior holes. If a return_errors dict is provided (optional),
Expand Down Expand Up @@ -510,7 +505,7 @@ class Shape:
def __init__(
self,
shapeType: int = NULL,
points: Optional[list[Coord]] = None,
points: Optional[Points] = None,
parts: Optional[Sequence[int]] = None,
partTypes: Optional[Sequence[int]] = None,
oid: Optional[int] = None,
Expand Down Expand Up @@ -546,7 +541,7 @@ def __init__(
# self.bbox: Optional[_Array[float]] = None

@property
def __geo_interface__(self) -> GeoJsonShapeT:
def __geo_interface__(self):
if self.shapeType in [POINT, POINTM, POINTZ]:
# point
if len(self.points) == 0:
Expand Down Expand Up @@ -1434,7 +1429,7 @@ def __shpHeader(self):
shp.seek(32)
self.shapeType = unpack("<i", shp.read(4))[0]
# The shapefile's bounding box (lower left, upper right)
self.bbox = _Array("d", unpack("<4d", shp.read(32)))
self.bbox: BBox = tuple(_Array("d", unpack("<4d", shp.read(32))))
# Elevation
self.zbox = _Array("d", unpack("<2d", shp.read(16)))
# Measure
Expand Down Expand Up @@ -1469,7 +1464,7 @@ def __shape(
record.points = []
# All shape types capable of having a bounding box
elif shapeType in (3, 13, 23, 5, 15, 25, 8, 18, 28, 31):
record.bbox = _Array[float]("d", unpack("<4d", f.read(32))) # type: ignore [attr-defined]
record.bbox = tuple(_Array[float]("d", unpack("<4d", f.read(32)))) # type: ignore [attr-defined]
# if bbox specified and no overlap, skip this shape
if bbox is not None and not bbox_overlap(bbox, record.bbox): # type: ignore [attr-defined]
# because we stop parsing this shape, skip to beginning of
Expand Down Expand Up @@ -1524,14 +1519,13 @@ def __shape(

# Read a single point
if shapeType in (1, 11, 21):
array_2D = _Array[float]("d", unpack("<2d", f.read(16)))
x, y = _Array[float]("d", unpack("<2d", f.read(16)))

record.points = [tuple(array_2D)]
record.points = [(x, y)]
if bbox is not None:
# create bounding box for Point by duplicating coordinates
point_bbox = list(record.points[0] + record.points[0])
# skip shape if no overlap with bounding box
if not bbox_overlap(bbox, point_bbox):
if not bbox_overlap(bbox, (x, y, x, y)):
f.seek(next_shape)
return None

Expand Down Expand Up @@ -2782,21 +2776,21 @@ def point(self, x: float, y: float):
pointShape.points.append((x, y))
self.shape(pointShape)

def pointm(self, x, y, m=None):
def pointm(self, x: float, y: float, m: Optional[float] = None):
"""Creates a POINTM shape.
If the m (measure) value is not set, it defaults to NoData."""
shapeType = POINTM
pointShape = Shape(shapeType)
pointShape.points.append([x, y, m])
pointShape.points.append((x, y, m))
self.shape(pointShape)

def pointz(self, x, y, z=0, m=None):
def pointz(self, x: float, y: float, z: float = 0.0, m: Optional[float] = None):
"""Creates a POINTZ shape.
If the z (elevation) value is not set, it defaults to 0.
If the m (measure) value is not set, it defaults to NoData."""
shapeType = POINTZ
pointShape = Shape(shapeType)
pointShape.points.append([x, y, z, m])
pointShape.points.append((x, y, z, m))
self.shape(pointShape)

def multipoint(self, points: Coords):
Expand All @@ -2806,57 +2800,53 @@ def multipoint(self, points: Coords):
# nest the points inside a list to be compatible with the generic shapeparts method
self._shapeparts(parts=[points], shapeType=shapeType)

def multipointm(self, points):
def multipointm(self, points: list[PointM]):
"""Creates a MULTIPOINTM shape.
Points is a list of xym values.
If the m (measure) value is not included, it defaults to None (NoData)."""
shapeType = MULTIPOINTM
points = [
points
] # nest the points inside a list to be compatible with the generic shapeparts method
self._shapeparts(parts=points, shapeType=shapeType)
# nest the points inside a list to be compatible with the generic shapeparts method
self._shapeparts(parts=[points], shapeType=shapeType)

def multipointz(self, points):
"""Creates a MULTIPOINTZ shape.
Points is a list of xyzm values.
If the z (elevation) value is not included, it defaults to 0.
If the m (measure) value is not included, it defaults to None (NoData)."""
shapeType = MULTIPOINTZ
points = [
points
] # nest the points inside a list to be compatible with the generic shapeparts method
self._shapeparts(parts=points, shapeType=shapeType)
# nest the points inside a list to be compatible with the generic shapeparts method
self._shapeparts(parts=[points], shapeType=shapeType)

def line(self, lines):
def line(self, lines: list[Coords]):
"""Creates a POLYLINE shape.
Lines is a collection of lines, each made up of a list of xy values."""
shapeType = POLYLINE
self._shapeparts(parts=lines, shapeType=shapeType)

def linem(self, lines):
def linem(self, lines: list[Points]):
"""Creates a POLYLINEM shape.
Lines is a collection of lines, each made up of a list of xym values.
If the m (measure) value is not included, it defaults to None (NoData)."""
shapeType = POLYLINEM
self._shapeparts(parts=lines, shapeType=shapeType)

def linez(self, lines):
def linez(self, lines: list[Points]):
"""Creates a POLYLINEZ shape.
Lines is a collection of lines, each made up of a list of xyzm values.
If the z (elevation) value is not included, it defaults to 0.
If the m (measure) value is not included, it defaults to None (NoData)."""
shapeType = POLYLINEZ
self._shapeparts(parts=lines, shapeType=shapeType)

def poly(self, polys):
def poly(self, polys: list[Coords]):
"""Creates a POLYGON shape.
Polys is a collection of polygons, each made up of a list of xy values.
Note that for ordinary polygons the coordinates must run in a clockwise direction.
If some of the polygons are holes, these must run in a counterclockwise direction."""
shapeType = POLYGON
self._shapeparts(parts=polys, shapeType=shapeType)

def polym(self, polys):
def polym(self, polys: list[Points]):
"""Creates a POLYGONM shape.
Polys is a collection of polygons, each made up of a list of xym values.
Note that for ordinary polygons the coordinates must run in a clockwise direction.
Expand All @@ -2865,7 +2855,7 @@ def polym(self, polys):
shapeType = POLYGONM
self._shapeparts(parts=polys, shapeType=shapeType)

def polyz(self, polys):
def polyz(self, polys: list[Points]):
"""Creates a POLYGONZ shape.
Polys is a collection of polygons, each made up of a list of xyzm values.
Note that for ordinary polygons the coordinates must run in a clockwise direction.
Expand All @@ -2875,7 +2865,7 @@ def polyz(self, polys):
shapeType = POLYGONZ
self._shapeparts(parts=polys, shapeType=shapeType)

def multipatch(self, parts, partTypes):
def multipatch(self, parts: list[list[PointZ]], partTypes: list[int]):
"""Creates a MULTIPATCH shape.
Parts is a collection of 3D surface patches, each made up of a list of xyzm values.
PartTypes is a list of types that define each of the surface patches.
Expand All @@ -2891,11 +2881,12 @@ def multipatch(self, parts, partTypes):
# set part index position
polyShape.parts.append(len(polyShape.points))
# add points
for point in part:
# Ensure point is list
if not isinstance(point, list):
point = list(point)
polyShape.points.append(point)
# for point in part:
# # Ensure point is list
# if not isinstance(point, list):
# point = list(point)
# polyShape.points.append(point)
polyShape.points.extend(part)
polyShape.partTypes = partTypes
# write the shape
self.shape(polyShape)
Expand Down Expand Up @@ -2924,13 +2915,20 @@ def _shapeparts(self, parts, shapeType):
# write the shape
self.shape(polyShape)

def field(self, name, fieldType="C", size="50", decimal=0):
def field(
# Types of args should match *FieldTuple
self,
name: str,
fieldType: str = "C",
size: int = 50,
decimal: int = 0,
):
"""Adds a dbf field descriptor to the shapefile."""
if fieldType == "D":
size = "8"
size = 8
decimal = 0
elif fieldType == "L":
size = "1"
size = 1
decimal = 0
if len(self.fields) >= 2046:
raise ShapefileException(
Expand Down
Loading