From 28365d6cbb5b5a5e4d45c293b3a69aaa0b6b48f1 Mon Sep 17 00:00:00 2001 From: Ben Schmidt Date: Thu, 27 Mar 2025 14:14:49 -0400 Subject: [PATCH] fix quadtree schema bug --- pyproject.toml | 2 +- quadfeather/demo.py | 17 +++++++++-------- quadfeather/tiler.py | 28 ++++++++++++++++++++++++---- tests/test_tiler.py | 28 +++++++++++++++++++++++++--- 4 files changed, 59 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2d24a92..900891d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "quadfeather" -version = "2.1.0" +version = "2.2.0" description = "Quadtree tiling from CSV/Apache Arrow for use with deepscatter in the browser." readme = "README.md" requires-python = ">=3.9" diff --git a/quadfeather/demo.py b/quadfeather/demo.py index db47bb9..b8fd749 100644 --- a/quadfeather/demo.py +++ b/quadfeather/demo.py @@ -31,18 +31,18 @@ def rbatch( if method == "lognormal": f = random.lognormal classes = [ - ("Banana", [5, 0.01], [5, 0.01]), - ("Strawberry", [3.5, 0.4], [3, 0.3]), - ("Apple", [4.6, 0.2], [3, 0.2]), - ("Mulberry", [5.6, 0.6], [6, 0.5]), + ("banana", [5, 0.01], [5, 0.01]), + ("strawberry", [3.5, 0.4], [3, 0.3]), + ("apple", [4.6, 0.2], [3, 0.2]), + ("mulberry", [5.6, 0.6], [6, 0.5]), ] else: f = random.normal classes = [ - ("Banana", [0, 0.2], [0, 0.3]), - ("Strawberry", [3, 0.05], [-3, 2]), - ("Apple", [-4.6, 0.1], [-5, 0.25]), - ("Mulberry", [5.6, 2.6], [6, 2.5]), + ("banana", [0, 0.2], [0, 0.3]), + ("strawberry", [3, 0.05], [-3, 2]), + ("apple", [-4.6, 0.1], [-5, 0.25]), + ("mulberry", [5.6, 2.6], [6, 2.5]), ] for c, xparam, yparam in classes: @@ -69,6 +69,7 @@ def rbatch( "y": y, "position": position, "class": [c] * len(x), + "cat": [c] * len(x), "quantity": random.random(len(x)), "date": date, } diff --git a/quadfeather/tiler.py b/quadfeather/tiler.py index 873d14c..8dd87f5 100644 --- a/quadfeather/tiler.py +++ b/quadfeather/tiler.py @@ -529,9 +529,17 @@ def schemas(self) -> Dict[str, pa.Schema]: if field.name in self.schema.names: # Check if the user requested a cast dtype = self.schema.field(field.name).type - if field in self.dictionaries: + if field.name in self.dictionaries: # dictionaries are written later. - dtype = pa.dictionary(pa.int16(), pa.utf8()) + itype = pa.int8() + if len(self.dictionaries[field.name]) >= 2**7: + itype = pa.int16() + if len(self.dictionaries[field.name]) >= 2**15: + itype = pa.int32() + if len(self.dictionaries[field.name]) >= 2**31: + itype = pa.int64() + + dtype = pa.dictionary(itype, pa.utf8()) fields[car].append(pa.field(field.name, dtype)) fields[self.sidecars.get("ix", "")].append(pa.field("ix", pa.uint64())) @@ -1529,7 +1537,11 @@ def insert_table(self, table: pa.Table): for sidecar, tb in self.keyed_batches(table).items(): for tb in rebatch(tb.to_batches(), 50e6): for batch in tb.to_batches(): - self.overflow_buffers[sidecar].write_batch(batch) + try: + self.overflow_buffers[sidecar].write_batch(batch) + except Exception as e: + logger.error(f"Error writing batch to {sidecar}: {e}") + raise e def keyed_batches(self, table: pa.Table): """ @@ -1569,7 +1581,15 @@ def remap_dictionary(chunk, new_order): # Switch a dictionary to use a pre-assigned set of keys. returns a new chunked dictionary array. new_indices = pc.index_in(chunk, new_order) - return pa.DictionaryArray.from_arrays(new_indices, new_order) + itype = pa.int8() + if len(new_order) >= 2**7: + itype = pa.int16() + if len(new_order) >= 2**15: + itype = pa.int32() + if len(new_order) >= 2**31: + itype = pa.int64() + indices = new_indices.cast(itype) + return pa.DictionaryArray.from_arrays(indices, new_order) def cli(): diff --git a/tests/test_tiler.py b/tests/test_tiler.py index a169a1c..4e72bf1 100644 --- a/tests/test_tiler.py +++ b/tests/test_tiler.py @@ -224,10 +224,28 @@ def test_big_parquet(self, tmp_path): assert ps[1] == 1 assert ps[-1] == 999 +class TestManyLittleFiles: + def test_small_chunks(self, tmp_path): + size = 100_000 + demo_parquet(tmp_path / "test.parquet", size=size) + qtree = main( + files=[tmp_path / "test.parquet"], + destination=tmp_path / "tiles", + tile_size=50, + first_tile_size=5, + sidecars={"class": "class_sidecar"}, + ) + manifest = qtree.manifest_table + assert pc.sum(manifest["nPoints"]).as_py() == size + tb = feather.read_table(tmp_path / "tiles" / "0/0/0.feather") + ps = tb["ix"].to_pylist() + assert ps[0] == 0 + assert ps[1] == 1 + assert ps[-1] == 4 class TestStreaming: def test_streaming_batches(self, tmp_path): - size = 5_000_000 + size = 2_000_000 demo_parquet( tmp_path / "t.parquet", size=size, extent=Rectangle(x=(0, 100), y=(0, 100)) ) @@ -238,9 +256,12 @@ def test_streaming_batches(self, tmp_path): basedir=tmp_path / "tiles", tile_size=9_000, dictionaries={ + # Force a dictionary to be written to a sidecar. "cat": pa.array(["apple", "banana", "strawberry", "mulberry"]), - "sidecars": {"cat": "cat"}, + # Same data, different dictionary. This one will be written to the main table. + "class": pa.array(["banana", "strawberry", "mulberry", "apple", "orange", "pineapple"]), }, + sidecars= {"cat": "cat"}, first_tile_size=1000, ) @@ -254,7 +275,8 @@ def test_streaming_batches(self, tmp_path): assert pc.sum(manifest["nPoints"]).as_py() == size tb = feather.read_table(tmp_path / "tiles" / "0/0/0.feather") - + assert "banana" in tb['class'].to_pylist() + assert "orange" not in tb['class'].to_pylist() class TestAppends: """