Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 148 additions & 4 deletions lance_ray/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,155 @@ def _read_fragments(
) -> Iterator[pa.Table]:
"""Read Lance fragments in batches.

This enhanced reader detects Lance blob-encoded columns and reconstructs
raw bytes using the :meth:`LanceDataset.take_blobs` API, returning
:class:`pyarrow.LargeBinaryArray` columns instead of the default
struct-based descriptors.

Row ordering is preserved by using per-batch row IDs.

NOTE: Use fragment ids, instead of fragments as parameter, because pickling
LanceFragment is expensive.
:class:`lance.LanceFragment` is expensive.
"""
# Resolve fragments
fragments = [lance_ds.get_fragment(id) for id in fragment_ids]
scanner_options["fragments"] = fragments
scanner = lance_ds.scanner(**scanner_options)

# Copy scanner options so we can safely mutate
scan_opts: dict[str, Any] = dict(scanner_options)
scan_opts["fragments"] = fragments

# Detect blob columns from the dataset schema and requested projection
ds_schema: pa.Schema = lance_ds.schema
requested_columns = scan_opts.get("columns")
# Map column name -> blob kind ("legacy" or "v2")
blob_columns: dict[str, str] = {}

def _is_blob_field(f: pa.Field) -> Optional[str]:
"""Detect Lance blob columns.

Returns:
"v2" for blob v2 extension columns,
"legacy" for legacy metadata-based blob columns,
or None if the field is not a blob.
"""
field_type = f.type

# Blob v2: extension type `lance.blob.v2`
if isinstance(field_type, pa.ExtensionType):
ext_name = getattr(field_type, "extension_name", None)
if ext_name == "lance.blob.v2":
return "v2"

# Legacy: LargeBinary with field metadata {"lance-encoding:blob": "true"}
try:
is_large_bin = field_type == pa.large_binary()
except Exception:
is_large_bin = False
if not is_large_bin:
return None

meta = f.metadata
if meta is None:
return None

# pyarrow may store metadata keys/values as str
if (meta.get("lance-encoding:blob") == "true") or (
meta.get(b"lance-encoding:blob") == b"true"
):
return "legacy"

return None

# Build list of blob columns to reconstruct, honoring column projection
ds_field_names = ds_schema.names
for idx, name in enumerate(ds_field_names):
field = ds_schema.field(idx)
kind = _is_blob_field(field)
if kind is None:
continue
if requested_columns is None:
blob_columns[name] = kind
elif isinstance(requested_columns, list):
if name in requested_columns:
blob_columns[name] = kind
elif isinstance(requested_columns, dict) and name in requested_columns:
# If columns are SQL expressions, only reconstruct if explicitly requested
blob_columns[name] = kind

# If blob columns are present, ensure row IDs are included for reconstruction
if blob_columns:
scan_opts["with_row_id"] = True

scanner = lance_ds.scanner(**scan_opts)

for batch in scanner.to_reader():
yield pa.Table.from_batches([batch])
# Fast path: no blob columns requested in this scan
if not blob_columns:
yield pa.Table.from_batches([batch])
continue

# Build a table so we can manipulate columns easily
table = pa.Table.from_batches([batch])

# Extract row IDs used to reconstruct bytes in the same order
if "_rowid" not in table.column_names:
# Safety: if row ids are missing for any reason, fall back to original
yield table
continue
row_ids = table.column("_rowid").to_pylist()

# For each blob column, reconstruct a LargeBinary array
for col, kind in blob_columns.items():
if col not in table.column_names:
# Column not projected in this batch
continue

# The scanned representation may be a struct descriptor or extension-backed
# array. We only rely on the null bitmap: None -> null bytes; non-null ->
# fetch bytes via LanceDataset.take_blobs.
desc_py = table.column(col).to_pylist()

# Fetch BlobFile handles in batch order
blob_files = lance_ds.take_blobs(col, ids=row_ids)

# Convert BlobFile -> bytes, respecting nulls
values: list[Optional[bytes]] = []
for i, desc in enumerate(desc_py):
if desc is None:
values.append(None)
continue
# Backward compatibility: older legacy blob layouts may encode
# nulls as a sentinel struct {position: 1, size: 0} instead of
# using the Arrow null bitmap. Treat this sentinel as null for
# metadata-based blob columns only.
if kind == "legacy" and isinstance(desc, dict):
pos = desc.get("position")
size = desc.get("size")
if pos == 1 and size == 0:
values.append(None)
continue
with blob_files[i] as bf:
values.append(bf.read())

# Construct LargeBinary array for Ray, preserving legacy metadata only
# for metadata-based blob columns. Blob v2 extension columns are exposed
# as plain LargeBinary bytes.
new_arr = pa.array(values, type=pa.large_binary())
ds_field_index = ds_schema.get_field_index(col)
ds_field = ds_schema.field(ds_field_index)
nullable = ds_field.nullable
metadata = ds_field.metadata if kind == "legacy" else None
new_field = pa.field(
col, pa.large_binary(), nullable=nullable, metadata=metadata
)
table = table.set_column(
table.schema.get_field_index(col),
new_field,
pa.chunked_array([new_arr]),
)

# Drop helper row ID column before returning
if "_rowid" in table.column_names:
table = table.drop_columns(["_rowid"])

yield table
16 changes: 14 additions & 2 deletions lance_ray/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,20 @@ def __call__(self, batch: Union[pa.Table, "pd.DataFrame", dict]) -> pa.Table:
# Convert dict/numpy arrays to pyarrow table if needed
if isinstance(batch, dict):
batch = pa.Table.from_pydict(batch)
elif hasattr(batch, "__dataframe__"): # pandas DataFrame
batch = pa.Table.from_pandas(batch)
else:
# Only convert when the input is an actual pandas DataFrame.
# Some objects (including pyarrow.Table) may implement the
# dataframe interchange protocol `__dataframe__`, but they are
# not pandas DataFrames. Using `hasattr(..., "__dataframe__")`
# incorrectly routes them through `Table.from_pandas` and causes
# errors. Perform a strict isinstance check instead.
try:
from pandas import DataFrame as _PandasDataFrame # type: ignore
except Exception:
_PandasDataFrame = None # type: ignore

if _PandasDataFrame is not None and isinstance(batch, _PandasDataFrame):
batch = pa.Table.from_pandas(batch)

transformed = self.transform(batch)
if not isinstance(transformed, Generator):
Expand Down
180 changes: 167 additions & 13 deletions lance_ray/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ def write_lance(
namespace_properties: Optional[dict[str, str]] = None,
ray_remote_args: Optional[dict[str, Any]] = None,
concurrency: Optional[int] = None,
# Streaming parameters (only effective when stream=True)
stream: bool = False,
batch_size: Optional[int] = None,
resume_rows: int = 0,
) -> None:
"""Write the dataset to a Lance dataset.

Expand Down Expand Up @@ -188,27 +192,177 @@ def write_lance(
Used together with namespace_properties and table_id.
namespace_properties: Properties for connecting to the namespace.
Used together with namespace_impl and table_id.
stream: Enable incremental batch streaming write. Default False.
batch_size: Batch size when streaming. If None, defaults to 1024.
resume_rows: Number of leading rows to skip when streaming (for resume).
"""
_validate_write_args(uri, namespace_impl, table_id, mode)

datasink = LanceDatasink(
uri,
table_id=table_id,
schema=schema,
mode=mode,
min_rows_per_file=min_rows_per_file,
# Fast path: non-streaming write using the Datasink API.
if not stream:
datasink = LanceDatasink(
uri,
table_id=table_id,
schema=schema,
mode=mode,
min_rows_per_file=min_rows_per_file,
max_rows_per_file=max_rows_per_file,
data_storage_version=data_storage_version,
storage_options=storage_options,
namespace_impl=namespace_impl,
namespace_properties=namespace_properties,
)

ds.write_datasink(
datasink,
ray_remote_args=ray_remote_args or {},
concurrency=concurrency,
)
return

# Streaming path: commit one fragment per batch to minimize memory usage.
import lance

if (namespace_impl is not None or namespace_properties is not None) and table_id:
raise ValueError(
"Streaming write with 'namespace_impl' + 'table_id' is not supported; "
"use non-stream mode or provide a direct 'uri'.",
)

if uri is None:
raise ValueError(
"Streaming write requires 'uri' to be provided when no namespace is used.",
)

dest_uri: str = uri
dest_exists = False
dest_version: Optional[int] = None

try:
_dest = lance.LanceDataset(dest_uri, storage_options=storage_options)
dest_exists = True
dest_version = _dest.version
except Exception:
dest_exists = False
dest_version = None

# Enforce mode semantics.
if mode == "create" and dest_exists:
raise ValueError("Destination exists but mode='create' was specified.")
if mode == "append" and not dest_exists:
raise ValueError("Destination does not exist but mode='append' was specified.")

from .fragment import LanceFragmentWriter

effective_batch_size = batch_size if batch_size is not None else 1024

writer = LanceFragmentWriter(
uri=dest_uri,
schema=schema, # if None, writer infers from first batch (preserves Arrow metadata)
max_rows_per_file=max_rows_per_file,
max_rows_per_group=min_rows_per_file, # keep naming aligned with v1 semantics
data_storage_version=data_storage_version,
storage_options=storage_options,
namespace_impl=namespace_impl,
namespace_properties=namespace_properties,
namespace_impl=None,
namespace_properties=None,
table_id=None,
)

ds.write_datasink(
datasink,
ray_remote_args=ray_remote_args or {},
concurrency=concurrency,
)
rows_seen = 0
first_commit_done = False

for batch in ds.iter_batches(
batch_size=effective_batch_size, batch_format="pyarrow"
):
# Convert to pyarrow.Table if needed.
tbl = batch if isinstance(batch, pa.Table) else pa.Table.from_pydict(batch)

# Apply resume_rows skipping across batches.
if resume_rows > rows_seen:
to_skip = min(resume_rows - rows_seen, tbl.num_rows)
rows_seen += to_skip
if to_skip >= tbl.num_rows:
# Whole batch skipped.
continue
tbl = tbl.slice(to_skip)

# Skip empty batches (possible after slicing).
if tbl.num_rows == 0:
continue

# Write this batch as one fragment and collect metadata.
frag_tbl = writer(tbl)
fragments: list[Any] = []
schema_obj: Optional[pa.Schema] = None
frag_col = frag_tbl.column("fragment").to_pylist()
sch_col = frag_tbl.column("schema").to_pylist()
for frag_bytes, schema_bytes in zip(frag_col, sch_col, strict=False):
fragment = pickle.loads(frag_bytes)
fragments.append(fragment)
schema_obj = pickle.loads(schema_bytes)

# Commit after each batch.
if not first_commit_done:
# First commit: respect mode.
if mode in ("create", "overwrite") or not dest_exists:
op = LanceOperation.Overwrite(schema_obj, fragments)
LanceDataset.commit(
dest_uri,
op,
read_version=None,
storage_options=storage_options,
)
first_commit_done = True
dest_exists = True
try:
_dest = lance.LanceDataset(
dest_uri, storage_options=storage_options
)
dest_version = _dest.version
except Exception:
dest_version = None
elif mode == "append":
op = LanceOperation.Append(fragments)
LanceDataset.commit(
dest_uri,
op,
read_version=dest_version,
storage_options=storage_options,
)
first_commit_done = True
try:
_dest = lance.LanceDataset(
dest_uri, storage_options=storage_options
)
dest_version = _dest.version
except Exception:
pass
else:
# Fallback: overwrite.
op = LanceOperation.Overwrite(schema_obj, fragments)
LanceDataset.commit(
dest_uri,
op,
read_version=None,
storage_options=storage_options,
)
first_commit_done = True
else:
# Subsequent commits always append.
op = LanceOperation.Append(fragments)
LanceDataset.commit(
dest_uri,
op,
read_version=dest_version,
storage_options=storage_options,
)
try:
_dest = lance.LanceDataset(dest_uri, storage_options=storage_options)
dest_version = _dest.version
except Exception:
pass

rows_seen += tbl.num_rows


def _handle_fragment(
Expand Down
Loading