Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions goosebit/updates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,14 @@ async def generate_chunk(request: Request, device: Device) -> list[UpdateChunk]:
if software is None:
return []

# Always use the download endpoint for consistency, the endpoint
# will handle both local and remote files appropriately.
href = str(request.url_for("download_artifact", dev_id=device.id))
# For remote http(s) URLs, pass the original URL directly to the device.
# This preserves credentials in the URL and allows relative path resolution (e.g. for casync).
# For s3:// or file:// URIs, use the download endpoint which handles proxying.
parsed_uri = urlparse(software.uri)
if parsed_uri.scheme in ("http", "https"):
href = software.uri
else:
href = str(request.url_for("download_artifact", dev_id=device.id))

return [
UpdateChunk(
Expand Down
140 changes: 140 additions & 0 deletions tests/unit/updates/test_generate_chunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import cast
from unittest.mock import MagicMock

import pytest
import pytest_asyncio

from goosebit.db.models import (
Device,
Hardware,
Software,
UpdateModeEnum,
UpdateStateEnum,
)
from goosebit.updates import generate_chunk


@pytest.fixture
def mock_request() -> MagicMock:
"""Create a mock request object."""
request = MagicMock()
request.url_for.return_value = "http://test/DEFAULT/controller/v1/test-device/download"
return request


@pytest_asyncio.fixture
async def hardware(db: None) -> Hardware:
"""Create test hardware."""
return cast(Hardware, await Hardware.create(model="test-model", revision="1.0"))


@pytest_asyncio.fixture
async def device(db: None, hardware: Hardware) -> Device:
"""Create test device."""
return cast(
Device,
await Device.create(
id="test-device",
last_state=UpdateStateEnum.REGISTERED,
update_mode=UpdateModeEnum.ASSIGNED,
hardware=hardware,
),
)


async def create_software_with_uri(uri: str, hardware: Hardware) -> Software:
"""Helper to create software with a specific URI."""
software = cast(
Software,
await Software.create(
version="1.0.0",
hash="testhash123",
size=1024,
uri=uri,
),
)
await software.compatibility.add(hardware)
return software


@pytest.mark.asyncio
async def test_generate_chunk_http_url_direct(mock_request: MagicMock, device: Device, hardware: Hardware) -> None:
"""Test that http:// URLs are passed directly to the device."""
uri = "http://example.com/firmware.swu"
software = await create_software_with_uri(uri, hardware)
device.assigned_software = software
await device.save()

chunks = await generate_chunk(mock_request, device)

assert len(chunks) == 1
assert chunks[0].artifacts[0].links["download"]["href"] == uri


@pytest.mark.asyncio
async def test_generate_chunk_https_url_direct(mock_request: MagicMock, device: Device, hardware: Hardware) -> None:
"""Test that https:// URLs are passed directly to the device."""
uri = "https://example.com/firmware.swu"
software = await create_software_with_uri(uri, hardware)
device.assigned_software = software
await device.save()

chunks = await generate_chunk(mock_request, device)

assert len(chunks) == 1
assert chunks[0].artifacts[0].links["download"]["href"] == uri


@pytest.mark.asyncio
async def test_generate_chunk_https_with_credentials_direct(
mock_request: MagicMock, device: Device, hardware: Hardware
) -> None:
"""Test that https:// URLs with credentials are passed directly, preserving credentials."""
uri = "https://user:secretpass@example.com/firmware.swu"
software = await create_software_with_uri(uri, hardware)
device.assigned_software = software
await device.save()

chunks = await generate_chunk(mock_request, device)

assert len(chunks) == 1
assert chunks[0].artifacts[0].links["download"]["href"] == uri


@pytest.mark.asyncio
async def test_generate_chunk_file_url_proxied(mock_request: MagicMock, device: Device, hardware: Hardware) -> None:
"""Test that file:// URLs use the proxy download endpoint."""
uri = "file:///path/to/firmware.swu"
software = await create_software_with_uri(uri, hardware)
device.assigned_software = software
await device.save()

chunks = await generate_chunk(mock_request, device)

assert len(chunks) == 1
assert chunks[0].artifacts[0].links["download"]["href"] == "http://test/DEFAULT/controller/v1/test-device/download"
mock_request.url_for.assert_called_with("download_artifact", dev_id=device.id)


@pytest.mark.asyncio
async def test_generate_chunk_s3_url_proxied(mock_request: MagicMock, device: Device, hardware: Hardware) -> None:
"""Test that s3:// URLs use the proxy download endpoint."""
uri = "s3://bucket-name/path/to/firmware.swu"
software = await create_software_with_uri(uri, hardware)
device.assigned_software = software
await device.save()

chunks = await generate_chunk(mock_request, device)

assert len(chunks) == 1
assert chunks[0].artifacts[0].links["download"]["href"] == "http://test/DEFAULT/controller/v1/test-device/download"
mock_request.url_for.assert_called_with("download_artifact", dev_id=device.id)


@pytest.mark.asyncio
async def test_generate_chunk_no_software_assigned(mock_request: MagicMock, device: Device) -> None:
"""Test that an empty list is returned when no software is assigned."""
# Device has no assigned software and no rollout
chunks = await generate_chunk(mock_request, device)

assert chunks == []
Loading