diff --git a/api/project/routes.py b/api/project/routes.py index 113d03c..462ceec 100755 --- a/api/project/routes.py +++ b/api/project/routes.py @@ -3,7 +3,8 @@ """ from typing import Literal -from fastapi import APIRouter, Query, status +from fastapi import APIRouter, Query, status, UploadFile, File +from fastapi.responses import StreamingResponse from core.deps import SessionDep, OpenSearchDep from api.project.models import Project, ProjectCreate, ProjectPublic, ProjectsPublic from api.samples.models import SampleCreate, SamplePublic, SamplesPublic, Attribute @@ -162,6 +163,53 @@ def get_samples( ) +@router.get( + "/{project_id}/samples/download", + response_class=StreamingResponse, + response_model=None, + status_code=status.HTTP_200_OK, + tags=["Sample Endpoints"], +) +def download_samples( + session: SessionDep, + project_id: str, +) -> StreamingResponse: + """ + Download all samples as a TSV for a given project. + """ + return sample_services.download_samples_as_tsv( + session=session, + project_id=project_id, + ) + + +@router.post( + "/{project_id}/samples/upload", + response_model=SamplesPublic, + tags=["Sample Endpoints"], + status_code=status.HTTP_201_CREATED, +) +async def upload_samples_to_project( + session: SessionDep, + opensearch_client: OpenSearchDep, + project_id: str, + file: UploadFile = File(...), +) -> SamplesPublic: + """ + Upload samples from a TSV file to a specific project. + """ + # Read file content + content = await file.read() + tsv_content = content.decode("utf-8") + + return sample_services.upload_samples_from_tsv( + session=session, + opensearch_client=opensearch_client, + project_id=project_id, + tsv_content=tsv_content, + ) + + @router.put( "/{project_id}/samples/{sample_id}", response_model=SamplePublic, diff --git a/api/samples/services.py b/api/samples/services.py index 942b092..8d03483 100644 --- a/api/samples/services.py +++ b/api/samples/services.py @@ -1,23 +1,30 @@ -from fastapi import HTTPException, status +""" +Service functions for managing samples within projects. +""" from typing import Literal +from io import StringIO +import csv + +from fastapi import HTTPException, status +from fastapi.responses import StreamingResponse from pydantic import PositiveInt from sqlmodel import Session, select, func +from opensearchpy import OpenSearch from api.samples.models import ( + Attribute, Sample, SampleAttribute, SampleCreate, SamplePublic, SamplesPublic, - Attribute, ) from api.project.models import Project from api.search.models import ( SearchDocument, ) -from opensearchpy import OpenSearch from api.search.services import add_object_to_index @@ -80,6 +87,64 @@ def add_sample_to_project( return sample +def download_samples_as_tsv( + session: Session, + project_id: str, +) -> StreamingResponse: + """ + Download samples for a specific project as a TSV file. + + Args: + session: Database session + project_id: Project ID to filter samples by + Returns: + StreamingResponse: TSV file response containing samples + """ + + # Query samples for the project + samples = session.exec( + select(Sample).where(Sample.project_id == project_id) + ).all() + + if not samples: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"No samples found for project {project_id}.", + ) + + # Collect all unique attribute keys across all samples for header + all_keys = set() + for sample in samples: + if sample.attributes: + for attr in sample.attributes: + all_keys.add(attr.key) + attribute_keys = sorted(list(all_keys)) + + # Create TSV in memory + output = StringIO() + writer = csv.writer(output, delimiter="\t") + + # Write header + header = ["project_id", "sample_id"] + attribute_keys + writer.writerow(header) + + # Write sample rows + for sample in samples: + row = [project_id, sample.sample_id] + attr_dict = {attr.key: attr.value for attr in (sample.attributes or [])} + for key in attribute_keys: + row.append(attr_dict.get(key, "")) + writer.writerow(row) + + output.seek(0) + + return StreamingResponse( + output, + media_type="text/tab-separated-values", + headers={"Content-Disposition": f"attachment; filename={project_id}_samples.tsv"}, + ) + + def get_samples( *, session: Session, @@ -226,3 +291,130 @@ def update_sample_in_project( Attribute(key=attr.key, value=attr.value) for attr in (sample.attributes or []) ] if sample.attributes else [] ) + + +def upload_samples_from_tsv( + session: Session, + opensearch_client: OpenSearch, + project_id: str, + tsv_content: str, +) -> SamplesPublic: + """ + Upload samples to a specific project from a TSV file content. + + Args: + session: Database session + opensearch_client: OpenSearch client for indexing + project_id: Project ID to associate samples with + tsv_content: Content of the TSV file as a string + Returns: + SamplesPublic: List of created samples + """ + # Check if project exists + project = session.exec( + select(Project).where(Project.project_id == project_id) + ).first() + if not project: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Project {project_id} not found.", + ) + + # Parse TSV content + # If tsv_content is bytes, decode it + if isinstance(tsv_content, bytes): + tsv_content = tsv_content.decode("utf-8") + + tsv_io = StringIO(tsv_content) + reader = csv.DictReader(tsv_io, delimiter="\t") + + # Validate header + if not reader.fieldnames: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="TSV file is empty or has no header.", + ) + + # Expected columns: project_id, sample_id, and any number of attribute columns + if "project_id" not in reader.fieldnames or "sample_id" not in reader.fieldnames: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="TSV file must contain 'project_id' and 'sample_id' columns.", + ) + + # Get attribute column names (all columns except project_id and sample_id) + attribute_keys = [ + col for col in reader.fieldnames + if col not in ["project_id", "sample_id"] + ] + + # Create samples from TSV rows + created_samples = [] + for row in reader: + # Validate project_id matches + row_project_id = row.get("project_id", "").strip() + if row_project_id != project_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=( + f"Row project_id '{row_project_id}' does not match " + f"URL project_id '{project_id}'." + ), + ) + + sample_id = row.get("sample_id", "").strip() + if not sample_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Each row must have a non-empty sample_id.", + ) + + # Build attributes from remaining columns + attributes = [] + for key in attribute_keys: + value = row.get(key, "").strip() + if value: # Only add non-empty attributes + attributes.append(Attribute(key=key, value=value)) + + # Create sample using existing service function + sample_in = SampleCreate( + sample_id=sample_id, + attributes=attributes if attributes else None, + ) + + sample = add_sample_to_project( + session=session, + opensearch_client=opensearch_client, + project_id=project_id, + sample_in=sample_in, + ) + created_samples.append(sample) + + # Convert to SamplePublic format + public_samples = [ + SamplePublic( + sample_id=sample.sample_id, + project_id=sample.project_id, + attributes=sample.attributes, + ) + for sample in created_samples + ] + + # Collect all unique attribute keys for data_cols + all_keys = set() + for sample in created_samples: + if sample.attributes: + for attr in sample.attributes: + all_keys.add(attr.key) + data_cols = sorted(list(all_keys)) if all_keys else None + + return SamplesPublic( + data=public_samples, + data_cols=data_cols, + total_items=len(public_samples), + total_pages=1, + current_page=1, + per_page=len(public_samples), + has_next=False, + has_prev=False, + ) diff --git a/tests/api/test_samples.py b/tests/api/test_samples.py index c27dd15..020cb15 100644 --- a/tests/api/test_samples.py +++ b/tests/api/test_samples.py @@ -319,3 +319,83 @@ def test_update_sample_attribute(client: TestClient, session: Session): attr["key"] == "Condition" and attr["value"] == "Diseased" for attr in response.json()["attributes"] ) + + +def test_download_samples_tsv(client: TestClient, session: Session): + """ + Test that we can download samples as a TSV file + """ + # Add a project to the database + new_project = Project(name="Test Project") + new_project.project_id = generate_project_id(session=session) + new_project.attributes = [] + session.add(new_project) + + # Add sample 1 + sample_1 = Sample(sample_id="Sample_1", project_id=new_project.project_id) + session.add(sample_1) + session.flush() # Flush to get the sample ID for attributes + + # Add attributes for Sample_1 + attr_1_1 = SampleAttribute(sample_id=sample_1.id, key="Tissue", value="Liver") + attr_1_2 = SampleAttribute(sample_id=sample_1.id, key="Condition", value="Healthy") + session.add(attr_1_1) + session.add(attr_1_2) + + # Add sample 2 + sample_2 = Sample(sample_id="Sample_2", project_id=new_project.project_id) + session.add(sample_2) + session.flush() + + # Add attributes for Sample_2 + attr_2_1 = SampleAttribute(sample_id=sample_2.id, key="Tissue", value="Heart") + attr_2_2 = SampleAttribute(sample_id=sample_2.id, key="Condition", value="Disease") + session.add(attr_2_1) + session.add(attr_2_2) + + session.commit() + + # Download samples as TSV + response = client.get( + f"/api/v1/projects/{new_project.project_id}/samples/download", + ) + assert response.status_code == 200 + assert response.headers["content-type"] == "text/tab-separated-values; charset=utf-8" + content = response.content.decode("utf-8") + lines = [line.strip() for line in content.strip().split("\n")] + assert lines[0] == "project_id\tsample_id\tCondition\tTissue" + assert f"{new_project.project_id}\tSample_1\tHealthy\tLiver" in lines + assert f"{new_project.project_id}\tSample_2\tDisease\tHeart" in lines + + +def test_upload_samples_tsv(client: TestClient, session: Session): + """ + Test that we can upload samples via a TSV file + """ + # Add a project to the database + new_project = Project(name="Test Project") + new_project.project_id = generate_project_id(session=session) + new_project.attributes = [] + session.add(new_project) + session.commit() + + # Create TSV content + tsv_content = ( + "project_id\tsample_id\tTissue\tCondition\n" + f"{new_project.project_id}\tSample_1\tLiver\tHealthy\n" + f"{new_project.project_id}\tSample_2\tHeart\tDisease\n" + ) + + # Upload samples via TSV + response = client.post( + f"/api/v1/projects/{new_project.project_id}/samples/upload", + files={"file": ("samples.tsv", tsv_content, "text/tab-separated-values")}, + ) + assert response.status_code == 201 + response_data = response.json() + assert len(response_data["data"]) == 2 + + # Verify that samples were added correctly + sample_ids = {sample["sample_id"] for sample in response_data["data"]} + assert "Sample_1" in sample_ids + assert "Sample_2" in sample_ids