Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions kloppy/_providers/sportscode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from kloppy.infra.serializers.code.sportscode import (
SportsCodeDeserializer,
SportsCodeInputs,
SportsCodeOutputs,
SportsCodeSerializer,
)
from kloppy.io import FileLike, open_as_file
Expand Down Expand Up @@ -31,6 +32,6 @@ def save(dataset: CodeDataset, output_filename: str) -> None:
dataset: The SportsCode dataset to save.
output_filename: The output filename.
"""
with open(output_filename, "wb") as fp:
with open_as_file(output_filename, "wb") as data_fp:
serializer = SportsCodeSerializer()
fp.write(serializer.serialize(dataset))
serializer.serialize(dataset, outputs=SportsCodeOutputs(data=data_fp))
48 changes: 41 additions & 7 deletions kloppy/infra/io/adapters/adapter.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,61 @@
from abc import ABC, abstractmethod
from typing import BinaryIO

from kloppy.infra.io.buffered_stream import BufferedStream


class Adapter(ABC):
@abstractmethod
def supports(self, url: str) -> bool:
pass
"""Returns True if this adapter supports the given URL, False otherwise."""

@abstractmethod
def is_directory(self, url: str) -> bool:
pass
"""Returns True if the given URL points to a directory, False otherwise."""

@abstractmethod
def is_file(self, url: str) -> bool:
pass
"""Returns True if the given URL points to a file, False otherwise."""

@abstractmethod
def read_to_stream(self, url: str, output: BinaryIO):
pass
def read_to_stream(self, url: str, output: BufferedStream):
"""Read content from the given URL into the BufferedStream.

Args:
url: The source URL
output: BufferedStream to write to
"""

def write_from_stream(self, url: str, input: BufferedStream, mode: str): # noqa: A002
"""Write content from BufferedStream to the given URL.

Args:
url: The destination URL
input: BufferedStream to read from
mode: Write mode ('wb' for write/overwrite or 'ab' for append)

Raises:
NotImplementedError: If write operations are not supported by this adapter
"""
raise NotImplementedError(
f"Write operations not supported for {url}. "
f"Adapter {self.__class__.__name__} does not implement write_from_stream."
)

@abstractmethod
def list_directory(self, url: str, recursive: bool = True) -> list[str]:
pass
"""Lists the contents of a directory.

Args:
url: The directory URL
recursive: Whether to list contents recursively

Returns:
A list of files in the directory

Example:
>>> adapter.list_directory("s3://my-bucket/data/", recursive=False)
['s3://my-bucket/data/file1.csv', 's3://my-bucket/data/file2.csv']
"""


__all__ = ["Adapter"]
43 changes: 35 additions & 8 deletions kloppy/infra/io/adapters/fsspec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
import os
import re
from typing import BinaryIO, Optional
from typing import Optional

import fsspec

from kloppy.config import get_config
from kloppy.exceptions import InputNotFoundError
from kloppy.infra.io.buffered_stream import BufferedStream

from .adapter import Adapter

Expand All @@ -28,7 +30,6 @@ def _get_filesystem(
Get the appropriate fsspec filesystem for the given URL, with caching enabled.
"""
protocol = self._infer_protocol(url)

if no_cache:
return fsspec.filesystem(protocol)

Expand All @@ -38,6 +39,16 @@ def _get_filesystem(
cache_storage=get_config("cache"),
)

def _get_filesystem_for_reading(
self, url: str
) -> fsspec.AbstractFileSystem:
return self._get_filesystem(url, no_cache=False)

def _get_filesystem_for_writing(
self, url: str
) -> fsspec.AbstractFileSystem:
return self._get_filesystem(url, no_cache=True)

def _detect_compression(self, url: str) -> Optional[str]:
"""
Detect the compression type based on the file extension.
Expand All @@ -60,20 +71,36 @@ def supports(self, url: str) -> bool:
Check if the adapter can handle the URL.
"""

def read_to_stream(self, url: str, output: BinaryIO):
def read_to_stream(self, url: str, output: BufferedStream):
"""
Reads content from the given URL and writes it to the provided binary stream.
Uses caching for remote files.
Reads content from the given URL and writes it to the provided BufferedStream.
Uses caching for remote files. Copies data in chunks.
"""
fs = self._get_filesystem(url)
fs = self._get_filesystem_for_reading(url)
compression = self._detect_compression(url)

try:
with fs.open(url, "rb", compression=compression) as source_file:
output.write(source_file.read())
output.read_from(source_file)
except FileNotFoundError as e:
raise InputNotFoundError(f"Input file not found: {url}") from e

def write_from_stream(self, url: str, input: BufferedStream, mode: str): # noqa: A002
"""
Writes content from BufferedStream to the given URL.
Does not use caching for writes. Copies data in chunks.

Args:
url: The destination URL
input: BufferedStream to read from
mode: Write mode ('wb' for write/overwrite or 'ab' for append)
"""
fs = self._get_filesystem_for_writing(url)
compression = self._detect_compression(url)

with fs.open(url, mode, compression=compression) as dest_file:
input.write_to(dest_file)

def list_directory(self, url: str, recursive: bool = True) -> list[str]:
"""
Lists the contents of a directory.
Expand All @@ -87,7 +114,7 @@ def list_directory(self, url: str, recursive: bool = True) -> list[str]:
return [
f"{protocol}://{fp}"
if protocol != "file" and not fp.startswith(protocol)
else fp
else os.fspath(fp)
for fp in files
]

Expand Down
25 changes: 19 additions & 6 deletions kloppy/infra/io/adapters/zip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,32 @@ class ZipAdapter(FSSpecAdapter):
def supports(self, url: str) -> bool:
return url.startswith("zip://")

def _get_filesystem(
self, url: str, no_cache: bool = False
def _get_filesystem_for_reading(
self, url: str
) -> fsspec.AbstractFileSystem:
fo = get_config("adapters.zip.fo")
if fo is None:
raise AdapterError(
"No zip archive provided for the zip adapter."
" Please provide one using the 'adapters.zip.fo' config."
)
return fsspec.filesystem(protocol="zip", fo=fo, mode="r")

def _get_filesystem_for_writing(
self, url: str
) -> fsspec.AbstractFileSystem:
fo = get_config("adapters.zip.fo")
if fo is None:
raise AdapterError(
"No zip archive provided for the zip adapter."
" Please provide one using the 'adapters.zip.fo' config."
)
return fsspec.filesystem(
protocol="zip",
fo=fo,
)
return fsspec.filesystem(protocol="zip", fo=fo, mode="a")

def _get_filesystem(
self, url: str, no_cache: bool = False
) -> fsspec.AbstractFileSystem:
return self._get_filesystem_for_reading(url)

def list_directory(self, url: str, recursive: bool = True) -> list[str]:
"""
Expand Down
76 changes: 76 additions & 0 deletions kloppy/infra/io/buffered_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Buffered stream utilities for efficient I/O operations."""

import shutil
import tempfile
from typing import BinaryIO, Protocol

DEFAULT_BUFFER_SIZE = 5 * 1024 * 1024 # 5MB before spilling to disk


class SupportsWrite(Protocol):
"""Protocol for objects that support write operations."""

def write(self, data: bytes) -> int: ...


class SupportsRead(Protocol):
"""Protocol for objects that support read operations."""

def read(self, n: int) -> bytes: ...


class BufferedStream(tempfile.SpooledTemporaryFile):
"""A spooled temporary file that can efficiently copy from streams in chunks."""

def __init__(self, max_size: int = DEFAULT_BUFFER_SIZE, mode: str = "w+b"):
super().__init__(max_size=max_size, mode=mode)

def write(self, data: bytes) -> int: # make it clearly bytes-only
return super().write(data)

def read(self, n: int = -1) -> bytes: # make it clearly bytes-only
return super().read(n)

@classmethod
def from_stream(
cls,
source: BinaryIO,
max_size: int = DEFAULT_BUFFER_SIZE,
chunk_size: int = 0,
) -> "BufferedStream":
"""
Create a BufferedStream by copying data from source stream in chunks.

Args:
source: The source binary stream to read from
max_size: Maximum size to keep in memory before spilling to disk
chunk_size: Size of chunks to keep in memory before spilling to disk

Returns:
A BufferedStream containing the copied data
"""
buffer = cls(max_size=max_size)
buffer.read_from(source, chunk_size)
return buffer

def read_from(self, source: SupportsRead, chunk_size: int = 0):
"""
Read data from source into this BufferedStream in chunks.

Args:
source: The source that supports read() method
chunk_size: Size of chunks to copy at a time (0 uses default)
"""
shutil.copyfileobj(source, self, chunk_size)
self.seek(0)

def write_to(self, output: SupportsWrite, chunk_size: int = 0) -> None:
"""
Write all contents of this BufferedStream to the output in chunks.

Args:
output: The destination that supports write() method
chunk_size: Size of chunks to keep in memory before spilling to disk
"""
self.seek(0)
shutil.copyfileobj(self, output, chunk_size)
11 changes: 6 additions & 5 deletions kloppy/infra/serializers/code/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@

from kloppy.domain import CodeDataset

T = TypeVar("T")
T_I = TypeVar("T_I")
T_O = TypeVar("T_O")


class CodeDataDeserializer(ABC, Generic[T]):
class CodeDataDeserializer(ABC, Generic[T_I]):
@abstractmethod
def deserialize(self, inputs: T) -> CodeDataset:
def deserialize(self, inputs: T_I) -> CodeDataset:
raise NotImplementedError


class CodeDataSerializer(ABC):
class CodeDataSerializer(ABC, Generic[T_O]):
@abstractmethod
def serialize(self, dataset: CodeDataset) -> bytes:
def serialize(self, dataset: CodeDataset, outputs: T_O) -> bool:
raise NotImplementedError
24 changes: 16 additions & 8 deletions kloppy/infra/serializers/code/sportscode.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class SportsCodeInputs(NamedTuple):
data: IO[bytes]


class SportsCodeOutputs(NamedTuple):
data: IO[bytes]


class SportsCodeDeserializer(CodeDataDeserializer[SportsCodeInputs]):
def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset:
all_instances = objectify.fromstring(inputs.data.read())
Expand Down Expand Up @@ -88,8 +92,10 @@ def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset:
)


class SportsCodeSerializer(CodeDataSerializer):
def serialize(self, dataset: CodeDataset) -> bytes:
class SportsCodeSerializer(CodeDataSerializer[SportsCodeOutputs]):
def serialize(
self, dataset: CodeDataset, outputs: SportsCodeOutputs
) -> bool:
root = etree.Element("file")
all_instances = etree.SubElement(root, "ALL_INSTANCES")
for i, code in enumerate(dataset.codes):
Expand Down Expand Up @@ -138,10 +144,12 @@ def serialize(self, dataset: CodeDataset) -> bytes:
text_ = etree.SubElement(label, "text")
text_.text = str(text)

return etree.tostring(
root,
pretty_print=True,
xml_declaration=True,
encoding="utf-8", # This might not work with some tools because they expected 'ascii'.
method="xml",
outputs.data.write(
etree.tostring(
root,
pretty_print=True,
xml_declaration=True,
encoding="utf-8", # This might not work with some tools because they expected 'ascii'.
method="xml",
)
)
Loading
Loading