Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def raise_http_exception(request: Request) -> Callable[[Exception | str], Awaita
"""Callback to raise an HTTPException with a specific status code."""

async def _raise_http_exception(error: Exception | str) -> None:
message = str(error) if isinstance(error, Exception) else error
code = error.status_code if isinstance(error, HTTPException) else 400
raise StreamTerminated(f"{code}: {message}") from error
message = f"{type(error).__name__}: {error}" if isinstance(error, Exception) else str(error)
code = error.status_code if isinstance(error, HTTPException) else 502
raise StreamTerminated(f"{code} - {message}") from error

return _raise_http_exception
24 changes: 18 additions & 6 deletions lib/metadata.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from starlette.datastructures import Headers
from pydantic import BaseModel, Field, field_validator, ByteSize, StrictStr, ConfigDict, AliasChoices
from typing import Optional, Self, Annotated
from pydantic import BaseModel, ByteSize, ConfigDict, Field, field_validator, StrictStr
from typing import Annotated, Optional, Self
import re


class FileMetadata(BaseModel):
name: StrictStr = Field(description="File name", min_length=2, max_length=255)
name: StrictStr = Field(description="File name", min_length=1, max_length=255)
size: ByteSize = Field(description="Size in bytes", gt=0)
type: StrictStr = Field(description="MIME type", default='application/octet-stream')

Expand All @@ -13,8 +14,19 @@ class FileMetadata(BaseModel):
@field_validator('name')
@classmethod
def validate_name(cls, v: str) -> str:
safe_filename = str(v).translate(str.maketrans(':;|*@/\\', ' ')).strip()
return safe_filename.encode('latin-1', 'ignore').decode('utf-8', 'ignore')
if not v or not v.strip():
raise ValueError("Filename cannot be empty")

safe_filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', ' ', str(v)).strip()
if not safe_filename:
raise ValueError("Filename contains only invalid characters")

try:
safe_filename = safe_filename.encode('utf-8').decode('utf-8')
except UnicodeError:
safe_filename = safe_filename.encode('utf-8', 'ignore').decode('utf-8', 'ignore')

return safe_filename

@classmethod
def from_json(cls, data: str) -> Self:
Expand All @@ -29,7 +41,7 @@ def get_from_http_headers(cls, headers: Headers, filename: str) -> Self:
return cls(
name=filename,
size=headers.get('content-length', '0'),
type=headers.get('content-type', '') or None
type=headers.get('content-type', '') # Must be a string
)

@classmethod
Expand Down
137 changes: 50 additions & 87 deletions lib/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ def __init__(self, transfer_id: str):
self.transfer_id = transfer_id
self.redis = self.get_redis()

self._k_queue = self.key('queue')
self._k_meta = self.key('metadata')
self._k_cleanup = f'cleanup:{transfer_id}'
self._k_receiver_connected = self.key('receiver_connected')
self._k_stream = self.key('stream')
self._k_metadata = self.key('metadata')
self._k_position = self.key('position')
self._k_progress = self.key('progress')
self._k_receiver_active = self.key('receiver_active')

@classmethod
def get_redis(cls) -> redis.Redis:
Expand All @@ -36,26 +37,25 @@ def key(self, name: str) -> str:
"""Get the Redis key for this transfer with the provided name."""
return f'transfer:{self.transfer_id}:{name}'

## Queue operations ##
async def add_chunk(self, data: bytes) -> None:
"""Add chunk to stream."""
await self.redis.xadd(self._k_stream, {'data': data})

async def _wait_for_queue_space(self, maxsize: int) -> None:
while await self.redis.llen(self._k_queue) >= maxsize:
await anyio.sleep(0.5)
async def stream_chunks(self, read_timeout: float = 20.0):
"""Stream chunks from last position."""
position = await self.redis.get(self._k_position)
last_id = position.decode() if position else '0'

async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 20.0) -> None:
"""Add data to the transfer queue with backpressure control."""
with anyio.fail_after(timeout):
await self._wait_for_queue_space(maxsize)
await self.redis.lpush(self._k_queue, data)

async def get_from_queue(self, timeout: float = 20.0) -> bytes:
"""Get data from the transfer queue with timeout."""
result = await self.redis.brpop([self._k_queue], timeout=timeout)
if not result:
raise TimeoutError("Timeout waiting for data")
while True:
result = await self.redis.xread({self._k_stream: last_id}, block=int(read_timeout*1000))
if not result:
raise TimeoutError("Stream read timeout")

_, data = result
return data
_, messages = result[0]
for message_id, fields in messages:
last_id = message_id
await self.redis.set(self._k_position, last_id, ex=300)
yield fields[b'data']

## Event operations ##

Expand Down Expand Up @@ -99,80 +99,43 @@ async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None:
await pubsub.unsubscribe(event_key)
await pubsub.aclose()

## Metadata operations ##

async def set_metadata(self, metadata: str) -> None:
"""Store transfer metadata."""
challenge = random.randbytes(8)
await self.redis.set(self._k_meta, challenge, nx=True)
if await self.redis.get(self._k_meta) == challenge:
await self.redis.set(self._k_meta, metadata, ex=300)
else:
raise KeyError("Metadata already set for this transfer.")
if not await self.redis.set(self._k_metadata, metadata, nx=True, ex=300):
raise KeyError("Transfer already exists")

async def get_metadata(self) -> str | None:
"""Retrieve transfer metadata."""
return await self.redis.get(self._k_meta)

## Transfer state operations ##

async def set_receiver_connected(self) -> bool:
"""
Mark that a receiver has connected for this transfer.
Returns True if the flag was set, False if it was already created.
"""
return bool(await self.redis.set(self._k_receiver_connected, '1', ex=300, nx=True))

async def is_receiver_connected(self) -> bool:
"""Check if a receiver has already connected."""
return await self.redis.exists(self._k_receiver_connected) > 0

async def set_completed(self) -> None:
"""Mark the transfer as completed."""
await self.redis.set(f'completed:{self.transfer_id}', '1', ex=300, nx=True)

async def is_completed(self) -> bool:
"""Check if the transfer is marked as completed."""
return await self.redis.exists(f'completed:{self.transfer_id}') > 0

async def set_interrupted(self) -> None:
"""Mark the transfer as interrupted."""
await self.redis.set(f'interrupt:{self.transfer_id}', '1', ex=300, nx=True)
await self.redis.ltrim(self._k_queue, 0, 0)

async def is_interrupted(self) -> bool:
"""Check if the transfer was interrupted."""
return await self.redis.exists(f'interrupt:{self.transfer_id}') > 0

## Cleanup operations ##

async def cleanup_started(self) -> bool:
"""
Check if cleanup has already been initiated for this transfer.
This uses a set/get pattern with challenge to avoid race conditions.
"""
challenge = random.randbytes(8)
await self.redis.set(self._k_cleanup, challenge, ex=60, nx=True)
if await self.redis.get(self._k_cleanup) == challenge:
return False
return True

async def cleanup(self) -> int:
"""Remove all keys related to this transfer."""
if await self.cleanup_started():
return 0
"""Get transfer metadata."""
return await self.redis.get(self._k_metadata)

pattern = self.key('*')
keys_to_delete = set()
async def save_progress(self, bytes_downloaded: int) -> None:
"""Save download progress."""
await self.redis.set(self._k_progress, str(bytes_downloaded), ex=300)

async def get_progress(self) -> int:
"""Get download progress."""
progress = await self.redis.get(self._k_progress)
return int(progress) if progress else 0

async def set_receiver_active(self) -> None:
"""Mark receiver as actively downloading with TTL."""
await self.redis.set(self._k_receiver_active, '1', ex=5)

async def is_receiver_active(self) -> bool:
"""Check if receiver is actively downloading."""
return bool(await self.redis.exists(self._k_receiver_active))

async def cleanup(self) -> None:
"""Delete all transfer data."""
pattern = self.key('*')
cursor = 0
keys = []

while True:
cursor, keys = await self.redis.scan(cursor, match=pattern)
keys_to_delete |= set(keys)
cursor, batch = await self.redis.scan(cursor, match=pattern)
keys.extend(batch)
if cursor == 0:
break

if keys_to_delete:
self.debug(f"- Cleaning up {len(keys_to_delete)} keys")
return await self.redis.delete(*keys_to_delete)
return 0
if keys:
await self.redis.delete(*keys)
Loading