Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 49 additions & 1 deletion api/project/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""

from typing import Literal
from fastapi import APIRouter, Query, status
from fastapi import APIRouter, Query, status, UploadFile, File
from fastapi.responses import StreamingResponse
from core.deps import SessionDep, OpenSearchDep
from api.project.models import Project, ProjectCreate, ProjectPublic, ProjectsPublic
from api.samples.models import SampleCreate, SamplePublic, SamplesPublic, Attribute
Expand Down Expand Up @@ -162,6 +163,53 @@ def get_samples(
)


@router.get(
"/{project_id}/samples/download",
response_class=StreamingResponse,
response_model=None,
status_code=status.HTTP_200_OK,
tags=["Sample Endpoints"],
)
def download_samples(
session: SessionDep,
project_id: str,
) -> StreamingResponse:
"""
Download all samples as a TSV for a given project.
"""
return sample_services.download_samples_as_tsv(
session=session,
project_id=project_id,
)


@router.post(
"/{project_id}/samples/upload",
response_model=SamplesPublic,
tags=["Sample Endpoints"],
status_code=status.HTTP_201_CREATED,
)
async def upload_samples_to_project(
session: SessionDep,
opensearch_client: OpenSearchDep,
project_id: str,
file: UploadFile = File(...),
) -> SamplesPublic:
"""
Upload samples from a TSV file to a specific project.
"""
# Read file content
content = await file.read()
tsv_content = content.decode("utf-8")

return sample_services.upload_samples_from_tsv(
session=session,
opensearch_client=opensearch_client,
project_id=project_id,
tsv_content=tsv_content,
)


@router.put(
"/{project_id}/samples/{sample_id}",
response_model=SamplePublic,
Expand Down
198 changes: 195 additions & 3 deletions api/samples/services.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,30 @@
from fastapi import HTTPException, status
"""
Service functions for managing samples within projects.
"""
from typing import Literal
from io import StringIO
import csv

from fastapi import HTTPException, status
from fastapi.responses import StreamingResponse
from pydantic import PositiveInt

from sqlmodel import Session, select, func

from opensearchpy import OpenSearch

from api.samples.models import (
Attribute,
Sample,
SampleAttribute,
SampleCreate,
SamplePublic,
SamplesPublic,
Attribute,
)
from api.project.models import Project
from api.search.models import (
SearchDocument,
)
from opensearchpy import OpenSearch
from api.search.services import add_object_to_index


Expand Down Expand Up @@ -80,6 +87,64 @@ def add_sample_to_project(
return sample


def download_samples_as_tsv(
session: Session,
project_id: str,
) -> StreamingResponse:
"""
Download samples for a specific project as a TSV file.

Args:
session: Database session
project_id: Project ID to filter samples by
Returns:
StreamingResponse: TSV file response containing samples
"""

# Query samples for the project
samples = session.exec(
select(Sample).where(Sample.project_id == project_id)
).all()

if not samples:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No samples found for project {project_id}.",
)

# Collect all unique attribute keys across all samples for header
all_keys = set()
for sample in samples:
if sample.attributes:
for attr in sample.attributes:
all_keys.add(attr.key)
attribute_keys = sorted(list(all_keys))

# Create TSV in memory
output = StringIO()
writer = csv.writer(output, delimiter="\t")

# Write header
header = ["project_id", "sample_id"] + attribute_keys
writer.writerow(header)

# Write sample rows
for sample in samples:
row = [project_id, sample.sample_id]
attr_dict = {attr.key: attr.value for attr in (sample.attributes or [])}
for key in attribute_keys:
row.append(attr_dict.get(key, ""))
writer.writerow(row)

output.seek(0)

return StreamingResponse(
output,
media_type="text/tab-separated-values",
headers={"Content-Disposition": f"attachment; filename={project_id}_samples.tsv"},
)


def get_samples(
*,
session: Session,
Expand Down Expand Up @@ -226,3 +291,130 @@ def update_sample_in_project(
Attribute(key=attr.key, value=attr.value) for attr in (sample.attributes or [])
] if sample.attributes else []
)


def upload_samples_from_tsv(
session: Session,
opensearch_client: OpenSearch,
project_id: str,
tsv_content: str,
) -> SamplesPublic:
"""
Upload samples to a specific project from a TSV file content.

Args:
session: Database session
opensearch_client: OpenSearch client for indexing
project_id: Project ID to associate samples with
tsv_content: Content of the TSV file as a string
Returns:
SamplesPublic: List of created samples
"""
# Check if project exists
project = session.exec(
select(Project).where(Project.project_id == project_id)
).first()
if not project:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Project {project_id} not found.",
)

# Parse TSV content
# If tsv_content is bytes, decode it
if isinstance(tsv_content, bytes):
tsv_content = tsv_content.decode("utf-8")

tsv_io = StringIO(tsv_content)
reader = csv.DictReader(tsv_io, delimiter="\t")

# Validate header
if not reader.fieldnames:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="TSV file is empty or has no header.",
)

# Expected columns: project_id, sample_id, and any number of attribute columns
if "project_id" not in reader.fieldnames or "sample_id" not in reader.fieldnames:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="TSV file must contain 'project_id' and 'sample_id' columns.",
)

# Get attribute column names (all columns except project_id and sample_id)
attribute_keys = [
col for col in reader.fieldnames
if col not in ["project_id", "sample_id"]
]

# Create samples from TSV rows
created_samples = []
for row in reader:
# Validate project_id matches
row_project_id = row.get("project_id", "").strip()
if row_project_id != project_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=(
f"Row project_id '{row_project_id}' does not match "
f"URL project_id '{project_id}'."
),
)

sample_id = row.get("sample_id", "").strip()
if not sample_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Each row must have a non-empty sample_id.",
)

# Build attributes from remaining columns
attributes = []
for key in attribute_keys:
value = row.get(key, "").strip()
if value: # Only add non-empty attributes
attributes.append(Attribute(key=key, value=value))

# Create sample using existing service function
sample_in = SampleCreate(
sample_id=sample_id,
attributes=attributes if attributes else None,
)

sample = add_sample_to_project(
session=session,
opensearch_client=opensearch_client,
project_id=project_id,
sample_in=sample_in,
)
created_samples.append(sample)

# Convert to SamplePublic format
public_samples = [
SamplePublic(
sample_id=sample.sample_id,
project_id=sample.project_id,
attributes=sample.attributes,
)
for sample in created_samples
]

# Collect all unique attribute keys for data_cols
all_keys = set()
for sample in created_samples:
if sample.attributes:
for attr in sample.attributes:
all_keys.add(attr.key)
data_cols = sorted(list(all_keys)) if all_keys else None

return SamplesPublic(
data=public_samples,
data_cols=data_cols,
total_items=len(public_samples),
total_pages=1,
current_page=1,
per_page=len(public_samples),
has_next=False,
has_prev=False,
)
80 changes: 80 additions & 0 deletions tests/api/test_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,83 @@ def test_update_sample_attribute(client: TestClient, session: Session):
attr["key"] == "Condition" and attr["value"] == "Diseased"
for attr in response.json()["attributes"]
)


def test_download_samples_tsv(client: TestClient, session: Session):
"""
Test that we can download samples as a TSV file
"""
# Add a project to the database
new_project = Project(name="Test Project")
new_project.project_id = generate_project_id(session=session)
new_project.attributes = []
session.add(new_project)

# Add sample 1
sample_1 = Sample(sample_id="Sample_1", project_id=new_project.project_id)
session.add(sample_1)
session.flush() # Flush to get the sample ID for attributes

# Add attributes for Sample_1
attr_1_1 = SampleAttribute(sample_id=sample_1.id, key="Tissue", value="Liver")
attr_1_2 = SampleAttribute(sample_id=sample_1.id, key="Condition", value="Healthy")
session.add(attr_1_1)
session.add(attr_1_2)

# Add sample 2
sample_2 = Sample(sample_id="Sample_2", project_id=new_project.project_id)
session.add(sample_2)
session.flush()

# Add attributes for Sample_2
attr_2_1 = SampleAttribute(sample_id=sample_2.id, key="Tissue", value="Heart")
attr_2_2 = SampleAttribute(sample_id=sample_2.id, key="Condition", value="Disease")
session.add(attr_2_1)
session.add(attr_2_2)

session.commit()

# Download samples as TSV
response = client.get(
f"/api/v1/projects/{new_project.project_id}/samples/download",
)
assert response.status_code == 200
assert response.headers["content-type"] == "text/tab-separated-values; charset=utf-8"
content = response.content.decode("utf-8")
lines = [line.strip() for line in content.strip().split("\n")]
assert lines[0] == "project_id\tsample_id\tCondition\tTissue"
assert f"{new_project.project_id}\tSample_1\tHealthy\tLiver" in lines
assert f"{new_project.project_id}\tSample_2\tDisease\tHeart" in lines


def test_upload_samples_tsv(client: TestClient, session: Session):
"""
Test that we can upload samples via a TSV file
"""
# Add a project to the database
new_project = Project(name="Test Project")
new_project.project_id = generate_project_id(session=session)
new_project.attributes = []
session.add(new_project)
session.commit()

# Create TSV content
tsv_content = (
"project_id\tsample_id\tTissue\tCondition\n"
f"{new_project.project_id}\tSample_1\tLiver\tHealthy\n"
f"{new_project.project_id}\tSample_2\tHeart\tDisease\n"
)

# Upload samples via TSV
response = client.post(
f"/api/v1/projects/{new_project.project_id}/samples/upload",
files={"file": ("samples.tsv", tsv_content, "text/tab-separated-values")},
)
assert response.status_code == 201
response_data = response.json()
assert len(response_data["data"]) == 2

# Verify that samples were added correctly
sample_ids = {sample["sample_id"] for sample in response_data["data"]}
assert "Sample_1" in sample_ids
assert "Sample_2" in sample_ids