diff --git a/lib/callbacks.py b/lib/callbacks.py index e07621b..ee601fe 100644 --- a/lib/callbacks.py +++ b/lib/callbacks.py @@ -24,8 +24,8 @@ def raise_http_exception(request: Request) -> Callable[[Exception | str], Awaita """Callback to raise an HTTPException with a specific status code.""" async def _raise_http_exception(error: Exception | str) -> None: - message = str(error) if isinstance(error, Exception) else error - code = error.status_code if isinstance(error, HTTPException) else 400 - raise StreamTerminated(f"{code}: {message}") from error + message = f"{type(error).__name__}: {error}" if isinstance(error, Exception) else str(error) + code = error.status_code if isinstance(error, HTTPException) else 502 + raise StreamTerminated(f"{code} - {message}") from error return _raise_http_exception diff --git a/lib/metadata.py b/lib/metadata.py index e829597..ba3dca5 100644 --- a/lib/metadata.py +++ b/lib/metadata.py @@ -1,10 +1,11 @@ from starlette.datastructures import Headers -from pydantic import BaseModel, Field, field_validator, ByteSize, StrictStr, ConfigDict, AliasChoices -from typing import Optional, Self, Annotated +from pydantic import BaseModel, ByteSize, ConfigDict, Field, field_validator, StrictStr +from typing import Annotated, Optional, Self +import re class FileMetadata(BaseModel): - name: StrictStr = Field(description="File name", min_length=2, max_length=255) + name: StrictStr = Field(description="File name", min_length=1, max_length=255) size: ByteSize = Field(description="Size in bytes", gt=0) type: StrictStr = Field(description="MIME type", default='application/octet-stream') @@ -13,8 +14,19 @@ class FileMetadata(BaseModel): @field_validator('name') @classmethod def validate_name(cls, v: str) -> str: - safe_filename = str(v).translate(str.maketrans(':;|*@/\\', ' ')).strip() - return safe_filename.encode('latin-1', 'ignore').decode('utf-8', 'ignore') + if not v or not v.strip(): + raise ValueError("Filename cannot be empty") + + safe_filename = re.sub(r'[<>:"/\\|?*\x00-\x1f]', ' ', str(v)).strip() + if not safe_filename: + raise ValueError("Filename contains only invalid characters") + + try: + safe_filename = safe_filename.encode('utf-8').decode('utf-8') + except UnicodeError: + safe_filename = safe_filename.encode('utf-8', 'ignore').decode('utf-8', 'ignore') + + return safe_filename @classmethod def from_json(cls, data: str) -> Self: @@ -29,7 +41,7 @@ def get_from_http_headers(cls, headers: Headers, filename: str) -> Self: return cls( name=filename, size=headers.get('content-length', '0'), - type=headers.get('content-type', '') or None + type=headers.get('content-type', '') # Must be a string ) @classmethod diff --git a/lib/store.py b/lib/store.py index a62d7c4..53ffbe9 100644 --- a/lib/store.py +++ b/lib/store.py @@ -19,10 +19,11 @@ def __init__(self, transfer_id: str): self.transfer_id = transfer_id self.redis = self.get_redis() - self._k_queue = self.key('queue') - self._k_meta = self.key('metadata') - self._k_cleanup = f'cleanup:{transfer_id}' - self._k_receiver_connected = self.key('receiver_connected') + self._k_stream = self.key('stream') + self._k_metadata = self.key('metadata') + self._k_position = self.key('position') + self._k_progress = self.key('progress') + self._k_receiver_active = self.key('receiver_active') @classmethod def get_redis(cls) -> redis.Redis: @@ -36,26 +37,25 @@ def key(self, name: str) -> str: """Get the Redis key for this transfer with the provided name.""" return f'transfer:{self.transfer_id}:{name}' - ## Queue operations ## + async def add_chunk(self, data: bytes) -> None: + """Add chunk to stream.""" + await self.redis.xadd(self._k_stream, {'data': data}) - async def _wait_for_queue_space(self, maxsize: int) -> None: - while await self.redis.llen(self._k_queue) >= maxsize: - await anyio.sleep(0.5) + async def stream_chunks(self, read_timeout: float = 20.0): + """Stream chunks from last position.""" + position = await self.redis.get(self._k_position) + last_id = position.decode() if position else '0' - async def put_in_queue(self, data: bytes, maxsize: int = 16, timeout: float = 20.0) -> None: - """Add data to the transfer queue with backpressure control.""" - with anyio.fail_after(timeout): - await self._wait_for_queue_space(maxsize) - await self.redis.lpush(self._k_queue, data) - - async def get_from_queue(self, timeout: float = 20.0) -> bytes: - """Get data from the transfer queue with timeout.""" - result = await self.redis.brpop([self._k_queue], timeout=timeout) - if not result: - raise TimeoutError("Timeout waiting for data") + while True: + result = await self.redis.xread({self._k_stream: last_id}, block=int(read_timeout*1000)) + if not result: + raise TimeoutError("Stream read timeout") - _, data = result - return data + _, messages = result[0] + for message_id, fields in messages: + last_id = message_id + await self.redis.set(self._k_position, last_id, ex=300) + yield fields[b'data'] ## Event operations ## @@ -99,80 +99,43 @@ async def wait_for_event(self, event_name: str, timeout: float = 300.0) -> None: await pubsub.unsubscribe(event_key) await pubsub.aclose() - ## Metadata operations ## - async def set_metadata(self, metadata: str) -> None: """Store transfer metadata.""" - challenge = random.randbytes(8) - await self.redis.set(self._k_meta, challenge, nx=True) - if await self.redis.get(self._k_meta) == challenge: - await self.redis.set(self._k_meta, metadata, ex=300) - else: - raise KeyError("Metadata already set for this transfer.") + if not await self.redis.set(self._k_metadata, metadata, nx=True, ex=300): + raise KeyError("Transfer already exists") async def get_metadata(self) -> str | None: - """Retrieve transfer metadata.""" - return await self.redis.get(self._k_meta) - - ## Transfer state operations ## - - async def set_receiver_connected(self) -> bool: - """ - Mark that a receiver has connected for this transfer. - Returns True if the flag was set, False if it was already created. - """ - return bool(await self.redis.set(self._k_receiver_connected, '1', ex=300, nx=True)) - - async def is_receiver_connected(self) -> bool: - """Check if a receiver has already connected.""" - return await self.redis.exists(self._k_receiver_connected) > 0 - - async def set_completed(self) -> None: - """Mark the transfer as completed.""" - await self.redis.set(f'completed:{self.transfer_id}', '1', ex=300, nx=True) - - async def is_completed(self) -> bool: - """Check if the transfer is marked as completed.""" - return await self.redis.exists(f'completed:{self.transfer_id}') > 0 - - async def set_interrupted(self) -> None: - """Mark the transfer as interrupted.""" - await self.redis.set(f'interrupt:{self.transfer_id}', '1', ex=300, nx=True) - await self.redis.ltrim(self._k_queue, 0, 0) - - async def is_interrupted(self) -> bool: - """Check if the transfer was interrupted.""" - return await self.redis.exists(f'interrupt:{self.transfer_id}') > 0 - - ## Cleanup operations ## - - async def cleanup_started(self) -> bool: - """ - Check if cleanup has already been initiated for this transfer. - This uses a set/get pattern with challenge to avoid race conditions. - """ - challenge = random.randbytes(8) - await self.redis.set(self._k_cleanup, challenge, ex=60, nx=True) - if await self.redis.get(self._k_cleanup) == challenge: - return False - return True - - async def cleanup(self) -> int: - """Remove all keys related to this transfer.""" - if await self.cleanup_started(): - return 0 + """Get transfer metadata.""" + return await self.redis.get(self._k_metadata) - pattern = self.key('*') - keys_to_delete = set() + async def save_progress(self, bytes_downloaded: int) -> None: + """Save download progress.""" + await self.redis.set(self._k_progress, str(bytes_downloaded), ex=300) + + async def get_progress(self) -> int: + """Get download progress.""" + progress = await self.redis.get(self._k_progress) + return int(progress) if progress else 0 + async def set_receiver_active(self) -> None: + """Mark receiver as actively downloading with TTL.""" + await self.redis.set(self._k_receiver_active, '1', ex=5) + + async def is_receiver_active(self) -> bool: + """Check if receiver is actively downloading.""" + return bool(await self.redis.exists(self._k_receiver_active)) + + async def cleanup(self) -> None: + """Delete all transfer data.""" + pattern = self.key('*') cursor = 0 + keys = [] + while True: - cursor, keys = await self.redis.scan(cursor, match=pattern) - keys_to_delete |= set(keys) + cursor, batch = await self.redis.scan(cursor, match=pattern) + keys.extend(batch) if cursor == 0: break - if keys_to_delete: - self.debug(f"- Cleaning up {len(keys_to_delete)} keys") - return await self.redis.delete(*keys_to_delete) - return 0 + if keys: + await self.redis.delete(*keys) diff --git a/lib/transfer.py b/lib/transfer.py index d8af9d8..e6db758 100644 --- a/lib/transfer.py +++ b/lib/transfer.py @@ -37,12 +37,14 @@ def __init__(self, uid: str, file: FileMetadata): @classmethod async def create(cls, uid: str, file: FileMetadata): + """Create a new transfer using the provided identifier and file metadata.""" transfer = cls(uid, file) await transfer.store.set_metadata(file.to_json()) return transfer @classmethod async def get(cls, uid: str): + """Fetch a transfer from the store using the provided identifier.""" store = Store(uid) metadata_json = await store.get_metadata() if not metadata_json: @@ -58,122 +60,89 @@ def _format_uid(uid: str): def get_file_info(self): return self.file.name, self.file.size, self.file.type - async def wait_for_event(self, event_name: str, timeout: float = 300.0): - await self.store.wait_for_event(event_name, timeout) - - async def set_client_connected(self): - self.debug(f"▼ Notifying sender that receiver is connected...") - await self.store.set_event('client_connected') - - async def wait_for_client_connected(self): - self.info(f"△ Waiting for client to connect...") - await self.wait_for_event('client_connected') - self.debug(f"△ Received client connected notification.") - - async def is_receiver_connected(self) -> bool: - return await self.store.is_receiver_connected() - - async def set_receiver_connected(self) -> bool: - return await self.store.set_receiver_connected() - - async def is_interrupted(self) -> bool: - return await self.store.is_interrupted() - - async def set_interrupted(self): - await self.store.set_interrupted() - - async def is_completed(self) -> bool: - return await self.store.is_completed() - - async def set_completed(self): - await self.store.set_completed() - - async def collect_upload(self, stream: AsyncIterator[bytes], on_error: Callable[[Exception | str], Awaitable[None]]) -> None: + @property + async def receiver_connected(self) -> bool: + """Check if a receiver is actively downloading.""" + return await self.store.is_receiver_active() + + async def notify_receiver_connected(self): + """Notify sender that receiver connected.""" + await self.store.set_event('receiver_connected') + + async def wait_for_receiver(self): + """Wait for receiver to connect.""" + self.info(f"△ Waiting for receiver...") + await self.store.wait_for_event('receiver_connected') + self.debug(f"△ Receiver connected") + + async def consume_upload(self, stream: AsyncIterator[bytes], on_error: Callable[[Exception | str], Awaitable[None]]) -> None: + """Consume upload stream and add chunks to Redis stream.""" self.bytes_uploaded = 0 try: async for chunk in stream: if not chunk: - self.debug(f"△ Empty chunk received, ending upload.") break - if await self.is_interrupted(): - raise TransferError("Transfer was interrupted by the receiver.", propagate=False) - - await self.store.put_in_queue(chunk) + await self.store.add_chunk(chunk) self.bytes_uploaded += len(chunk) if self.bytes_uploaded < self.file.size: - raise TransferError("Received less data than expected.", propagate=True) + raise TransferError("Incomplete upload", propagate=True) - self.debug(f"△ End of upload, sending done marker.") - await self.store.put_in_queue(self.DONE_FLAG) + await self.store.add_chunk(self.DONE_FLAG) + self.debug(f"△ All data chunks uploaded: {self.bytes_uploaded} bytes") - except (ClientDisconnect, WebSocketDisconnect) as e: - self.error(f"△ Unexpected upload error: {e}") - await self.store.put_in_queue(self.DEAD_FLAG) + except (ClientDisconnect, WebSocketDisconnect): + self.error(f"△ Sender disconnected") + await self.store.add_chunk(self.DEAD_FLAG) - except TimeoutError as e: - self.warning(f"△ Timeout during upload.", exc_info=True) - await on_error("Timeout during upload.") + except TimeoutError: + self.warning(f"△ Upload timeout") + await on_error("Upload timeout") except TransferError as e: - self.warning(f"△ Upload error: {e}") if e.propagate: - await self.store.put_in_queue(self.DEAD_FLAG) - else: - await on_error(e) + await self.store.add_chunk(self.DEAD_FLAG) + await on_error(e) - finally: - await anyio.sleep(1.0) + async def produce_download(self, on_error: Callable[[Exception | str], Awaitable[None]]) -> AsyncIterator[bytes]: + """Produce download stream from Redis stream.""" + self.bytes_downloaded = await self.store.get_progress() - async def supply_download(self, on_error: Callable[[Exception | str], Awaitable[None]]) -> AsyncIterator[bytes]: - self.bytes_downloaded = 0 + if self.bytes_downloaded > 0: + self.info(f"▼ Resuming from byte {self.bytes_downloaded}") try: - while True: - chunk = await self.store.get_from_queue() + await self.store.set_receiver_active() + async for chunk in self.store.stream_chunks(): if chunk == self.DEAD_FLAG: - raise TransferError("Sender disconnected.") + raise TransferError("Sender disconnected") - if chunk == self.DONE_FLAG and self.bytes_downloaded < self.file.size: - raise TransferError("Received less data than expected.") - - elif chunk == self.DONE_FLAG: - self.debug(f"▼ Done marker received, ending download.") + if chunk == self.DONE_FLAG: + if self.bytes_downloaded >= self.file.size: + self.debug(f"▼ All data chunks downloaded: {self.bytes_downloaded} bytes") break self.bytes_downloaded += len(chunk) + await self.store.save_progress(self.bytes_downloaded) + await self.store.set_receiver_active() yield chunk - except Exception as e: - self.error(f"▼ Unexpected download error!", exc_info=True) - self.debug("Debug info:", stack_info=True) - await on_error(e) - except TransferError as e: - self.warning(f"▼ Download error") + await on_error(e) + except Exception as e: + self.error(f"▼ Download error", exc_info=True) await on_error(e) async def cleanup(self): - try: - with anyio.fail_after(30.0): - await self.store.cleanup() - except TimeoutError: - self.warning(f"- Cleanup timed out.") - pass + """Clean up transfer data.""" + await self.store.cleanup() async def finalize_download(self): - # self.debug("▼ Finalizing download...") - if self.bytes_downloaded < self.file.size and not await self.is_interrupted(): - self.warning("▼ Client disconnected before download was complete.") - await self.set_interrupted() - - await self.cleanup() - # self.debug("▼ Finalizing download...") - if self.bytes_downloaded < self.file.size and not await self.is_interrupted(): - self.warning("▼ Client disconnected before download was complete.") - await self.set_interrupted() - - await self.cleanup() + """Finalize download and cleanup if complete.""" + if self.bytes_downloaded < self.file.size: + self.info(f"▼ Download paused at {self.bytes_downloaded}/{self.file.size} bytes") + else: + await self.cleanup() diff --git a/static/css/style.css b/static/css/style.css index 08ab353..1664621 100644 --- a/static/css/style.css +++ b/static/css/style.css @@ -430,13 +430,54 @@ code.inline-highlight { } /* Responsive Design */ + +/* Tablet and small desktop */ +@media (max-width: 768px) { + .container { + max-width: 100%; + padding: var(--space-md); + } + + .header { + padding: var(--space-xl) 0; + } + + .code-section { + padding: var(--space-md); + } + + .info-list { + padding: var(--space-md); + } +} + +/* Mobile devices */ @media (max-width: 600px) { + :root { + --space-xs: 0.25rem; + --space-sm: 0.375rem; + --space-md: 0.75rem; + --space-lg: 1rem; + --space-xl: 1.5rem; + --space-2xl: 2rem; + } + .container { padding: var(--space-sm); } + .header { + padding: var(--space-lg) 0; + margin-bottom: var(--space-lg); + } + .header h1 { - font-size: 2rem; + font-size: 1.75rem; + margin-bottom: var(--space-xs); + } + + .header p { + font-size: 0.95rem; } .beta-badge { @@ -444,13 +485,223 @@ code.inline-highlight { display: inline-block; margin-left: var(--space-sm); margin-top: var(--space-xs); + font-size: 0.65rem; + padding: 1px 4px; + } + + .beta-warning { + padding: var(--space-sm); + margin-top: var(--space-md); + font-size: 0.85rem; + } + + .section { + margin-bottom: var(--space-xl); + } + + .section h2 { + font-size: 1.1rem; + margin-bottom: var(--space-sm); + } + + .section p { + font-size: 0.95rem; + margin-bottom: var(--space-sm); } + /* Mobile-optimized transfer area */ .transfer-container { - min-height: 160px; + min-height: 180px; + margin-bottom: var(--space-lg); + } + + .drop-area { + border-radius: var(--radius-md); + padding: var(--space-lg); + min-height: 180px; + touch-action: none; + } + + .drop-area p { + font-size: 1rem; + padding: var(--space-md); + line-height: 1.4; + } + + /* Make file input area larger for mobile touch */ + .drop-area::after { + content: ''; + position: absolute; + top: -10px; + left: -10px; + right: -10px; + bottom: -10px; + z-index: -1; + } + + .share-link { + padding: var(--space-md); + } + + .share-link label { + font-size: 0.95rem; + } + + .share-link input { + padding: var(--space-md); + font-size: 0.85rem; + max-width: 100%; + } + + /* Progress bar mobile optimization */ + .upload-progress { + margin-top: var(--space-md); + } + + .progress-info { + margin-bottom: var(--space-xs); + } + + .status-text, + .progress-text { + font-size: 0.85rem; + } + + .progress-bar { + height: 10px; + } + + /* Download page mobile */ + .download-container { + padding: var(--space-md); + } + + .file-info p { + font-size: 0.95rem; + margin-bottom: var(--space-xs); + } + + .button-download { + padding: var(--space-md) var(--space-lg); + font-size: 1rem; + width: 100%; + display: block; + text-align: center; + touch-action: manipulation; + } + + /* Code sections mobile - hide cURL section */ + .code-section { + display: none; + } + + .code-section h3 { + font-size: 1rem; + margin-bottom: var(--space-sm); + } + + .code-section p { + font-size: 0.9rem; } .code-block { - font-size: 0.8rem; + padding: var(--space-sm); + font-size: 0.75rem; + overflow-x: auto; + -webkit-overflow-scrolling: touch; + } + + .code-block code { + font-size: 0.75rem; + } + + code.inline-highlight { + font-size: 0.85rem; + padding: 0.05rem 0.3rem; + } + + /* Info list mobile */ + .info-list { + padding: var(--space-md); + } + + .info-list h3 { + font-size: 1rem; + margin-bottom: var(--space-sm); + } + + .info-list li { + padding-left: var(--space-md); + margin-bottom: var(--space-sm); + font-size: 0.9rem; + } + + /* Footer mobile */ + .footer { + padding: var(--space-lg) 0; + font-size: 0.85rem; + } +} + +/* Extra small mobile devices */ +@media (max-width: 375px) { + .header h1 { + font-size: 1.5rem; + } + + .beta-badge { + display: block; + margin-left: 0; + margin-top: var(--space-sm); + width: fit-content; + margin-inline: auto; + } + + .code-block { + font-size: 0.7rem; + } +} + +/* Mobile landscape orientation */ +@media (max-width: 900px) and (orientation: landscape) { + .header { + padding: var(--space-md) 0; + } + + .header h1 { + font-size: 1.5rem; + } + + .transfer-container { + min-height: 140px; + } + + .drop-area { + min-height: 140px; + } +} + +/* Touch device optimizations */ +@media (pointer: coarse) { + .drop-area { + cursor: default; + } + + .button-download, + .drop-area { + -webkit-tap-highlight-color: transparent; + user-select: none; + } + + a { + -webkit-tap-highlight-color: rgba(31, 111, 235, 0.2); + } +} + +/* High DPI screens */ +@media (-webkit-min-device-pixel-ratio: 2), (min-resolution: 192dpi) { + .progress-bar { + transform: translateZ(0); + will-change: width; } } diff --git a/static/index.html b/static/index.html index 929c6fd..68f4837 100644 --- a/static/index.html +++ b/static/index.html @@ -19,7 +19,7 @@

Transit.shBeta

Direct file transfer without intermediate storage

- Notice: This service is in beta. If you encounter any bug, please report it here.
Sending files via mobile browsers is not supported yet. + Notice: This service is in beta. If you encounter any bug, please report it here.
@@ -29,14 +29,14 @@

Send a file

Drag and drop or select a file to generate a download link.

-
-

Drag and drop your file here, or click to select a file

- +
+

Drag and drop your file here, or tap to select a file

+
@@ -54,7 +54,7 @@

Send a file

Using cURL

-

You can use the curl command to transfer from your terminal. 100 MiB maximum.

+

You can use the curl command to transfer from your terminal. 1 GiB maximum.

# Send @@ -95,6 +95,6 @@

Important Information

- + diff --git a/static/js/file-transfer.js b/static/js/file-transfer.js index c0e270f..ae2bc71 100644 --- a/static/js/file-transfer.js +++ b/static/js/file-transfer.js @@ -1,36 +1,61 @@ -document.addEventListener('DOMContentLoaded', initFileTransfer); +const CHUNK_SIZE_MOBILE = 32 * 1024; // 32KiB for mobile devices +const CHUNK_SIZE_DESKTOP = 64 * 1024; // 64KiB for desktop devices +const BUFFER_THRESHOLD_MOBILE = CHUNK_SIZE_MOBILE * 16; // 512KiB buffer threshold for mobile +const BUFFER_THRESHOLD_DESKTOP = CHUNK_SIZE_DESKTOP * 16; // 1MiB buffer threshold for desktop +const BUFFER_CHECK_INTERVAL = 200; // 200ms interval for buffer checks +const SHARE_LINK_FOCUS_DELAY = 300; // 300ms delay before focusing share link +const TRANSFER_FINALIZE_DELAY = 1000; // 1000ms delay before finalizing transfer +const MOBILE_BREAKPOINT = 768; // 768px mobile breakpoint +const TRANSFER_ID_MAX_NUMBER = 1000; // Maximum number for transfer ID generation (0-999) + +const DEBUG_LOGS = true; +const log = { + debug: (...args) => DEBUG_LOGS && console.debug(...args), + info: (...args) => console.info(...args), + warn: (...args) => console.warn(...args), + error: (...args) => console.error(...args) +}; + +initFileTransfer(); function initFileTransfer() { + log.debug('Initializing file transfer interface'); const elements = { dropArea: document.getElementById('drop-area'), + dropAreaText: document.getElementById('drop-area-text'), fileInput: document.getElementById('file-input'), uploadProgress: document.getElementById('upload-progress'), progressBarFill: document.getElementById('progress-bar-fill'), progressText: document.getElementById('progress-text'), statusText: document.getElementById('status-text'), shareLink: document.getElementById('share-link'), - shareUrl: document.getElementById('share-url'), + shareUrl: document.getElementById('share-url') }; + if (isMobileDevice() && elements.dropAreaText) { + elements.dropAreaText.textContent = 'Tap here to select a file'; + log.debug('Updated UI text for mobile device'); + } + setupEventListeners(elements); } function setupEventListeners(elements) { const { dropArea, fileInput } = elements; - // Prevent default drag behaviors ['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => { dropArea.addEventListener(eventName, preventDefaults, false); document.body.addEventListener(eventName, preventDefaults, false); }); + ['dragenter', 'dragover'].forEach(eventName => { dropArea.addEventListener(eventName, () => highlight(dropArea), false); }); + ['dragleave', 'drop'].forEach(eventName => { dropArea.addEventListener(eventName, () => unhighlight(dropArea), false); }); - // Handle dropped files dropArea.addEventListener('drop', e => handleDrop(e, elements), false); dropArea.addEventListener('click', () => fileInput.click()); fileInput.addEventListener('change', () => { @@ -40,179 +65,184 @@ function setupEventListeners(elements) { }); } -// Event helpers function preventDefaults(e) { e.preventDefault(); e.stopPropagation(); } -function highlight(dropArea) { - dropArea.classList.add('highlight'); +function highlight(element) { + element.classList.add('highlight'); } -function unhighlight(dropArea) { - dropArea.classList.remove('highlight'); +function unhighlight(element) { + element.classList.remove('highlight'); } function handleDrop(e, elements) { - const dt = e.dataTransfer; - const files = dt.files; + const files = e.dataTransfer.files; handleFiles(files, elements); } function handleFiles(files, elements) { if (files.length > 0) { - uploadFile(files[0], elements); + const file = files[0]; + log.info('File selected:', { + name: file.name, + size: file.size, + type: file.type, + lastModified: new Date(file.lastModified).toISOString() + }); + uploadFile(file, elements); } } -// Transfer ID generation -function generateTransferId() { - // Generate a UUID to get a high-entropy random value. - const uuid = self.crypto.randomUUID(); - const hex = uuid.replace(/-/g, ''); // We only need the hex digits - - const consonants = 'bcdfghjklmnpqrstvwxyz'; - const vowels = 'aeiou'; - - // Function to create a pronounceable "word" from a hex string segment. - const createWord = (hexSegment) => { - let word = ''; - for (let i = 0; i < hexSegment.length; i++) { - const charCode = parseInt(hexSegment[i], 16); - if (i % 2 === 0) { // Consonant - word += consonants[charCode % consonants.length]; - } else { // Vowel - word += vowels[charCode % vowels.length]; - } - } - return word; - }; - - // Create two 6-letter words from the first 12 characters of the UUID hex. - const word1 = createWord(hex.substring(0, 6)); - const word2 = createWord(hex.substring(6, 12)); - - // Use the next 4 hex characters for a number between 0 and 999. - // This gives a larger range than the original 0-99. - const num = parseInt(hex.substring(12, 15), 16) % 1000; - - return `${word1}-${word2}-${num}`; -} - -// UI updates function showProgress(elements, message = 'Connecting...') { const { uploadProgress, statusText } = elements; uploadProgress.style.display = 'block'; statusText.textContent = message; - uploadProgress.setAttribute('aria-valuenow', '0'); // Add ARIA update + uploadProgress.setAttribute('aria-valuenow', '0'); } function updateProgress(elements, progress) { - const { progressBarFill, progressText, uploadProgress } = elements; // Add uploadProgress + const { progressBarFill, progressText, uploadProgress, statusText } = elements; const percentage = Math.min(100, Math.round(progress * 100)); progressBarFill.style.width = `${percentage}%`; progressText.textContent = `${percentage}%`; - uploadProgress.setAttribute('aria-valuenow', percentage); // Add ARIA update + uploadProgress.setAttribute('aria-valuenow', percentage); if (percentage === 100) { - elements.statusText.textContent = 'Completing transfer...'; + statusText.textContent = 'Completing transfer...'; } } function displayShareLink(elements, transferId) { const { shareUrl, shareLink, dropArea } = elements; - shareUrl.value = `https://transit.sh/${transferId}`; + shareUrl.value = `${window.location.origin}/${transferId}`; shareLink.style.display = 'flex'; dropArea.style.display = 'none'; - // Focus and select the share URL for easy copying setTimeout(() => { shareUrl.focus(); shareUrl.select(); - }, 300); + }, SHARE_LINK_FOCUS_DELAY); } function uploadFile(file, elements) { - const { statusText } = elements; const transferId = generateTransferId(); - const ws = new WebSocket(`wss://transit.sh/send/${transferId}`); - let abortController = new AbortController(); + + const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:'; + const wsUrl = `${wsProtocol}//${window.location.host}/send/${transferId}`; + + log.info('Starting upload:', { transferId, fileName: file.name, fileSize: file.size, wsUrl }); + + const ws = new WebSocket(wsUrl); + const abortController = new AbortController(); + const uploadState = { + file: file, + transferId: transferId, + isUploading: false, + wakeLock: null + }; showProgress(elements); - // WebSocket event handlers - ws.onopen = () => handleWsOpen(ws, file, transferId, elements, abortController); - ws.onmessage = (event) => handleWsMessage(event, ws, file, elements, abortController); - ws.onerror = (error) => { - handleWsError(error, statusText); - cleanupTransfer(abortController); + ws.onopen = () => handleWsOpen(ws, file, transferId, elements); + ws.onmessage = (event) => handleWsMessage(event, ws, file, elements, abortController, uploadState); + ws.onerror = (error) => handleWsError(error, elements.statusText); + ws.onclose = (event) => { + log.info('WebSocket connection closed:', { code: event.code, reason: event.reason, wasClean: event.wasClean }); + if (uploadState.isUploading && !event.wasClean) { + elements.statusText.textContent = 'Connection lost. Please try uploading again.'; + elements.statusText.style.color = 'var(--error)'; + } + cleanupTransfer(abortController, uploadState); }; - ws.onclose = () => { - console.log('WebSocket connection closed'); - cleanupTransfer(abortController); + + const handleVisibilityChange = () => { + if (document.hidden && uploadState.isUploading) { + log.warn('App went to background during active upload'); + if (isMobileDevice()) { + elements.statusText.textContent = '⚠️ Keep app in foreground during upload'; + elements.statusText.style.color = 'var(--warning)'; + } + } else if (!document.hidden && uploadState.isUploading) { + log.info('App returned to foreground'); + if (ws.readyState !== WebSocket.OPEN) { + elements.statusText.textContent = 'Connection lost. Please try uploading again.'; + elements.statusText.style.color = 'var(--error)'; + uploadState.isUploading = false; + } + } }; + document.addEventListener('visibilitychange', handleVisibilityChange); - // Ensure cleanup on page unload - window.addEventListener('beforeunload', () => cleanupTransfer(abortController), { once: true }); -} + const handleBeforeUnload = (e) => { + if (uploadState.isUploading) { + e.preventDefault(); + e.returnValue = 'File upload in progress. Are you sure you want to leave?'; + return e.returnValue; + } + }; + window.addEventListener('beforeunload', handleBeforeUnload); -function handleWsOpen(ws, file, transferId, elements, abortController) { - const { statusText } = elements; + window.addEventListener('unload', () => { + document.removeEventListener('visibilitychange', handleVisibilityChange); + window.removeEventListener('beforeunload', handleBeforeUnload); + cleanupTransfer(abortController, uploadState); + }, { once: true }); + if (isMobileDevice() && 'wakeLock' in navigator) { + requestWakeLock(uploadState); + } +} + +function handleWsOpen(ws, file, transferId, elements) { + log.info('WebSocket connection opened'); const metadata = { file_name: file.name, file_size: file.size, file_type: file.type || 'application/octet-stream' }; - + log.info('Sending file metadata:', metadata); ws.send(JSON.stringify(metadata)); - statusText.textContent = 'Waiting for the receiver to start the download... (max. 5 minutes)'; + elements.statusText.textContent = 'Waiting for the receiver to start the download... (max. 5 minutes)'; displayShareLink(elements, transferId); } -function handleWsMessage(event, ws, file, elements, abortController) { - const { statusText } = elements; +function handleWsMessage(event, ws, file, elements, abortController, uploadState) { + log.debug('WebSocket message received:', event.data); if (event.data === 'Go for file chunks') { - statusText.textContent = 'Peer connected. Transferring file...'; - sendFileInChunks(ws, file, elements, abortController); + log.info('Receiver connected, starting file transfer'); + elements.statusText.textContent = 'Peer connected. Transferring file...'; + uploadState.isUploading = true; + sendFileInChunks(ws, file, elements, abortController, uploadState); } else if (event.data.startsWith('Error')) { - statusText.textContent = event.data; - statusText.style.color = 'var(--error)'; - console.error('Server error:', event.data); - cleanupTransfer(abortController); + log.error('Server error:', event.data); + elements.statusText.textContent = event.data; + elements.statusText.style.color = 'var(--error)'; + cleanupTransfer(abortController, uploadState); } else { - console.log('Unexpected message:', event.data); + log.warn('Unexpected message:', event.data); } } function handleWsError(error, statusText) { + log.error('WebSocket error:', error); statusText.textContent = 'Error: ' + (error.message || 'Connection failed'); statusText.style.color = 'var(--error)'; - console.error('WebSocket Error:', error); } -function cleanupTransfer(abortController) { - if (abortController) { - abortController.abort(); - abortController = null; - } -} +async function sendFileInChunks(ws, file, elements, abortController, uploadState) { + const chunkSize = isMobileDevice() ? CHUNK_SIZE_MOBILE : CHUNK_SIZE_DESKTOP; + log.info('Starting chunked upload:', { chunkSize, fileSize: file.size, totalChunks: Math.ceil(file.size / chunkSize) }); -async function sendFileInChunks(ws, file, elements, abortController) { - const { statusText } = elements; - const chunkSize = 64 * 1024; // 64KiB - let offset = 0; const reader = new FileReader(); - + let offset = 0; const signal = abortController.signal; - if (signal.aborted) return; try { while (offset < file.size && !signal.aborted) { - // Wait until WebSocket buffer has room - await waitForWebSocketBuffer(ws); - + await waitForWebSocketBuffer(ws, signal); if (signal.aborted) break; const end = Math.min(offset + chunkSize, file.size); @@ -224,53 +254,35 @@ async function sendFileInChunks(ws, file, elements, abortController) { ws.send(chunk); offset += chunk.byteLength; - // Update progress - updateProgress(elements, offset / file.size); + const progress = offset / file.size; + log.debug('Chunk sent:', { offset, progress: `${Math.round(progress * 100)}%`, bufferedAmount: ws.bufferedAmount }); + updateProgress(elements, progress); } - // If we completed successfully (not aborted), finalize the transfer if (!signal.aborted && offset >= file.size) { - finalizeTransfer(ws, statusText); + log.info('Upload completed successfully'); + uploadState.isUploading = false; + finalizeTransfer(ws, elements.statusText, uploadState); } } catch (error) { if (!signal.aborted) { - statusText.textContent = `Error: ${error.message || 'Upload failed'}`; - console.error('Upload error:', error); + log.error('Upload failed:', error); + elements.statusText.textContent = `Error: ${error.message || 'Upload failed'}`; ws.close(); } } finally { - // Cleanup reader.onload = null; reader.onerror = null; } } -// Promise-based wait for WebSocket buffer to clear -function waitForWebSocketBuffer(ws) { - return new Promise(resolve => { - const checkBuffer = () => { - if (ws.bufferedAmount < 1024 * 1024) { // 1MiB threshold (16 chunks of 64KiB) - resolve(); - } else { - setTimeout(checkBuffer, 200); - } - }; - checkBuffer(); - }); -} - -// Promise-based file chunk reading function readChunkAsArrayBuffer(reader, blob, signal) { return new Promise((resolve, reject) => { - if (signal.aborted) { - resolve(null); - return; - } + if (signal.aborted) return resolve(null); reader.onload = e => resolve(e.target.result); - reader.onerror = e => reject(new Error('Error reading file')); + reader.onerror = () => reject(new Error('Error reading file')); - // Add abort handling signal.addEventListener('abort', () => { reader.abort(); resolve(null); @@ -280,12 +292,80 @@ function readChunkAsArrayBuffer(reader, blob, signal) { }); } -function finalizeTransfer(ws, statusText) { - // Send empty chunk to signal end of transfer +function waitForWebSocketBuffer(ws, signal) { + return new Promise(resolve => { + const threshold = isMobileDevice() ? BUFFER_THRESHOLD_MOBILE : BUFFER_THRESHOLD_DESKTOP; + const checkBuffer = () => { + if (signal.aborted || ws.bufferedAmount < threshold) { + resolve(); + } else { + setTimeout(checkBuffer, BUFFER_CHECK_INTERVAL); + } + }; + checkBuffer(); + }); +} + +function finalizeTransfer(ws, statusText, uploadState) { + log.info('Sending end-of-transfer signal'); ws.send(new ArrayBuffer(0)); setTimeout(() => { + log.info('Transfer finalized successfully'); statusText.textContent = '✓ Transfer complete!'; + if (uploadState.wakeLock) { + uploadState.wakeLock.release().catch(() => {}); + uploadState.wakeLock = null; + } ws.close(); - }, 500); + }, TRANSFER_FINALIZE_DELAY); +} + +function cleanupTransfer(abortController, uploadState) { + if (abortController) { + abortController.abort(); + } + if (uploadState && uploadState.wakeLock) { + uploadState.wakeLock.release().catch(() => {}); + uploadState.wakeLock = null; + } +} + +function isMobileDevice() { + return /Android|webOS|iPhone|iPad|iPod|BlackBerry|IEMobile|Opera Mini/i.test(navigator.userAgent) || + (window.matchMedia && window.matchMedia(`(max-width: ${MOBILE_BREAKPOINT}px)`).matches); +} + +function generateTransferId() { + const uuid = self.crypto.randomUUID(); + const hex = uuid.replace(/-/g, ''); + const consonants = 'bcdfghjklmnpqrstvwxyz'; + const vowels = 'aeiou'; + + const createWord = (hexSegment) => { + let word = ''; + for (let i = 0; i < hexSegment.length; i++) { + const charCode = parseInt(hexSegment[i], 16); + word += (i % 2 === 0) ? consonants[charCode % consonants.length] : vowels[charCode % vowels.length]; + } + return word; + }; + + const word1 = createWord(hex.substring(0, 6)); + const word2 = createWord(hex.substring(6, 12)); + const num = parseInt(hex.substring(12, 15), 16) % TRANSFER_ID_MAX_NUMBER; + + const transferId = `${word1}-${word2}-${num}`; + log.debug('Generated transfer ID:', transferId); + return transferId; +} + +async function requestWakeLock(uploadState) { + try { + uploadState.wakeLock = await navigator.wakeLock.request('screen'); + log.info('Wake lock acquired to prevent screen sleep'); + uploadState.wakeLock.addEventListener('release', () => log.debug('Wake lock released')); + } catch (err) { + log.warn('Wake lock request failed:', err.message); + } } diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 9495c26..6dd06f8 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,7 +4,7 @@ import httpx from fastapi import WebSocketDisconnect from starlette.responses import ClientDisconnect -from websockets.exceptions import ConnectionClosedError, InvalidStatus +from websockets.exceptions import InvalidStatus from tests.helpers import generate_test_file from tests.ws_client import WebSocketTestClient @@ -23,7 +23,7 @@ async def test_invalid_uid(websocket_client: WebSocketTestClient, test_client: h response_put = await test_client.put(f"/{uid}/test.txt") assert response_put.status_code == expected_status - with pytest.raises((ConnectionClosedError, InvalidStatus)): + with pytest.raises(InvalidStatus): async with websocket_client.websocket_connect(f"/send/{uid}") as _: # type: ignore pass @@ -88,55 +88,76 @@ async def test_transfer_id_already_used(websocket_client: WebSocketTestClient): @pytest.mark.anyio async def test_receiver_disconnects(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient): - """Tests that the sender is notified if the receiver disconnects mid-transfer.""" + """Tests that transfers can be resumed after receiver disconnects.""" uid = "receiver-disconnect" file_content, file_metadata = generate_test_file(size_in_kb=128) # Larger file + received_bytes = b"" async def sender(): - with pytest.raises(ConnectionClosedError, match="Transfer was interrupted by the receiver"): - async with websocket_client.websocket_connect(f"/send/{uid}") as ws: - await anyio.sleep(0.1) + async with websocket_client.websocket_connect(f"/send/{uid}") as ws: + await anyio.sleep(0.1) - await ws.send_json({ - 'file_name': file_metadata.name, - 'file_size': file_metadata.size, - 'file_type': file_metadata.type - }) - await anyio.sleep(1.0) # Allow receiver to connect + await ws.send_json({ + 'file_name': file_metadata.name, + 'file_size': file_metadata.size, + 'file_type': file_metadata.type + }) + await anyio.sleep(1.0) # Allow receiver to connect - response = await ws.recv() - await anyio.sleep(0.1) - assert response == "Go for file chunks" + response = await ws.recv() + await anyio.sleep(0.1) + assert response == "Go for file chunks" - chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)] - for chunk in chunks: - await ws.send_bytes(chunk) - await anyio.sleep(0.1) + chunks = [file_content[i:i + 4096] for i in range(0, len(file_content), 4096)] + for chunk in chunks: + await ws.send_bytes(chunk) + await anyio.sleep(0.05) - await anyio.sleep(2.0) + # Send completion marker + await ws.send_bytes(b'') + await anyio.sleep(2.0) async def receiver(): + nonlocal received_bytes await anyio.sleep(1.0) headers = {'Accept': '*/*'} + # First download attempt - disconnect after receiving some data async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response: await anyio.sleep(0.1) - response.raise_for_status() + i = 0 - with pytest.raises(ClientDisconnect): - async for chunk in response.aiter_bytes(4096): - if not chunk: - break - i += 1 - if i >= 5: - raise ClientDisconnect("Simulated disconnect") - await anyio.sleep(0.025) + async for chunk in response.aiter_bytes(4096): + if not chunk: + break + received_bytes += chunk + i += 1 + if i >= 5: # Disconnect after receiving 5 chunks + break + await anyio.sleep(0.025) + + # Wait a bit before resuming + await anyio.sleep(0.5) + + # Resume the download + async with test_client.stream("GET", f"/{uid}?download=true", headers=headers) as response: + response.raise_for_status() + assert response.status_code in (200, 206) # 206 for partial content on resume + + async for chunk in response.aiter_bytes(4096): + if not chunk: + break + received_bytes += chunk async with anyio.create_task_group() as tg: tg.start_soon(sender) tg.start_soon(receiver) + # Verify that the full file was received + assert len(received_bytes) == len(file_content) + assert received_bytes == file_content + @pytest.mark.anyio async def test_prefetcher_request(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient): diff --git a/tests/test_resumable.py b/tests/test_resumable.py new file mode 100644 index 0000000..e14c060 --- /dev/null +++ b/tests/test_resumable.py @@ -0,0 +1,139 @@ +import anyio +import pytest +import httpx +from tests.helpers import generate_test_file +from tests.ws_client import WebSocketTestClient + + +@pytest.mark.anyio +async def test_http_resumable_download(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient): + """Test that HTTP downloads can be resumed after disconnection.""" + uid = "test-resume-http" + file_content, file_metadata = generate_test_file(size_in_kb=256) + + # Start the sender + async def sender(): + async with websocket_client.websocket_connect(f"/send/{uid}") as ws: + await ws.send_json({ + 'file_name': file_metadata.name, + 'file_size': file_metadata.size, + 'file_type': file_metadata.type + }) + + response = await ws.recv() + assert response == "Go for file chunks" + + # Send file in chunks slowly to allow resume + chunk_size = 8192 + for i in range(0, len(file_content), chunk_size): + chunk = file_content[i:i + chunk_size] + await ws.send_bytes(chunk) + await anyio.sleep(0.05) # Slower sending to allow resume + + # Send completion marker + await ws.send_bytes(b'') + await anyio.sleep(1) # Keep connection open + + # Start receiver that will disconnect and resume + async def receiver(): + await anyio.sleep(0.5) # Let sender start + + received_bytes = b"" + + # First download - disconnect after 25% of the file + async with test_client.stream("GET", f"/{uid}?download=true") as response: + assert response.status_code == 200 + bytes_to_receive = file_metadata.size // 4 + + async for chunk in response.aiter_bytes(4096): + received_bytes += chunk + if len(received_bytes) >= bytes_to_receive: + break + + first_download_size = len(received_bytes) + assert first_download_size >= bytes_to_receive + + await anyio.sleep(0.2) # Small pause before resuming + + # Resume download + async with test_client.stream("GET", f"/{uid}?download=true") as response: + # Should get 206 Partial Content for resume + assert response.status_code == 206 + + # Check Content-Range header + assert 'content-range' in response.headers + + async for chunk in response.aiter_bytes(4096): + received_bytes += chunk + + # Verify we received the complete file + assert len(received_bytes) == file_metadata.size + assert received_bytes == file_content + + async with anyio.create_task_group() as tg: + tg.start_soon(sender) + tg.start_soon(receiver) + + +@pytest.mark.anyio +async def test_multiple_resume_attempts(test_client: httpx.AsyncClient, websocket_client: WebSocketTestClient): + """Test that transfers can be resumed multiple times.""" + uid = "test-multi-resume" + file_content, file_metadata = generate_test_file(size_in_kb=128) + + # Start the sender + async def sender(): + async with websocket_client.websocket_connect(f"/send/{uid}") as ws: + await ws.send_json({ + 'file_name': file_metadata.name, + 'file_size': file_metadata.size, + 'file_type': file_metadata.type + }) + + response = await ws.recv() + assert response == "Go for file chunks" + + # Send file slowly to allow multiple resume attempts + chunk_size = 4096 + for i in range(0, len(file_content), chunk_size): + chunk = file_content[i:i + chunk_size] + await ws.send_bytes(chunk) + await anyio.sleep(0.02) + + await ws.send_bytes(b'') + await anyio.sleep(2) + + # Receiver with multiple disconnects + async def receiver(): + await anyio.sleep(0.5) + + received_bytes = b"" + download_attempts = [0.2, 0.4, 0.6, 1.0] # Download percentages + + for attempt_idx, target_percentage in enumerate(download_attempts): + target_bytes = int(file_metadata.size * target_percentage) + + async with test_client.stream("GET", f"/{uid}?download=true") as response: + # First attempt gets 200, resumes get 206 + if attempt_idx == 0: + assert response.status_code == 200 + else: + assert response.status_code == 206 + + async for chunk in response.aiter_bytes(4096): + received_bytes += chunk + + # Stop at target percentage (except last attempt) + if target_percentage < 1.0 and len(received_bytes) >= target_bytes: + break + + if target_percentage < 1.0: + await anyio.sleep(0.1) # Pause between attempts + + # Verify complete file received + assert len(received_bytes) == file_metadata.size + assert received_bytes == file_content + + async with anyio.create_task_group() as tg: + tg.start_soon(sender) + tg.start_soon(receiver) \ No newline at end of file diff --git a/views/http.py b/views/http.py index 64b2402..1a24ef8 100644 --- a/views/http.py +++ b/views/http.py @@ -37,7 +37,7 @@ async def http_upload(request: Request, uid: str, filename: str): raise HTTPException(status_code=400, detail="Cannot decode file metadata from HTTP headers.") except ValidationError as e: log.error("△ Invalid file metadata.", exc_info=e) - raise HTTPException(status_code=400, detail="Invalid file metadata.") + raise HTTPException(status_code=400, detail=f"Invalid file metadata: {e.errors(include_url=False, include_context=True, include_input=False)}") if file.size > 1024**3: raise HTTPException(status_code=413, detail="File too large. 1GiB maximum for HTTP.") @@ -54,13 +54,13 @@ async def http_upload(request: Request, uid: str, filename: str): raise HTTPException(status_code=400, detail="Invalid transfer ID or file metadata.") try: - await transfer.wait_for_client_connected() + await transfer.wait_for_receiver() except TimeoutError: log.warning("△ Receiver did not connect in time.") - raise HTTPException(status_code=408, detail="Client did not connect in time.") + raise HTTPException(status_code=408, detail="Receiver did not connect in time.") transfer.info("△ Starting upload...") - await transfer.collect_upload( + await transfer.consume_upload( stream=request.stream(), on_error=raise_http_exception(request), ) @@ -107,21 +107,33 @@ async def http_download(request: Request, uid: str): return templates.TemplateResponse(request, "preview.html", transfer.file.to_readable_dict()) if not is_curl and not request.query_params.get('download'): - log.info(f"▼ Browser request detected, serving download page. UA: ({request.headers.get('user-agent')})") - return templates.TemplateResponse(request, "download.html", transfer.file.to_readable_dict() | {'receiver_connected': await transfer.is_receiver_connected()}) - - elif not await transfer.set_receiver_connected(): - raise HTTPException(status_code=409, detail="A client is already downloading this file.") - - await transfer.set_client_connected() - - transfer.info("▼ Starting download...") - data_stream = StreamingResponse( - transfer.supply_download(on_error=raise_http_exception(request)), - status_code=200, + log.info(f"▼ Browser request detected, serving download page") + receiver_connected = await transfer.receiver_connected + return templates.TemplateResponse(request, "download.html", + transfer.file.to_readable_dict() | {'receiver_connected': receiver_connected}) + + await transfer.notify_receiver_connected() + + progress = await transfer.store.get_progress() + if progress > 0: + transfer.info(f"▼ Resuming from byte {progress}") + headers = { + "Content-Disposition": f"attachment; filename={file_name}", + "Content-Range": f"bytes {progress}-*/{file_size}" + } + status_code = 206 + else: + transfer.info("▼ Starting download") + headers = { + "Content-Disposition": f"attachment; filename={file_name}", + "Content-Length": str(file_size) + } + status_code = 200 + + return StreamingResponse( + transfer.produce_download(on_error=raise_http_exception(request)), + status_code=status_code, media_type=file_type, background=BackgroundTask(transfer.finalize_download), - headers={"Content-Disposition": f"attachment; filename={file_name}", "Content-Length": str(file_size)} + headers=headers ) - - return data_stream diff --git a/views/websockets.py b/views/websockets.py index 33656fb..ad1ad82 100644 --- a/views/websockets.py +++ b/views/websockets.py @@ -54,21 +54,17 @@ async def websocket_upload(websocket: WebSocket, uid: str): return try: - await transfer.wait_for_client_connected() + await transfer.wait_for_receiver() except TimeoutError: - log.warning("△ Receiver did not connect in time.") - await websocket.send_text(f"Error: Receiver did not connect in time.") - return - except Exception as e: - log.error("△ Error while waiting for receiver connection.", exc_info=e) - await websocket.send_text("Error: Error while waiting for receiver connection.") + log.warning("△ Receiver timeout") + await websocket.send_text("Error: Receiver timeout") return transfer.debug("△ Sending go-ahead...") await websocket.send_text("Go for file chunks") transfer.info("△ Starting upload...") - await transfer.collect_upload( + await transfer.consume_upload( stream=websocket.iter_bytes(), on_error=send_error_and_close(websocket), ) @@ -92,14 +88,15 @@ async def websocket_download(background_tasks: BackgroundTasks, websocket: WebSo await websocket.send_text("File not found") return - if await transfer.is_receiver_connected(): - log.warning("▼ A client is already downloading this file.") - await websocket.send_text("Error: A client is already downloading this file.") - return - + progress = await transfer.store.get_progress() file_name, file_size, file_type = transfer.get_file_info() - transfer.debug(f"▼ File: name={file_name}, size={file_size}, type={file_type}") - await websocket.send_json({'file_name': file_name, 'file_size': file_size, 'file_type': file_type}) + + metadata = {'file_name': file_name, 'file_size': file_size, 'file_type': file_type} + if progress > 0: + metadata['resume_from'] = progress + transfer.info(f"▼ Resuming from byte {progress}") + + await websocket.send_json(metadata) transfer.info("▼ Waiting for go-ahead...") while True: @@ -107,22 +104,15 @@ async def websocket_download(background_tasks: BackgroundTasks, websocket: WebSo msg = await websocket.receive_text() if msg == "Go for file chunks": break - transfer.warning(f"▼ Unexpected message: {msg}") except WebSocketDisconnect: - transfer.warning("▼ Client disconnected while waiting for go-ahead") + transfer.warning("▼ Disconnected while waiting") return - if not await transfer.set_receiver_connected(): - log.warning("▼ A client is already downloading this file.") - await websocket.send_text("Error: A client is already downloading this file.") - return - - transfer.info("▼ Notifying client is connected.") - await transfer.set_client_connected() + await transfer.notify_receiver_connected() background_tasks.add_task(transfer.finalize_download) - transfer.info("▼ Starting download...") - async for chunk in transfer.supply_download(on_error=send_error_and_close(websocket)): + transfer.info("▼ Starting download") + async for chunk in transfer.produce_download(on_error=send_error_and_close(websocket)): await websocket.send_bytes(chunk) await websocket.send_bytes(b'') - transfer.info("▼ Download complete.") + transfer.info("▼ Download complete")