diff --git a/run_benchmarks.py b/run_benchmarks.py index 961357f..c3a583e 100644 --- a/run_benchmarks.py +++ b/run_benchmarks.py @@ -7,6 +7,7 @@ import os import timeit from collections.abc import Callable +from os import PathLike from pathlib import Path from tempfile import TemporaryFile as TempF from typing import Iterable, Union, cast @@ -50,14 +51,14 @@ def benchmark( shapeRecords = collections.defaultdict(list) -def open_shapefile_with_PyShp(target: Union[str, os.PathLike]): +def open_shapefile_with_PyShp(target: Union[str, PathLike]): with shapefile.Reader(target) as r: fields[target] = r.fields for shapeRecord in r.iterShapeRecords(): shapeRecords[target].append(shapeRecord) -def write_shapefile_with_PyShp(target: Union[str, os.PathLike]): +def write_shapefile_with_PyShp(target: Union[str, PathLike]): with TempF("wb") as shp, TempF("wb") as dbf, TempF("wb") as shx: with shapefile.Writer(shp=shp, dbf=dbf, shx=shx) as w: # type: ignore [arg-type] for field_info_tuple in fields[target]: diff --git a/src/shapefile.py b/src/shapefile.py index 5ec1077..1257b2b 100644 --- a/src/shapefile.py +++ b/src/shapefile.py @@ -20,6 +20,7 @@ import time import zipfile from datetime import date +from os import PathLike from struct import Struct, calcsize, error, pack, unpack from types import TracebackType from typing import ( @@ -159,7 +160,7 @@ def read(self, size: int = -1) -> bytes: ... # File name, file object or anything with a read() method that returns bytes. -BinaryFileT = Union[str, IO[bytes]] +BinaryFileT = Union[str, PathLike[Any], IO[bytes]] BinaryFileStreamT = Union[IO[bytes], io.BytesIO, WriteSeekableBinStream] FieldTypeT = Literal["C", "D", "F", "L", "M", "N"] @@ -341,11 +342,11 @@ class GeoJSONFeatureCollectionWithBBox(GeoJSONFeatureCollection): @overload -def fsdecode_if_pathlike(path: os.PathLike[Any]) -> str: ... +def fsdecode_if_pathlike(path: PathLike[Any]) -> str: ... @overload def fsdecode_if_pathlike(path: T) -> T: ... def fsdecode_if_pathlike(path: Any) -> Any: - if isinstance(path, os.PathLike): + if isinstance(path, PathLike): return os.fsdecode(path) # str return path @@ -2243,7 +2244,7 @@ def _assert_ext_is_supported(self, ext: str) -> None: def __init__( self, - shapefile_path: Union[str, os.PathLike[Any]] = "", + shapefile_path: Union[str, PathLike[Any]] = "", /, *, encoding: str = "utf-8", @@ -2411,7 +2412,7 @@ def __init__( return if shp is not _NO_SHP_SENTINEL: - shp = cast(Union[str, IO[bytes], None], shp) + shp = cast(Union[str, PathLike[Any], IO[bytes], None], shp) self.shp = self.__seek_0_on_file_obj_wrap_or_open_from_name("shp", shp) self.shx = self.__seek_0_on_file_obj_wrap_or_open_from_name("shx", shx) @@ -2432,7 +2433,7 @@ def __seek_0_on_file_obj_wrap_or_open_from_name( if file_ is None: return None - if isinstance(file_, str): + if isinstance(file_, (str, PathLike)): baseName, __ = os.path.splitext(file_) return self._load_constituent_file(baseName, ext) @@ -3235,7 +3236,7 @@ class Writer: def __init__( self, - target: Union[str, os.PathLike[Any], None] = None, + target: Union[str, PathLike[Any], None] = None, shapeType: Optional[int] = None, autoBalance: bool = False, *, diff --git a/test_shapefile.py b/test_shapefile.py index 2a10d3e..152be86 100644 --- a/test_shapefile.py +++ b/test_shapefile.py @@ -13,6 +13,8 @@ # our imports import shapefile +shapefiles_dir = Path(__file__).parent / "shapefiles" + # define various test shape tuples of (type, points, parts indexes, and expected geo interface output) geo_interface_tests = [ ( @@ -719,8 +721,7 @@ def test_reader_pathlike(): """ Assert that path-like objects can be read. """ - base = Path("shapefiles") - with shapefile.Reader(base / "blockgroups") as sf: + with shapefile.Reader(shapefiles_dir / "blockgroups") as sf: assert len(sf) == 663 @@ -736,6 +737,18 @@ def test_reader_dbf_only(): assert record[1:3] == ["060750601001", 4715] +def test_reader_dbf_only_from_Path(): + """ + Assert that specifying just the + dbf argument to the shapefile reader as a Path + reads just the dbf file. + """ + with shapefile.Reader(dbf=shapefiles_dir / "blockgroups.dbf") as sf: + assert len(sf) == 663 + record = sf.record(3) + assert record[1:3] == ["060750601001", 4715] + + def test_reader_shp_shx_only(): """ Assert that specifying just the @@ -750,6 +763,20 @@ def test_reader_shp_shx_only(): assert len(shape.points) == 173 +def test_reader_shp_shx_only_from_Paths(): + """ + Assert that specifying just the + shp and shx argument to the shapefile reader as Paths + reads just the shp and shx file. + """ + with shapefile.Reader( + shp=shapefiles_dir / "blockgroups.shp", shx=shapefiles_dir / "blockgroups.shx" + ) as sf: + assert len(sf) == 663 + shape = sf.shape(3) + assert len(shape.points) == 173 + + def test_reader_shp_dbf_only(): """ Assert that specifying just the @@ -766,6 +793,22 @@ def test_reader_shp_dbf_only(): assert record[1:3] == ["060750601001", 4715] +def test_reader_shp_dbf_only_from_Paths(): + """ + Assert that specifying just the + shp and shx argument to the shapefile reader as Paths + reads just the shp and dbf file. + """ + with shapefile.Reader( + shp=shapefiles_dir / "blockgroups.shp", dbf=shapefiles_dir / "blockgroups.dbf" + ) as sf: + assert len(sf) == 663 + shape = sf.shape(3) + assert len(shape.points) == 173 + record = sf.record(3) + assert record[1:3] == ["060750601001", 4715] + + def test_reader_shp_only(): """ Assert that specifying just the @@ -778,6 +821,18 @@ def test_reader_shp_only(): assert len(shape.points) == 173 +def test_reader_shp_only_from_Path(): + """ + Assert that specifying just the + shp argument to the shapefile reader as a Path + reads just the shp file (shx optional). + """ + with shapefile.Reader(shp=shapefiles_dir / "blockgroups.shp") as sf: + assert len(sf) == 663 + shape = sf.shape(3) + assert len(shape.points) == 173 + + def test_reader_filelike_dbf_only(): """ Assert that specifying just the