Skip to content
Merged

V1 #2

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions backend/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
GOOGLE_API_KEY=
ADMIN_ACCESS=true
SUPABASE_URL=
SUPABASE_KEY=
81 changes: 80 additions & 1 deletion backend/app/controllers/image.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,79 @@
import logging

from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, BackgroundTasks, Header, HTTPException

from app.models.image import ImageGenerationRequest, ImageGenerationResponse
from app.services.image import ImageService
from app.utils.database import db_client
from app.utils.storage import save_image_pair_to_db, upload_image_to_storage

log = logging.getLogger(__name__)


async def save_images_to_database(
authorization: str,
project_id: str,
input_image_data: str,
output_image_data: str,
prompt_text: str,
):
"""
Background task to upload images to storage and save the pair to database.
"""
try:
log.info(f"Starting background task to save images for project {project_id}")

# Skip if no input image data (can't save without input)
if not input_image_data:
log.warning("No input image data provided, skipping database save")
return

# Extract token from authorization header
token = authorization.replace("Bearer ", "") if authorization else ""

# Get database client
supabase_client = await db_client(token=token)

# Upload input image to storage
input_url, input_mime_type, input_width, input_height = (
await upload_image_to_storage(
supabase_client=supabase_client,
image_data=input_image_data,
folder="image_pairs/input",
)
)

# Upload output image to storage
output_url, output_mime_type, output_width, output_height = (
await upload_image_to_storage(
supabase_client=supabase_client,
image_data=output_image_data,
folder="image_pairs/output",
)
)

# Save image pair to database
await save_image_pair_to_db(
supabase_client=supabase_client,
project_id=project_id,
input_url=input_url,
input_mime_type=input_mime_type,
input_width=input_width,
input_height=input_height,
output_url=output_url,
output_mime_type=output_mime_type,
output_width=output_width,
output_height=output_height,
prompt_text=prompt_text,
)

log.info(f"Successfully saved image pair for project {project_id}")

except Exception as e:
log.error(f"Error in background task save_images_to_database: {e}")
# Don't raise - background tasks should not affect the response


class ImageController:
def __init__(self, service: ImageService):
self.router = APIRouter()
Expand All @@ -23,13 +89,26 @@ def setup_routes(self):
)
async def generate_image(
input: ImageGenerationRequest,
background_tasks: BackgroundTasks,
authorization: str = Header(None),
) -> ImageGenerationResponse:
log.info(f"Generating image with prompt: {input.prompt}")
try:
response: ImageGenerationResponse = await self.service.generate_image(
input=input
)
log.info("Image generation completed successfully")

# Add background task to save images to database
background_tasks.add_task(
save_images_to_database,
authorization=authorization,
project_id=input.project_id,
input_image_data=input.image_data,
output_image_data=response.image_data,
prompt_text=input.prompt,
)

return response
except ValueError as e:
log.error(f"Validation error: {e}")
Expand Down
3 changes: 3 additions & 0 deletions backend/app/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class ImageGenerationRequest(BaseModel):
default=None,
description="Optional base64 encoded image data to use as input for image generation.",
)
project_id: str = Field(
description="The project ID to associate with this image pair."
)


class ImageGenerationResponse(BaseModel):
Expand Down
5 changes: 4 additions & 1 deletion backend/app/services/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ async def generate_image(
"""
log.info(f"Generating image with prompt: {input.prompt}")

prompt = f"Generate an whiteboard drawingimage based on the following prompt and the reference image: {input.prompt}"
prompt = f"""Generate a drawing image based on the following prompt and the reference image.
The image background should match the reference image background. You are just drawing diagrams, so make sure you are not over verbose.
But if there are paragraphs in the image, you should keep them there.
Prompt: {input.prompt}"""

# Prepare reference image if provided
reference_image = None
Expand Down
Empty file added backend/app/utils/__init__.py
Empty file.
65 changes: 65 additions & 0 deletions backend/app/utils/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import os
import uuid

from supabase import AsyncClientOptions

# from supabase import AsyncClientOptions
from supabase._async.client import AsyncClient as Client
from supabase._async.client import create_client

# Import config to ensure environment variables are loaded

log = logging.getLogger(__name__)


def is_valid_uuid(value):
try:
uuid.UUID(str(value))
return True
except ValueError:
return False


def _get_required_env_var(var_name: str) -> str:
"""Get a required environment variable with proper error handling."""
value = os.environ.get(var_name)
if not value:
raise ValueError(
f"Required environment variable '{var_name}' is not set or is empty"
)
return value


async def db_client(
token: str,
) -> Client:
try:
supabase_url = _get_required_env_var("SUPABASE_URL")
supabase_key = _get_required_env_var("SUPABASE_KEY")
except ValueError as e:
log.error(f"Failed to load required environment variables: {e}")
raise

"""
Note that if we set ADMIN_ACCESS to true, there won't be an org_id associated with the db request, which might be a cause of problem when the entry requires org_id to be non-null.
"""

# Development
if os.environ.get("ADMIN_ACCESS") == "true":
return await create_client(
supabase_url=supabase_url,
supabase_key=supabase_key,
)

# Production
return await create_client(
supabase_url=supabase_url,
supabase_key=supabase_key,
options=AsyncClientOptions(
headers={
"Authorization": f"Bearer {token}",
"apiKey": supabase_key,
}
),
)
117 changes: 117 additions & 0 deletions backend/app/utils/storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import base64
import logging
import uuid
from io import BytesIO
from typing import Tuple

from PIL import Image
from supabase._async.client import AsyncClient as Client

log = logging.getLogger(__name__)


async def upload_image_to_storage(
supabase_client: Client,
image_data: str,
bucket_name: str = "whisprdraw",
folder: str = "image_pairs",
) -> Tuple[str, str, int, int]:
"""
Upload a base64 encoded image to Supabase storage.

Args:
supabase_client: The Supabase client instance
image_data: Base64 encoded image data
bucket_name: The name of the storage bucket
folder: The folder path within the bucket

Returns:
Tuple of (public_url, mime_type, width, height)
"""
try:
# Decode base64 image
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))

# Get image properties
width, height = image.size
mime_type = f"image/{image.format.lower()}" if image.format else "image/png"

# Generate unique filename
file_extension = image.format.lower() if image.format else "png"
filename = f"{folder}/{uuid.uuid4()}.{file_extension}"

# Upload to Supabase storage
response = await supabase_client.storage.from_(bucket_name).upload(
path=filename,
file=image_bytes,
file_options={"content-type": mime_type},
)

# Get public URL
public_url_response = await supabase_client.storage.from_(
bucket_name
).get_public_url(filename)

log.info(f"Successfully uploaded image to {public_url_response}")
return public_url_response, mime_type, width, height

except Exception as e:
log.error(f"Error uploading image to storage: {e}")
raise RuntimeError(f"Failed to upload image: {e}")


async def save_image_pair_to_db(
supabase_client: Client,
project_id: str,
input_url: str,
input_mime_type: str,
input_width: int,
input_height: int,
output_url: str,
output_mime_type: str,
output_width: int,
output_height: int,
prompt_text: str,
metadata: dict = None,
):
"""
Save an image pair record to the database.

Args:
supabase_client: The Supabase client instance
project_id: The project ID
input_url: URL to the input image
input_mime_type: MIME type of the input image
input_width: Width of the input image
input_height: Height of the input image
output_url: URL to the output/generated image
output_mime_type: MIME type of the output image
output_width: Width of the output image
output_height: Height of the output image
prompt_text: The prompt used for generation
metadata: Optional metadata dictionary
"""
try:
data = {
"project_id": project_id,
"input_url": input_url,
"input_mime_type": input_mime_type,
"input_width": input_width,
"input_height": input_height,
"output_url": output_url,
"output_mime_type": output_mime_type,
"output_width": output_width,
"output_height": output_height,
"prompt_text": prompt_text,
"metadata": metadata,
}

response = await supabase_client.table("image_pairs").insert(data).execute()

log.info(f"Successfully saved image pair to database: {response.data}")
return response.data

except Exception as e:
log.error(f"Error saving image pair to database: {e}")
raise RuntimeError(f"Failed to save image pair: {e}")
Loading
Loading