diff --git a/src/shapefile.py b/src/shapefile.py index 5ba3c61..41f6359 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -21,7 +21,24 @@ import zipfile from datetime import date from struct import Struct, calcsize, error, pack, unpack -from typing import IO, Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union +from typing import ( + IO, + Any, + Collection, + Container, + Generic, + Iterable, + Iterator, + NoReturn, + Optional, + Protocol, + Reversible, + Sequence, + TypedDict, + TypeVar, + Union, + overload, +) from urllib.error import HTTPError from urllib.parse import urlparse, urlunparse from urllib.request import Request, urlopen @@ -88,8 +105,9 @@ 5: "RING", } -# Custom type variables +## Custom type variables +T = TypeVar("T") Point2D = tuple[float, float] PointZ = tuple[float, float, float] PointZM = tuple[float, float, float, float] @@ -99,12 +117,24 @@ BBox = tuple[float, float, float, float] + +class BinaryWritable(Protocol): + def write(self, data: bytes): ... + + +class BinaryWritableSeekable(BinaryWritable): + def seek(self, i: int): ... # pylint: disable=unused-argument + def tell(self): ... + + # File name, file object or anything with a read() method that returns bytes. -# TODO: Create simple Protocol with a read() method pylint: disable=fixme BinaryFileT = Union[str, IO[bytes]] -BinaryFileStreamT = Union[IO[bytes], io.BytesIO] +BinaryFileStreamT = Union[IO[bytes], io.BytesIO, BinaryWritableSeekable] -RecordValue = Union[float, str, date] +FieldTuple = tuple[str, str, int, bool] +RecordValue = Union[ + bool, int, float, str, date +] # A Possible value in a Shapefile record, e.g. L, N, F, C, D types class GeoJsonShapeT(TypedDict): @@ -114,6 +144,11 @@ class GeoJsonShapeT(TypedDict): ] +class HasGeoInterface(Protocol): + @property + def __geo_interface__(self) -> Any: ... + + # Helpers MISSING = [None, ""] @@ -158,9 +193,13 @@ def is_string(v: Any) -> bool: return isinstance(v, str) -def pathlike_obj(path: Any) -> Any: +@overload +def fsdecode_if_pathlike(path: os.PathLike) -> str: ... +@overload +def fsdecode_if_pathlike(path: T) -> T: ... +def fsdecode_if_pathlike(path): if isinstance(path, os.PathLike): - return os.fsdecode(path) + return os.fsdecode(path) # str return path @@ -168,7 +207,7 @@ def pathlike_obj(path: Any) -> Any: # Begin -class _Array(array.array): +class _Array(array.array, Generic[T]): """Converts python tuples to lists of the appropriate type. Used to unpack different shapefile header parts.""" @@ -215,7 +254,7 @@ def ring_bbox(coords: Coords) -> BBox: return bbox -def bbox_overlap(bbox1: BBox, bbox2: BBox) -> bool: +def bbox_overlap(bbox1: BBox, bbox2: Collection[float]) -> bool: """Tests whether two bounding boxes overlap.""" xmin1, ymin1, xmax1, ymax1 = bbox1 xmin2, ymin2, xmax2, ymax2 = bbox2 @@ -472,8 +511,8 @@ def __init__( self, shapeType: int = NULL, points: Optional[list[Coord]] = None, - parts: Optional[list[int]] = None, - partTypes: Optional[list[int]] = None, + parts: Optional[Sequence[int]] = None, + partTypes: Optional[Sequence[int]] = None, oid: Optional[int] = None, ): """Stores the geometry of the different shape types @@ -502,6 +541,10 @@ def __init__( else: self.__oid = -1 + # self.z: Optional[Union[list[Optional[float]], _Array[float]]] = None + # self.m: Optional[list[Optional[float]]] = None + # self.bbox: Optional[_Array[float]] = None + @property def __geo_interface__(self) -> GeoJsonShapeT: if self.shapeType in [POINT, POINTM, POINTZ]: @@ -973,7 +1016,7 @@ def _assert_ext_is_supported(self, ext: str): def __init__( self, - shapefile_path: str = "", + shapefile_path: Union[str, os.PathLike] = "", /, *, encoding: str = "utf-8", @@ -992,14 +1035,14 @@ def __init__( self.shpLength: Optional[int] = None self.numRecords: Optional[int] = None self.numShapes: Optional[int] = None - self.fields: list[list[str]] = [] + self.fields: list[FieldTuple] = [] self.__dbfHdrLength = 0 self.__fieldLookup: dict[str, int] = {} self.encoding = encoding self.encodingErrors = encodingErrors # See if a shapefile name was passed as the first argument if shapefile_path: - path = pathlike_obj(shapefile_path) + path = fsdecode_if_pathlike(shapefile_path) if is_string(path): if ".zip" in path: # Shapefile is inside a zipfile @@ -1349,7 +1392,7 @@ def close(self): pass self._files_to_close = [] - def __getFileObj(self, f): + def __getFileObj(self, f: Optional[T]) -> T: """Checks to see if the requested shapefile file object is available. If not a ShapefileException is raised.""" if not f: @@ -1405,14 +1448,17 @@ def __shpHeader(self): # pylint: enable=attribute-defined-outside-init - def __shape(self, oid=None, bbox=None): + def __shape( + self, oid: Optional[int] = None, bbox: Optional[BBox] = None + ) -> Optional[Shape]: """Returns the header info and geometry for a single shape.""" # pylint: disable=attribute-defined-outside-init f = self.__getFileObj(self.shp) record = Shape(oid=oid) - # Formerly we also set __zmin = __zmax = __mmin = __mmax = None - nParts = nPoints = None + # Previously, we also set __zmin = __zmax = __mmin = __mmax = None + nParts: Optional[int] = None + nPoints: Optional[int] = None (__recNum, recLength) = unpack(">2i", f.read(8)) # Determine the start of the next record next_shape = f.tell() + (2 * recLength) @@ -1422,51 +1468,65 @@ def __shape(self, oid=None, bbox=None): if shapeType == 0: record.points = [] # All shape types capable of having a bounding box - elif shapeType in (3, 5, 8, 13, 15, 18, 23, 25, 28, 31): - record.bbox = _Array("d", unpack("<4d", f.read(32))) + elif shapeType in (3, 13, 23, 5, 15, 25, 8, 18, 28, 31): + record.bbox = _Array[float]("d", unpack("<4d", f.read(32))) # type: ignore [attr-defined] # if bbox specified and no overlap, skip this shape - if bbox is not None and not bbox_overlap(bbox, record.bbox): + if bbox is not None and not bbox_overlap(bbox, record.bbox): # type: ignore [attr-defined] # because we stop parsing this shape, skip to beginning of # next shape before we return f.seek(next_shape) return None # Shape types with parts - if shapeType in (3, 5, 13, 15, 23, 25, 31): + if shapeType in (3, 13, 23, 5, 15, 25, 31): nParts = unpack("= 16: - __mmin, __mmax = unpack("<2d", f.read(16)) - # Measure values less than -10e38 are nodata values according to the spec - if next_shape - f.tell() >= nPoints * 8: - record.m = [] - for m in _Array("d", unpack(f"<{nPoints}d", f.read(nPoints * 8))): - if m > NODATA: - record.m.append(m) - else: - record.m.append(None) - else: - record.m = [None for _ in range(nPoints)] + + # Read z extremes and values + if shapeType in (13, 15, 18, 31): + __zmin, __zmax = unpack("<2d", f.read(16)) + record.z = _Array[float]( # type: ignore [attr-defined] + "d", unpack(f"<{nPoints}d", f.read(nPoints * 8)) + ) + + # Read m extremes and values + if shapeType in (13, 23, 15, 25, 18, 28, 31): + if next_shape - f.tell() >= 16: + __mmin, __mmax = unpack("<2d", f.read(16)) + # Measure values less than -10e38 are nodata values according to the spec + if next_shape - f.tell() >= nPoints * 8: + record.m = [] # type: ignore [attr-defined] + for m in _Array[float]( + "d", unpack(f"<{nPoints}d", f.read(nPoints * 8)) + ): + if m > NODATA: + record.m.append(m) # type: ignore [attr-defined] + else: + record.m.append(None) # type: ignore [attr-defined] + else: + record.m = [None for _ in range(nPoints)] # type: ignore [attr-defined] + # Read a single point if shapeType in (1, 11, 21): - record.points = [_Array("d", unpack("<2d", f.read(16)))] + array_2D = _Array[float]("d", unpack("<2d", f.read(16))) + + record.points = [tuple(array_2D)] if bbox is not None: # create bounding box for Point by duplicating coordinates point_bbox = list(record.points[0] + record.points[0]) @@ -1474,9 +1534,11 @@ def __shape(self, oid=None, bbox=None): if not bbox_overlap(bbox, point_bbox): f.seek(next_shape) return None + # Read a single Z value if shapeType == 11: - record.z = list(unpack("= 8: @@ -1485,14 +1547,17 @@ def __shape(self, oid=None, bbox=None): m = NODATA # Measure values less than -10e38 are nodata values according to the spec if m > NODATA: - record.m = [m] + record.m = [m] # type: ignore [attr-defined] else: - record.m = [None] + record.m = [None] # type: ignore [attr-defined] + # pylint: enable=attribute-defined-outside-init # Seek to the end of this record as defined by the record header because # the shapefile spec doesn't require the actual content to meet the header # definition. Probably allowed for lazy feature deletion. + f.seek(next_shape) + return record def __shxHeader(self): @@ -1517,12 +1582,12 @@ def __shxOffsets(self): # Jump to the first record. shx.seek(100) # Each index record consists of two nrs, we only want the first one - shxRecords = _Array("i", shx.read(2 * self.numShapes * 4)) + shxRecords = _Array[int]("i", shx.read(2 * self.numShapes * 4)) if sys.byteorder != "big": shxRecords.byteswap() - self._offsets = [2 * el for el in shxRecords[::2]] + self._offsets: list[int] = [2 * el for el in shxRecords[::2]] - def __shapeIndex(self, i=None): + def __shapeIndex(self, i: Optional[int] = None) -> Optional[int]: """Returns the offset in a .shp file for a shape based on information in the .shx index file.""" shx = self.shx @@ -1534,7 +1599,7 @@ def __shapeIndex(self, i=None): self.__shxOffsets() return self._offsets[i] - def shape(self, i=0, bbox=None): + def shape(self, i: int = 0, bbox: Optional[BBox] = None) -> Optional[Shape]: """Returns a shape object for a shape in the geometry record file. If the 'bbox' arg is given (list or tuple of xmin,ymin,xmax,ymax), @@ -1572,7 +1637,7 @@ def shape(self, i=0, bbox=None): shp.seek(offset) return self.__shape(oid=i, bbox=bbox) - def shapes(self, bbox=None): + def shapes(self, bbox: Optional[BBox] = None) -> Shapes: """Returns all shapes in a shapefile. To only read shapes within a given spatial region, specify the 'bbox' arg as a list or tuple of xmin,ymin,xmax,ymax. @@ -1581,7 +1646,7 @@ def shapes(self, bbox=None): shapes.extend(self.iterShapes(bbox=bbox)) return shapes - def iterShapes(self, bbox=None): + def iterShapes(self, bbox: Optional[BBox] = None) -> Iterator[Optional[Shape]]: """Returns a generator of shapes in a shapefile. Useful for handling large shapefiles. To only read shapes within a given spatial region, specify the 'bbox' @@ -1675,7 +1740,7 @@ def __dbfHeader(self): # pylint: enable=attribute-defined-outside-init - def __recordFmt(self, fields=None): + def __recordFmt(self, fields: Optional[Container[str]] = None) -> tuple[str, int]: """Calculates the format and size of a .dbf record. Optional 'fields' arg specifies which fieldnames to unpack and which to ignore. Note that this always includes the DeletionFlag at index 0, regardless of the 'fields' arg. @@ -1701,7 +1766,9 @@ def __recordFmt(self, fields=None): fmtSize += 1 return (fmt, fmtSize) - def __recordFields(self, fields=None): + def __recordFields( + self, fields: Optional[Iterable[str]] = None + ) -> tuple[list[FieldTuple], dict[str, int], Struct]: """Returns the necessary info required to unpack a record's fields, restricted to a subset of fieldnames 'fields' if specified. Returns a list of field info tuples, a name-index lookup dict, @@ -1711,19 +1778,19 @@ def __recordFields(self, fields=None): if fields is not None: # restrict info to the specified fields # first ignore repeated field names (order doesn't matter) - fields = list(set(fields)) + unique_fields = list(set(fields)) # get the struct - fmt, __fmtSize = self.__recordFmt(fields=fields) + fmt, __fmtSize = self.__recordFmt(fields=unique_fields) recStruct = Struct(fmt) # make sure the given fieldnames exist - for name in fields: + for name in unique_fields: if name not in self.__fieldLookup or name == "DeletionFlag": raise ValueError(f'"{name}" is not a valid field name') # fetch relevant field info tuples fieldTuples = [] for fieldinfo in self.fields[1:]: name = fieldinfo[0] - if name in fields: + if name in unique_fields: fieldTuples.append(fieldinfo) # store the field positions recLookup = {f[0]: i for i, f in enumerate(fieldTuples)} @@ -1736,7 +1803,7 @@ def __recordFields(self, fields=None): def __record( self, - fieldTuples: list[tuple[str, str, int, bool]], + fieldTuples: list[FieldTuple], recLookup: dict[str, int], recStruct: Struct, oid: Optional[int] = None, @@ -1983,25 +2050,27 @@ class Writer: def __init__( self, - target=None, - shapeType=None, - autoBalance=False, - encoding="utf-8", - encodingErrors="strict", + target: Union[str, os.PathLike, None] = None, + shapeType: Optional[int] = None, + autoBalance: bool = False, *, - shp=None, - shx=None, - dbf=None, + encoding: str = "utf-8", + encodingErrors: str = "strict", + shp: Optional[BinaryWritableSeekable] = None, + shx: Optional[BinaryWritableSeekable] = None, + dbf: Optional[BinaryWritableSeekable] = None, **kwargs, # pylint: disable=unused-argument ): self.target = target self.autoBalance = autoBalance - self.fields = [] + self.fields: list[FieldTuple] = [] self.shapeType = shapeType - self.shp = self.shx = self.dbf = None - self._files_to_close = [] + self.shp: Optional[BinaryFileStreamT] = None + self.shx: Optional[BinaryFileStreamT] = None + self.dbf: Optional[BinaryFileStreamT] = None + self._files_to_close: list[BinaryFileStreamT] = [] if target: - target = pathlike_obj(target) + target = fsdecode_if_pathlike(target) if not is_string(target): raise TypeError( f"The target filepath {target!r} must be of type str/unicode or path-like, not {type(target)}." @@ -2106,7 +2175,15 @@ def close(self): pass self._files_to_close = [] - def __getFileObj(self, f: Union[IO[bytes], str]) -> IO[bytes]: + W = TypeVar("W", bound=BinaryWritableSeekable) + + @overload + def __getFileObj(self, f: str) -> IO[bytes]: ... + @overload + def __getFileObj(self, f: None) -> NoReturn: ... + @overload + def __getFileObj(self, f: W) -> W: ... + def __getFileObj(self, f): """Safety handler to verify file-like objects""" if not f: raise ShapefileException("No file-like object available.") @@ -2210,8 +2287,8 @@ def __mbox(self, s): return mbox @property - def shapeTypeName(self): - return SHAPETYPE_LOOKUP[self.shapeType] + def shapeTypeName(self) -> str: + return SHAPETYPE_LOOKUP[self.shapeType or 0] def bbox(self): """Returns the current bounding box for the shapefile which is @@ -2227,7 +2304,11 @@ def mbox(self): """Returns the current m extremes for the shapefile.""" return self._mbox - def __shapefileHeader(self, fileObj, headerType="shp"): + def __shapefileHeader( + self, + fileObj: Optional[BinaryWritableSeekable], + headerType: str = "shp", + ): """Writes the specified header type to the specified file-like object. Several of the shapefile formats are so similar that a single generic method to read or write them is warranted.""" @@ -2339,14 +2420,17 @@ def __dbfHeader(self): # Terminator f.write(b"\r") - def shape(self, s): + def shape( + self, + s: Union[Shape, HasGeoInterface, dict], + ): # Balance if already not balanced if self.autoBalance and self.recNum < self.shpNum: self.balance() # Check is shape or import from geojson if not isinstance(s, Shape): if hasattr(s, "__geo_interface__"): - s = s.__geo_interface__ + s = s.__geo_interface__ # type: ignore [assignment] if isinstance(s, dict): s = Shape._from_geojson(s) else: @@ -2561,7 +2645,9 @@ def __shxRecord(self, offset, length): # pylint: enable=raise-missing-from - def record(self, *recordList, **recordDict): + def record( + self, *recordList: Iterable[RecordValue], **recordDict: dict[str, RecordValue] + ): """Creates a dbf attribute record. You can submit either a sequence of field values or keyword arguments of field names and values. Before adding records you must add fields for the record values using the @@ -2689,11 +2775,11 @@ def null(self): """Creates a null shape.""" self.shape(Shape(NULL)) - def point(self, x, y): + def point(self, x: float, y: float): """Creates a POINT shape.""" shapeType = POINT pointShape = Shape(shapeType) - pointShape.points.append([x, y]) + pointShape.points.append((x, y)) self.shape(pointShape) def pointm(self, x, y, m=None): @@ -2713,14 +2799,12 @@ def pointz(self, x, y, z=0, m=None): pointShape.points.append([x, y, z, m]) self.shape(pointShape) - def multipoint(self, points): + def multipoint(self, points: Coords): """Creates a MULTIPOINT shape. Points is a list of xy values.""" shapeType = MULTIPOINT - points = [ - points - ] # nest the points inside a list to be compatible with the generic shapeparts method - self._shapeparts(parts=points, shapeType=shapeType) + # nest the points inside a list to be compatible with the generic shapeparts method + self._shapeparts(parts=[points], shapeType=shapeType) def multipointm(self, points): """Creates a MULTIPOINTM shape. @@ -2835,9 +2919,8 @@ def _shapeparts(self, parts, shapeType): # add points for point in part: # Ensure point is list - if not isinstance(point, list): - point = list(point) - polyShape.points.append(point) + point_list = list(point) + polyShape.points.append(point_list) # write the shape self.shape(polyShape)