diff --git a/.github/actions/test/action.yml b/.github/actions/test/action.yml index c6ca65a..0184dfe 100644 --- a/.github/actions/test/action.yml +++ b/.github/actions/test/action.yml @@ -87,7 +87,7 @@ runs: working-directory: ${{ inputs.pyshp_repo_directory }} run: | python -m pip install --upgrade pip - pip install -r requirements.test.txt + pip install -e .[test] - name: Pytest shell: bash diff --git a/.github/workflows/run_tests_hooks_and_tools.yml b/.github/workflows/run_tests_hooks_and_tools.yml index 468b2e2..42c981e 100644 --- a/.github/workflows/run_tests_hooks_and_tools.yml +++ b/.github/workflows/run_tests_hooks_and_tools.yml @@ -30,41 +30,6 @@ jobs: run: | pylint --disable=R,C test_shapefile.py - test_on_EOL_Pythons: - strategy: - fail-fast: false - matrix: - python-version: [ - "2.7", - "3.5", - "3.6", - "3.7", - "3.8", - ] - - runs-on: ubuntu-latest - container: - image: python:${{ matrix.python-version }} - - steps: - - uses: actions/checkout@v4 - with: - path: ./Pyshp - - - name: Non-network tests - uses: ./Pyshp/.github/actions/test - with: - pyshp_repo_directory: ./Pyshp - python-version: ${{ matrix.python-version }} - - - name: Network tests - uses: ./Pyshp/.github/actions/test - with: - extra_args: '-m network' - replace_remote_urls_with_localhost: 'yes' - pyshp_repo_directory: ./Pyshp - python-version: ${{ matrix.python-version }} - test_on_supported_Pythons: strategy: fail-fast: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ffe59bf..3849c55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,8 +8,13 @@ repos: hooks: - id: isort name: isort (python) + args: ["--profile", "black"] - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - id: check-yaml - id: trailing-whitespace +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.17.0 + hooks: + - id: mypy \ No newline at end of file diff --git a/README.md b/README.md index c55e204..caf5f33 100644 --- a/README.md +++ b/README.md @@ -74,8 +74,6 @@ Both the Esri and XBase file-formats are very simple in design and memory efficient which is part of the reason the shapefile format remains popular despite the numerous ways to store and exchange GIS data available today. -Pyshp is compatible with Python 2.7-3.x. - This document provides examples for using PyShp to read and write shapefiles. However many more examples are continually added to the blog [http://GeospatialPython.com](http://GeospatialPython.com), and by searching for PyShp on [https://gis.stackexchange.com](https://gis.stackexchange.com). diff --git a/pyproject.toml b/pyproject.toml index df8e737..945c86c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,39 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" +[project] +name = "pyshp" +authors = [ + {name = "Joel Lawhead", email = "jlawhead@geospatialpython.com"}, +] +maintainers = [ + {name = "Karim Bahgat", email = "karim.bahgat.norway@gmail.com"} +] +readme = "README.md" +keywords = ["gis", "geospatial", "geographic", "shapefile", "shapefiles"] +description = "Pure Python read/write support for ESRI Shapefile format" +license = "MIT" +license-files = ["LICENSE.TXT"] +dynamic = ["version"] +requires-python = ">=3.9" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Topic :: Scientific/Engineering :: GIS", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", +] + +[project.optional-dependencies] +test = ["pytest"] + +[project.urls] +Repository = "https://github.com/GeospatialPython/pyshp" + +[tool.setuptools.dynamic] +version = {attr = "shapefile.__version__"} [tool.ruff] # Exclude a variety of commonly ignored directories. @@ -39,7 +72,7 @@ line-length = 88 indent-width = 4 # Assume Python 3.9 -target-version = "py37" +target-version = "py39" [tool.ruff.lint] # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. @@ -67,7 +100,6 @@ skip-magic-trailing-comma = false line-ending = "auto" - [tool.pylint.MASTER] load-plugins=[ "pylint_per_file_ignores", @@ -86,4 +118,4 @@ load-plugins=[ per-file-ignores = """ shapefile.py:W0212 test_shapefile.py:W0212 -""" \ No newline at end of file +""" diff --git a/requirements.test.txt b/requirements.test.txt deleted file mode 100644 index 1114173..0000000 --- a/requirements.test.txt +++ /dev/null @@ -1,2 +0,0 @@ -pytest >= 3.7 -setuptools diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 906abd3..0000000 --- a/setup.cfg +++ /dev/null @@ -1,30 +0,0 @@ -[metadata] -name = pyshp -version = attr: shapefile.__version__ -description = Pure Python read/write support for ESRI Shapefile format -long_description = file: README.md -long_description_content_type = text/markdown -author = Joel Lawhead -author_email = jlawhead@geospatialpython.com -maintainer = Karim Bahgat -maintainer_email = karim.bahgat.norway@gmail.com -url = https://github.com/GeospatialPython/pyshp -download_url = https://pypi.org/project/pyshp/ -license = MIT -license_files = LICENSE.TXT -keywords = gis, geospatial, geographic, shapefile, shapefiles -classifiers = - Development Status :: 5 - Production/Stable - Programming Language :: Python - Programming Language :: Python :: 2.7 - Programming Language :: Python :: 3 - Topic :: Scientific/Engineering :: GIS - Topic :: Software Development :: Libraries - Topic :: Software Development :: Libraries :: Python Modules - -[options] -py_modules = shapefile -python_requires = >=2.7 - -[bdist_wheel] -universal=1 diff --git a/setup.py b/setup.py deleted file mode 100644 index 6068493..0000000 --- a/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -from setuptools import setup - -setup() diff --git a/shapefile.py b/shapefile.py index 211fd48..e563298 100644 --- a/shapefile.py +++ b/shapefile.py @@ -3,12 +3,13 @@ Provides read and write support for ESRI Shapefiles. authors: jlawheadgeospatialpython.com maintainer: karim.bahgat.norwaygmail.com -Compatible with Python versions 2.7-3.x +Compatible with Python versions >=3.9 """ __version__ = "2.4.0" import array +import doctest import io import logging import os @@ -16,12 +17,19 @@ import tempfile import time import zipfile +from collections.abc import Collection from datetime import date from struct import Struct, calcsize, error, pack, unpack +from typing import IO, Any, Iterable, Iterator, Optional, Reversible, TypedDict, Union +from urllib.error import HTTPError +from urllib.parse import urlparse, urlunparse +from urllib.request import Request, urlopen # Create named logger logger = logging.getLogger(__name__) +doctest.NORMALIZE_WHITESPACE = 1 + # Module settings VERBOSE = True @@ -79,24 +87,23 @@ 5: "RING", } +# Custom type variables -# Python 2-3 handling - -PYTHON3 = sys.version_info[0] == 3 +Point2D = tuple[float, float] +PointZ = tuple[float, float, float] +PointZM = tuple[float, float, float, float] -if PYTHON3: - xrange = range - izip = zip +Coord = Union[Point2D, PointZ, PointZM] +Coords = list[Coord] - from urllib.error import HTTPError - from urllib.parse import urlparse, urlunparse - from urllib.request import Request, urlopen +BBox = tuple[float, float, float, float] -else: - from itertools import izip - from urllib2 import HTTPError, Request, urlopen - from urlparse import urlparse, urlunparse +class GeoJSONT(TypedDict): + type: str + coordinates: Union[ + tuple[()], Point2D, PointZ, PointZM, Coords, list[Coords], list[list[Coords]] + ] # Helpers @@ -104,92 +111,50 @@ MISSING = [None, ""] NODATA = -10e38 # as per the ESRI shapefile spec, only used for m-values. -if PYTHON3: - - def b(v, encoding="utf-8", encodingErrors="strict"): - if isinstance(v, str): - # For python 3 encode str to bytes. - return v.encode(encoding, encodingErrors) - elif isinstance(v, bytes): - # Already bytes. - return v - elif v is None: - # Since we're dealing with text, interpret None as "" - return b"" - else: - # Force string representation. - return str(v).encode(encoding, encodingErrors) - - def u(v, encoding="utf-8", encodingErrors="strict"): - if isinstance(v, bytes): - # For python 3 decode bytes to str. - return v.decode(encoding, encodingErrors) - elif isinstance(v, str): - # Already str. - return v - elif v is None: - # Since we're dealing with text, interpret None as "" - return "" - else: - # Force string representation. - return bytes(v).decode(encoding, encodingErrors) - - def is_string(v): - return isinstance(v, str) - -else: - - def b(v, encoding="utf-8", encodingErrors="strict"): - if isinstance(v, unicode): - # For python 2 encode unicode to bytes. - return v.encode(encoding, encodingErrors) - elif isinstance(v, bytes): - # Already bytes. - return v - elif v is None: - # Since we're dealing with text, interpret None as "" - return "" - else: - # Force string representation. - return unicode(v).encode(encoding, encodingErrors) - - def u(v, encoding="utf-8", encodingErrors="strict"): - if isinstance(v, bytes): - # For python 2 decode bytes to unicode. - return v.decode(encoding, encodingErrors) - elif isinstance(v, unicode): - # Already unicode. - return v - elif v is None: - # Since we're dealing with text, interpret None as "" - return "" - else: - # Force string representation. - return bytes(v).decode(encoding, encodingErrors) - def is_string(v): - return isinstance(v, basestring) +def b( + v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" +) -> bytes: + if isinstance(v, str): + # For python 3 encode str to bytes. + return v.encode(encoding, encodingErrors) + elif isinstance(v, bytes): + # Already bytes. + return v + elif v is None: + # Since we're dealing with text, interpret None as "" + return b"" + else: + # Force string representation. + return str(v).encode(encoding, encodingErrors) + + +def u( + v: Union[str, bytes], encoding: str = "utf-8", encodingErrors: str = "strict" +) -> str: + if isinstance(v, bytes): + # For python 3 decode bytes to str. + return v.decode(encoding, encodingErrors) + elif isinstance(v, str): + # Already str. + return v + elif v is None: + # Since we're dealing with text, interpret None as "" + return "" + else: + # Force string representation. + return bytes(v).decode(encoding, encodingErrors) -if sys.version_info[0:2] >= (3, 6): +def is_string(v: Any) -> bool: + return isinstance(v, str) - def pathlike_obj(path): - if isinstance(path, os.PathLike): - return os.fsdecode(path) - else: - return path -else: - - def pathlike_obj(path): - if is_string(path): - return path - elif hasattr(path, "__fspath__"): - return path.__fspath__() - else: - try: - return str(path) - except: - return path + +def pathlike_obj(path: Any) -> Any: + if isinstance(path, os.PathLike): + return os.fsdecode(path) + else: + return path # Begin @@ -203,7 +168,10 @@ def __repr__(self): return str(self.tolist()) -def signed_area(coords, fast=False): +def signed_area( + coords: Coords, + fast: bool = False, +) -> float: """Return the signed area enclosed by a ring using the linear time algorithm. A value >= 0 indicates a counter-clockwise oriented ring. A faster version is possible by setting 'fast' to True, which returns @@ -219,7 +187,7 @@ def signed_area(coords, fast=False): return area2 / 2.0 -def is_cw(coords): +def is_cw(coords: Coords) -> bool: """Returns True if a polygon ring has clockwise orientation, determined by a negatively signed area. """ @@ -227,35 +195,35 @@ def is_cw(coords): return area2 < 0 -def rewind(coords): +def rewind(coords: Reversible[Coord]) -> list[Coord]: """Returns the input coords in reversed order.""" return list(reversed(coords)) -def ring_bbox(coords): +def ring_bbox(coords: Coords) -> BBox: """Calculates and returns the bounding box of a ring.""" xs, ys = zip(*coords) bbox = min(xs), min(ys), max(xs), max(ys) return bbox -def bbox_overlap(bbox1, bbox2): - """Tests whether two bounding boxes overlap, returning a boolean""" +def bbox_overlap(bbox1: BBox, bbox2: BBox) -> bool: + """Tests whether two bounding boxes overlap.""" xmin1, ymin1, xmax1, ymax1 = bbox1 xmin2, ymin2, xmax2, ymax2 = bbox2 overlap = xmin1 <= xmax2 and xmax1 >= xmin2 and ymin1 <= ymax2 and ymax1 >= ymin2 return overlap -def bbox_contains(bbox1, bbox2): - """Tests whether bbox1 fully contains bbox2, returning a boolean""" +def bbox_contains(bbox1: BBox, bbox2: BBox) -> bool: + """Tests whether bbox1 fully contains bbox2.""" xmin1, ymin1, xmax1, ymax1 = bbox1 xmin2, ymin2, xmax2, ymax2 = bbox2 contains = xmin1 < xmin2 and xmax1 > xmax2 and ymin1 < ymin2 and ymax1 > ymax2 return contains -def ring_contains_point(coords, p): +def ring_contains_point(coords: list[Coord], p: Point2D) -> bool: """Fast point-in-polygon crossings algorithm, MacMartin optimization. Adapted from code by Eric Haynes @@ -300,7 +268,7 @@ def ring_contains_point(coords, p): return inside_flag -def ring_sample(coords, ccw=False): +def ring_sample(coords: list[Coord], ccw: bool = False) -> Point2D: """Return a sample point guaranteed to be within a ring, by efficiently finding the first centroid of a coordinate triplet whose orientation matches the orientation of the ring and passes the point-in-ring test. @@ -311,8 +279,7 @@ def ring_sample(coords, ccw=False): def itercoords(): # iterate full closed ring - for p in coords: - yield p + yield from coords # finally, yield the second coordinate to the end to allow checking the last triplet yield coords[1] @@ -348,12 +315,14 @@ def itercoords(): raise Exception("Unexpected error: Unable to find a ring sample point.") -def ring_contains_ring(coords1, coords2): +def ring_contains_ring(coords1: list[Coord], coords2: list[Point2D]) -> bool: """Returns True if all vertexes in coords2 are fully inside coords1.""" - return all((ring_contains_point(coords1, p2) for p2 in coords2)) + return all(ring_contains_point(coords1, p2) for p2 in coords2) -def organize_polygon_rings(rings, return_errors=None): +def organize_polygon_rings( + rings: Iterable[list[Coord]], return_errors: Optional[dict[str, int]] = None +) -> list[list[list[Coord]]]: """Organize a list of coordinate rings into one or more polygons with holes. Returns a list of polygons, where each polygon is composed of a single exterior ring, and one or more interior holes. If a return_errors dict is provided (optional), @@ -398,7 +367,9 @@ def organize_polygon_rings(rings, return_errors=None): return polys # first determine each hole's candidate exteriors based on simple bbox contains test - hole_exteriors = dict([(hole_i, []) for hole_i in xrange(len(holes))]) + hole_exteriors: dict[int, list[int]] = { + hole_i: [] for hole_i in range(len(holes)) + } exterior_bboxes = [ring_bbox(ring) for ring in exteriors] for hole_i in hole_exteriors.keys(): hole_bbox = ring_bbox(holes[hole_i]) @@ -478,9 +449,14 @@ def organize_polygon_rings(rings, return_errors=None): return polys -class Shape(object): +class Shape: def __init__( - self, shapeType=NULL, points=None, parts=None, partTypes=None, oid=None + self, + shapeType: int = NULL, + points: Optional[list[Coord]] = None, + parts: Optional[list[int]] = None, + partTypes: Optional[list[int]] = None, + oid: Optional[int] = None, ): """Stores the geometry of the different shape types specified in the Shapefile spec. Shape types are @@ -500,7 +476,7 @@ def __init__( self.partTypes = partTypes # and a dict to silently record any errors encountered - self._errors = {} + self._errors: dict[str, int] = {} # add oid if oid is not None: @@ -509,16 +485,18 @@ def __init__( self.__oid = -1 @property - def __geo_interface__(self): + def __geo_interface__(self) -> GeoJSONT: if self.shapeType in [POINT, POINTM, POINTZ]: # point if len(self.points) == 0: # the shape has no coordinate information, i.e. is 'empty' # the geojson spec does not define a proper null-geometry type # however, it does allow geometry types with 'empty' coordinates to be interpreted as null-geometries - return {"type": "Point", "coordinates": tuple()} + return {"type": "Point", "coordinates": ()} + # return {"type": "Point", "coordinates": tuple()} #type: ignore else: - return {"type": "Point", "coordinates": tuple(self.points[0])} + return {"type": "Point", "coordinates": self.points[0]} + # return {"type": "Point", "coordinates": tuple(self.points[0])} # type: ignore elif self.shapeType in [MULTIPOINT, MULTIPOINTM, MULTIPOINTZ]: if len(self.points) == 0: # the shape has no coordinate information, i.e. is 'empty' @@ -529,7 +507,8 @@ def __geo_interface__(self): # multipoint return { "type": "MultiPoint", - "coordinates": [tuple(p) for p in self.points], + "coordinates": self.points, + # "coordinates": [tuple(p) for p in self.points], #type: ignore } elif self.shapeType in [POLYLINE, POLYLINEM, POLYLINEZ]: if len(self.parts) == 0: @@ -541,7 +520,8 @@ def __geo_interface__(self): # linestring return { "type": "LineString", - "coordinates": [tuple(p) for p in self.points], + "coordinates": self.points, + # "coordinates": [tuple(p) for p in self.points], #type: ignore } else: # multilinestring @@ -552,10 +532,12 @@ def __geo_interface__(self): ps = part continue else: - coordinates.append([tuple(p) for p in self.points[ps:part]]) + # coordinates.append([tuple(p) for p in self.points[ps:part]]) + coordinates.append([p for p in self.points[ps:part]]) ps = part else: - coordinates.append([tuple(p) for p in self.points[part:]]) + # coordinates.append([tuple(p) for p in self.points[part:]]) + coordinates.append([p for p in self.points[part:]]) return {"type": "MultiLineString", "coordinates": coordinates} elif self.shapeType in [POLYGON, POLYGONM, POLYGONZ]: if len(self.parts) == 0: @@ -566,7 +548,7 @@ def __geo_interface__(self): else: # get all polygon rings rings = [] - for i in xrange(len(self.parts)): + for i in range(len(self.parts)): # get indexes of start and end points of the ring start = self.parts[i] try: @@ -575,7 +557,8 @@ def __geo_interface__(self): end = len(self.points) # extract the points that make up the ring - ring = [tuple(p) for p in self.points[start:end]] + # ring = [tuple(p) for p in self.points[start:end]] + ring = [p for p in self.points[start:end]] rings.append(ring) # organize rings into list of polygons, where each polygon is defined as list of rings. @@ -703,16 +686,16 @@ def _from_geojson(geoj): return shape @property - def oid(self): + def oid(self) -> int: """The index position of the shape in the original shapefile""" return self.__oid @property - def shapeTypeName(self): + def shapeTypeName(self) -> str: return SHAPETYPE_LOOKUP[self.shapeType] def __repr__(self): - return "Shape #{}: {}".format(self.__oid, self.shapeTypeName) + return f"Shape #{self.__oid}: {self.shapeTypeName}" class _Record(list): @@ -763,10 +746,10 @@ def __getattr__(self, item): index = self.__field_positions[item] return list.__getitem__(self, index) except KeyError: - raise AttributeError("{} is not a field name".format(item)) + raise AttributeError(f"{item} is not a field name") except IndexError: raise IndexError( - "{} found as a field but not enough values available.".format(item) + f"{item} found as a field but not enough values available." ) def __setattr__(self, key, value): @@ -783,7 +766,7 @@ def __setattr__(self, key, value): index = self.__field_positions[key] return list.__setitem__(self, index, value) except KeyError: - raise AttributeError("{} is not a field name".format(key)) + raise AttributeError(f"{key} is not a field name") def __getitem__(self, item): """ @@ -804,7 +787,7 @@ def __getitem__(self, item): if index is not None: return list.__getitem__(self, index) else: - raise IndexError('"{}" is not a field name and not an int'.format(item)) + raise IndexError(f'"{item}" is not a field name and not an int') def __setitem__(self, key, value): """ @@ -822,7 +805,7 @@ def __setitem__(self, key, value): if index is not None: return list.__setitem__(self, index, value) else: - raise IndexError("{} is not a field name and not an int".format(key)) + raise IndexError(f"{key} is not a field name and not an int") @property def oid(self): @@ -834,15 +817,15 @@ def as_dict(self, date_strings=False): Returns this Record as a dictionary using the field names as keys :return: dict """ - dct = dict((f, self[i]) for f, i in self.__field_positions.items()) + dct = {f: self[i] for f, i in self.__field_positions.items()} if date_strings: for k, v in dct.items(): if isinstance(v, date): - dct[k] = "{:04d}{:02d}{:02d}".format(v.year, v.month, v.day) + dct[k] = f"{v.year:04d}{v.month:02d}{v.day:02d}" return dct def __repr__(self): - return "Record #{}: {}".format(self.__oid, list(self)) + return f"Record #{self.__oid}: {list(self)}" def __dir__(self): """ @@ -866,7 +849,7 @@ def __eq__(self, other): return list.__eq__(self, other) -class ShapeRecord(object): +class ShapeRecord: """A ShapeRecord object containing a shape along with its attributes. Provides the GeoJSON __geo_interface__ to return a Feature dictionary.""" @@ -892,7 +875,7 @@ class Shapes(list): to return a GeometryCollection dictionary.""" def __repr__(self): - return "Shapes: {}".format(list(self)) + return f"Shapes: {list(self)}" @property def __geo_interface__(self): @@ -912,7 +895,7 @@ class ShapeRecords(list): to return a FeatureCollection dictionary.""" def __repr__(self): - return "ShapeRecords: {}".format(list(self)) + return f"ShapeRecords: {list(self)}" @property def __geo_interface__(self): @@ -929,7 +912,17 @@ class ShapefileException(Exception): pass -class Reader(object): +class _NoShpSentinel(object): + """For use as a default value for shp to preserve the + behaviour (from when all keyword args were gathered + in the **kwargs dict) in case someone explictly + called Reader(shp=None) to load self.shx. + """ + + pass + + +class Reader: """Reads the three files of a shapefile as a unit or separately. If one of the three files (.shp, .shx, .dbf) is missing no exception is thrown until you try @@ -950,24 +943,40 @@ class Reader(object): but they can be. """ - def __init__(self, *args, **kwargs): + CONSTITUENT_FILE_EXTS = ["shp", "shx", "dbf"] + assert all(ext.islower() for ext in CONSTITUENT_FILE_EXTS) + + def _assert_ext_is_supported(self, ext: str): + assert ext in self.CONSTITUENT_FILE_EXTS + + def __init__( + self, + shapefile_path: str = "", + *, + encoding="utf-8", + encodingErrors="strict", + shp=_NoShpSentinel, + shx=None, + dbf=None, + **kwargs, + ): self.shp = None self.shx = None self.dbf = None - self._files_to_close = [] + self._files_to_close: list[IO[bytes]] = [] self.shapeName = "Not specified" - self._offsets = [] + self._offsets: list[int] = [] self.shpLength = None self.numRecords = None self.numShapes = None - self.fields = [] + self.fields: list[list[str]] = [] self.__dbfHdrLength = 0 - self.__fieldLookup = {} - self.encoding = kwargs.pop("encoding", "utf-8") - self.encodingErrors = kwargs.pop("encodingErrors", "strict") + self.__fieldLookup: dict[str, int] = {} + self.encoding = encoding + self.encodingErrors = encodingErrors # See if a shapefile name was passed as the first argument - if len(args) > 0: - path = pathlike_obj(args[0]) + if shapefile_path: + path = pathlike_obj(shapefile_path) if is_string(path): if ".zip" in path: # Shapefile is inside a zipfile @@ -984,6 +993,8 @@ def __init__(self, *args, **kwargs): else: zpath = path[: path.find(".zip") + 4] shapefile = path[path.find(".zip") + 4 + 1 :] + + zipfileobj: Union[tempfile._TemporaryFileWrapper, io.BufferedReader] # Create a zip file handle if zpath.startswith("http"): # Zipfile is from a url @@ -1031,19 +1042,20 @@ def __init__(self, *args, **kwargs): shapefile = os.path.splitext(shapefile)[ 0 ] # root shapefile name - for ext in ["SHP", "SHX", "DBF", "shp", "shx", "dbf"]: - try: - member = archive.open(shapefile + "." + ext) - # write zipfile member data to a read+write tempfile and use as source, gets deleted on close() - fileobj = tempfile.NamedTemporaryFile( - mode="w+b", delete=True - ) - fileobj.write(member.read()) - fileobj.seek(0) - setattr(self, ext.lower(), fileobj) - self._files_to_close.append(fileobj) - except: - pass + for lower_ext in self.CONSTITUENT_FILE_EXTS: + for cased_ext in [lower_ext, lower_ext.upper()]: + try: + member = archive.open(f"{shapefile}.{cased_ext}") + # write zipfile member data to a read+write tempfile and use as source, gets deleted on close() + fileobj = tempfile.NamedTemporaryFile( + mode="w+b", delete=True + ) + fileobj.write(member.read()) + fileobj.seek(0) + setattr(self, lower_ext, fileobj) + self._files_to_close.append(fileobj) + except: + pass # Close and delete the temporary zipfile try: zipfileobj.close() @@ -1103,46 +1115,44 @@ def __init__(self, *args, **kwargs): self.load(path) return - # Otherwise, load from separate shp/shx/dbf args (must be path or file-like) - if "shp" in kwargs: - if hasattr(kwargs["shp"], "read"): - self.shp = kwargs["shp"] - # Copy if required - try: - self.shp.seek(0) - except (NameError, io.UnsupportedOperation): - self.shp = io.BytesIO(self.shp.read()) - else: - (baseName, ext) = os.path.splitext(kwargs["shp"]) - self.load_shp(baseName) + if shp is not _NoShpSentinel: + self.shp = self._seek_0_on_file_obj_wrap_or_open_from_name("shp", shp) + self.shx = self._seek_0_on_file_obj_wrap_or_open_from_name("shx", shx) - if "shx" in kwargs: - if hasattr(kwargs["shx"], "read"): - self.shx = kwargs["shx"] - # Copy if required - try: - self.shx.seek(0) - except (NameError, io.UnsupportedOperation): - self.shx = io.BytesIO(self.shx.read()) - else: - (baseName, ext) = os.path.splitext(kwargs["shx"]) - self.load_shx(baseName) - - if "dbf" in kwargs: - if hasattr(kwargs["dbf"], "read"): - self.dbf = kwargs["dbf"] - # Copy if required - try: - self.dbf.seek(0) - except (NameError, io.UnsupportedOperation): - self.dbf = io.BytesIO(self.dbf.read()) - else: - (baseName, ext) = os.path.splitext(kwargs["dbf"]) - self.load_dbf(baseName) + self.dbf = self._seek_0_on_file_obj_wrap_or_open_from_name("dbf", dbf) # Load the files if self.shp or self.dbf: - self.load() + self._try_to_set_constituent_file_headers() + + def _seek_0_on_file_obj_wrap_or_open_from_name( + self, + ext: str, + # File name, file object or anything with a read() method that returns bytes. + # TODO: Create simple Protocol with a read() method + file_: Optional[Union[str, IO[bytes]]], + ) -> Union[None, io.BytesIO, IO[bytes]]: + # assert ext in {'shp', 'dbf', 'shx'} + self._assert_ext_is_supported(ext) + + if file_ is None: + return None + + if isinstance(file_, str): + baseName, __ = os.path.splitext(file_) + return self._load_constituent_file(baseName, ext) + + if hasattr(file_, "read"): + # Copy if required + try: + file_.seek(0) # type: ignore + return file_ + except (NameError, io.UnsupportedOperation): + return io.BytesIO(file_.read()) + + raise ShapefileException( + f"Could not load shapefile constituent file from: {file_}" + ) def __str__(self): """ @@ -1156,9 +1166,7 @@ def __str__(self): ) ) if self.dbf: - info.append( - " {} records ({} fields)".format(len(self), len(self.fields)) - ) + info.append(f" {len(self)} records ({len(self.fields)} fields)") return "\n".join(info) def __enter__(self): @@ -1224,8 +1232,7 @@ def __len__(self): def __iter__(self): """Iterates through the shapes/records in the shapefile.""" - for shaperec in self.iterShapeRecords(): - yield shaperec + yield from self.iterShapeRecords() @property def __geo_interface__(self): @@ -1250,8 +1257,11 @@ def load(self, shapefile=None): self.load_dbf(shapeName) if not (self.shp or self.dbf): raise ShapefileException( - "Unable to open %s.dbf or %s.shp." % (shapeName, shapeName) + f"Unable to open {shapeName}.dbf or {shapeName}.shp." ) + self._try_to_set_constituent_file_headers() + + def _try_to_set_constituent_file_headers(self): if self.shp: self.__shpHeader() if self.dbf: @@ -1259,50 +1269,61 @@ def load(self, shapefile=None): if self.shx: self.__shxHeader() - def load_shp(self, shapefile_name): + def _try_get_open_constituent_file( + self, + shapefile_name: str, + ext: str, + ) -> Union[IO[bytes], None]: """ - Attempts to load file with .shp extension as both lower and upper case + Attempts to open a .shp, .dbf or .shx file, + with both lower case and upper case file extensions, + and return it. If it was not possible to open the file, None is returned. """ - shp_ext = "shp" + # typing.LiteralString is only available from PYthon 3.11 onwards. + # https://docs.python.org/3/library/typing.html#typing.LiteralString + # assert ext in {'shp', 'dbf', 'shx'} + self._assert_ext_is_supported(ext) + try: - self.shp = open("%s.%s" % (shapefile_name, shp_ext), "rb") - self._files_to_close.append(self.shp) - except IOError: + return open(f"{shapefile_name}.{ext}", "rb") + except OSError: try: - self.shp = open("%s.%s" % (shapefile_name, shp_ext.upper()), "rb") - self._files_to_close.append(self.shp) - except IOError: - pass + return open(f"{shapefile_name}.{ext.upper()}", "rb") + except OSError: + return None + + def _load_constituent_file( + self, + shapefile_name: str, + ext: str, + ) -> Union[IO[bytes], None]: + """ + Attempts to open a .shp, .dbf or .shx file, with the extension + as both lower and upper case, and if successful append it to + self._files_to_close. + """ + shp_dbf_or_dhx_file = self._try_get_open_constituent_file(shapefile_name, ext) + if shp_dbf_or_dhx_file is not None: + self._files_to_close.append(shp_dbf_or_dhx_file) + return shp_dbf_or_dhx_file + + def load_shp(self, shapefile_name): + """ + Attempts to load file with .shp extension as both lower and upper case + """ + self.shp = self._load_constituent_file(shapefile_name, "shp") def load_shx(self, shapefile_name): """ Attempts to load file with .shx extension as both lower and upper case """ - shx_ext = "shx" - try: - self.shx = open("%s.%s" % (shapefile_name, shx_ext), "rb") - self._files_to_close.append(self.shx) - except IOError: - try: - self.shx = open("%s.%s" % (shapefile_name, shx_ext.upper()), "rb") - self._files_to_close.append(self.shx) - except IOError: - pass + self.shx = self._load_constituent_file(shapefile_name, "shx") def load_dbf(self, shapefile_name): """ Attempts to load file with .dbf extension as both lower and upper case """ - dbf_ext = "dbf" - try: - self.dbf = open("%s.%s" % (shapefile_name, dbf_ext), "rb") - self._files_to_close.append(self.dbf) - except IOError: - try: - self.dbf = open("%s.%s" % (shapefile_name, dbf_ext.upper()), "rb") - self._files_to_close.append(self.dbf) - except IOError: - pass + self.dbf = self._load_constituent_file(shapefile_name, "dbf") def __del__(self): self.close() @@ -1313,7 +1334,7 @@ def close(self): if hasattr(attribute, "close"): try: attribute.close() - except IOError: + except OSError: pass self._files_to_close = [] @@ -1330,14 +1351,14 @@ def __getFileObj(self, f): self.load() return f - def __restrictIndex(self, i): + def __restrictIndex(self, i: int) -> int: """Provides list-like handling of a record index with a clearer error message if the index is out of bounds.""" if self.numRecords: rmax = self.numRecords - 1 if abs(i) > rmax: raise IndexError( - "Shape or Record index: %s out of range. Max index: %s" % (i, rmax) + f"Shape or Record index: {i} out of range. Max index: {rmax}" ) if i < 0: i = range(self.numRecords)[i] @@ -1406,7 +1427,7 @@ def __shape(self, oid=None, bbox=None): # Read points - produces a list of [x,y] values if nPoints: flat = unpack("<%sd" % (2 * nPoints), f.read(16 * nPoints)) - record.points = list(izip(*(iter(flat),) * 2)) + record.points = list(zip(*(iter(flat),) * 2)) # Read z extremes and values if shapeType in (13, 15, 18, 31): (zmin, zmax) = unpack("<2d", f.read(16)) @@ -1561,7 +1582,7 @@ def iterShapes(self, bbox=None): if self.numShapes: # Iterate exactly the number of shapes from shx header - for i in xrange(self.numShapes): + for i in range(self.numShapes): # MAYBE: check if more left of file or exit early? shape = self.__shape(oid=i, bbox=bbox) if shape: @@ -1624,7 +1645,7 @@ def __dbfHeader(self): # store all field positions for easy lookups # note: fieldLookup gives the index position of a field inside Reader.fields - self.__fieldLookup = dict((f[0], i) for i, f in enumerate(self.fields)) + self.__fieldLookup = {f[0]: i for i, f in enumerate(self.fields)} # by default, read all fields except the deletion flag, hence "[1:]" # note: recLookup gives the index position of a field inside a _Record list @@ -1676,7 +1697,7 @@ def __recordFields(self, fields=None): # make sure the given fieldnames exist for name in fields: if name not in self.__fieldLookup or name == "DeletionFlag": - raise ValueError('"{}" is not a valid field name'.format(name)) + raise ValueError(f'"{name}" is not a valid field name') # fetch relevant field info tuples fieldTuples = [] for fieldinfo in self.fields[1:]: @@ -1684,7 +1705,7 @@ def __recordFields(self, fields=None): if name in fields: fieldTuples.append(fieldinfo) # store the field positions - recLookup = dict((f[0], i) for i, f in enumerate(fieldTuples)) + recLookup = {f[0]: i for i, f in enumerate(fieldTuples)} else: # use all the dbf fields fieldTuples = self.fields[1:] # sans deletion flag @@ -1850,7 +1871,7 @@ def iterRecords(self, fields=None, start=0, stop=None): recSize = self.__recordLength f.seek(self.__dbfHdrLength + (start * recSize)) fieldTuples, recLookup, recStruct = self.__recordFields(fields) - for i in xrange(start, stop): + for i in range(start, stop): r = self.__record( oid=i, fieldTuples=fieldTuples, recLookup=recLookup, recStruct=recStruct ) @@ -1891,7 +1912,7 @@ def iterShapeRecords(self, fields=None, bbox=None): """ if bbox is None: # iterate through all shapes and records - for shape, record in izip( + for shape, record in zip( self.iterShapes(), self.iterRecords(fields=fields) ): yield ShapeRecord(shape=shape, record=record) @@ -1908,10 +1929,22 @@ def iterShapeRecords(self, fields=None, bbox=None): yield ShapeRecord(shape=shape, record=record) -class Writer(object): +class Writer: """Provides write support for ESRI Shapefiles.""" - def __init__(self, target=None, shapeType=None, autoBalance=False, **kwargs): + def __init__( + self, + target=None, + shapeType=None, + autoBalance=False, + encoding="utf-8", + encodingErrors="strict", + *, + shp=None, + shx=None, + dbf=None, + **kwargs, + ): self.target = target self.autoBalance = autoBalance self.fields = [] @@ -1929,8 +1962,7 @@ def __init__(self, target=None, shapeType=None, autoBalance=False, **kwargs): self.shp = self.__getFileObj(os.path.splitext(target)[0] + ".shp") self.shx = self.__getFileObj(os.path.splitext(target)[0] + ".shx") self.dbf = self.__getFileObj(os.path.splitext(target)[0] + ".dbf") - elif kwargs.get("shp") or kwargs.get("shx") or kwargs.get("dbf"): - shp, shx, dbf = kwargs.get("shp"), kwargs.get("shx"), kwargs.get("dbf") + elif shp or shx or dbf: if shp: self.shp = self.__getFileObj(shp) if shx: @@ -1955,8 +1987,8 @@ def __init__(self, target=None, shapeType=None, autoBalance=False, **kwargs): # Use deletion flags in dbf? Default is false (0). Note: Currently has no effect, records should NOT contain deletion flags. self.deletionFlag = 0 # Encoding - self.encoding = kwargs.pop("encoding", "utf-8") - self.encodingErrors = kwargs.pop("encodingErrors", "strict") + self.encoding = encoding + self.encodingErrors = encodingErrors def __len__(self): """Returns the current number of features written to the shapefile. @@ -2015,7 +2047,7 @@ def close(self): ): try: attribute.flush() - except IOError: + except OSError: pass # Close any files that the writer opened (but not those given by user) @@ -2023,17 +2055,15 @@ def close(self): if hasattr(attribute, "close"): try: attribute.close() - except IOError: + except OSError: pass self._files_to_close = [] - def __getFileObj(self, f): + def __getFileObj(self, f: Union[IO[bytes], str]) -> IO[bytes]: """Safety handler to verify file-like objects""" if not f: raise ShapefileException("No file-like object available.") - elif hasattr(f, "write"): - return f - else: + if isinstance(f, str): pth = os.path.split(f)[0] if pth and not os.path.exists(pth): os.makedirs(pth) @@ -2041,6 +2071,10 @@ def __getFileObj(self, f): self._files_to_close.append(fp) return fp + if hasattr(f, "write"): + return f + raise Exception(f"Unsupported file-like: {f}") + def __shpFileLength(self): """Calculates the file length of the shp file.""" # Remember starting position @@ -2494,7 +2528,7 @@ def record(self, *recordList, **recordDict): if self.autoBalance and self.recNum > self.shpNum: self.balance() - fieldCount = sum((1 for field in self.fields if field[0] != "DeletionFlag")) + fieldCount = sum(1 for field in self.fields if field[0] != "DeletionFlag") if recordList: record = list(recordList) while len(record) < fieldCount: @@ -2781,11 +2815,7 @@ def field(self, name, fieldType="C", size="50", decimal=0): # Begin Testing -def _get_doctests(): - import doctest - - doctest.NORMALIZE_WHITESPACE = 1 - +def _get_doctests() -> doctest.DocTest: # run tests with open("README.md", "rb") as fobj: tests = doctest.DocTestParser().get_doctest( @@ -2799,7 +2829,11 @@ def _get_doctests(): return tests -def _filter_network_doctests(examples, include_network=False, include_non_network=True): +def _filter_network_doctests( + examples: Iterable[doctest.Example], + include_network: bool = False, + include_non_network: bool = True, +) -> Iterator[doctest.Example]: globals_from_network_doctests = set() if not (include_network or include_non_network): @@ -2840,16 +2874,16 @@ def _filter_network_doctests(examples, include_network=False, include_non_networ def _replace_remote_url( - old_url, + old_url: str, # Default port of Python http.server and Python 2's SimpleHttpServer - port=8000, - scheme="http", - netloc="localhost", - path=None, - params="", - query="", - fragment="", -): + port: int = 8000, + scheme: str = "http", + netloc: str = "localhost", + path: Optional[str] = None, + params: str = "", + query: str = "", + fragment: str = "", +) -> str: old_parsed = urlparse(old_url) # Strip subpaths, so an artefacts @@ -2869,19 +2903,16 @@ def _replace_remote_url( fragment=fragment, ) - new_url = urlunparse(new_parsed) if PYTHON3 else urlunparse(list(new_parsed)) + new_url = urlunparse(new_parsed) return new_url -def _test(args=sys.argv[1:], verbosity=0): +def _test(args: list[str] = sys.argv[1:], verbosity: bool = False) -> int: if verbosity == 0: print("Getting doctests...") - import doctest import re - doctest.NORMALIZE_WHITESPACE = 1 - tests = _get_doctests() if len(args) >= 2 and args[0] == "-m": @@ -2909,9 +2940,6 @@ def _test(args=sys.argv[1:], verbosity=0): class Py23DocChecker(doctest.OutputChecker): def check_output(self, want, got, optionflags): - if sys.version_info[0] == 2: - got = re.sub("u'(.*?)'", "'\\1'", got) - got = re.sub('u"(.*?)"', '"\\1"', got) res = doctest.OutputChecker.check_output(self, want, got, optionflags) return res diff --git a/test_shapefile.py b/test_shapefile.py index 1b7182f..5f9b855 100644 --- a/test_shapefile.py +++ b/test_shapefile.py @@ -5,12 +5,7 @@ import datetime import json import os.path - -try: - from pathlib import Path -except ImportError: - # pathlib2 is a dependency of pytest >= 3.7 - from pathlib2 import Path +from pathlib import Path # third party imports import pytest