-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathapi.py
More file actions
191 lines (158 loc) · 6.45 KB
/
api.py
File metadata and controls
191 lines (158 loc) · 6.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#
# Gemini ImageGen v0.1.3
#
# GUI for Gemini Image Generation using PySide6
# (C) Copyright 2025 Mika Jussila
#
import os
import io
from typing import Optional, Tuple
from google import genai
from google.genai import types
from PIL import Image, ImageDraw, ImageFont
class BillingRequired(Exception):
"""Raised when the Imagen API requires a billed Google account."""
pass
def generate_image(prompt: str, aspect_ratio: str = "1:1") -> Tuple[Optional[bytes], str]:
"""
Generate an image using Google Imagen API.
Uses the GEMINI_API_KEY environment variable for authentication.
Falls back to a local placeholder if the key is missing or the API call fails.
Args:
prompt: Text description of the image to generate.
aspect_ratio: Aspect ratio as string (e.g., "1:1", "16:9", "9:16", "4:3", "3:4").
"""
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
return _placeholder_image("(No GEMINI_API_KEY set)"), "placeholder"
try:
# Configure the Gemini client (the client will read GEMINI_API_KEY if available)
client = genai.Client(api_key=api_key)
# Auto-discover image-capable models available to this API key
candidate_models = []
try:
models = client.models.list()
model_names = []
for m in models:
name = getattr(m, 'name', None) or getattr(m, 'model', None) or str(m)
model_names.append(name)
# Heuristic: choose models whose name mentions 'imagen' or 'image'
for n in model_names:
ln = n.lower()
if 'imagen' in ln or ('gemini' in ln and 'image' in ln):
candidate_models.append(n)
# Prioritize gemini-2.5-flash-image (most reliable, no billing required)
preferred = 'models/gemini-2.5-flash-image'
if preferred in candidate_models:
candidate_models.insert(0, candidate_models.pop(candidate_models.index(preferred)))
except Exception:
# If discovery fails, fall back to known models
pass
# Determine which models to try: prefer detected candidate models, else fall back to known models
if candidate_models:
models_to_try = candidate_models
else:
models_to_try = [
"models/gemini-2.5-flash-image",
"models/imagen-4.0-generate-001",
"models/imagen-4.0-fast-generate-001",
]
for model in models_to_try:
try:
response = client.models.generate_images(
model=model,
prompt=prompt,
config=types.GenerateImagesConfig(
number_of_images=1,
aspect_ratio=aspect_ratio,
),
)
if getattr(response, 'generated_images', None):
generated_image = response.generated_images[0]
image_bytes = None
# Extract bytes: try image_bytes first (raw API bytes), then PIL Image fallback
if hasattr(generated_image, 'image') and hasattr(generated_image.image, 'image_bytes'):
image_bytes = generated_image.image.image_bytes
elif hasattr(generated_image, 'image') and hasattr(generated_image.image, 'save'):
bio = io.BytesIO()
generated_image.image.save(bio, format="PNG")
bio.seek(0)
image_bytes = bio.getvalue()
if image_bytes:
return image_bytes, model
except Exception as model_error:
msg = str(model_error)
if "billing" in msg.lower() or "billed" in msg.lower():
raise BillingRequired("Imagen requires a billed Google account. Enable billing in AI Studio.")
continue
# If all models failed, fall back to placeholder with informative text
return _placeholder_image(f"No images generated for: {prompt[:40]} (check model access/billing)"), "placeholder"
except Exception as e:
# If this is a billing-related signal, re-raise so the GUI can handle it
if isinstance(e, BillingRequired):
raise
error_msg = str(e)[:100]
return _placeholder_image(f"(API error: {error_msg})"), "placeholder"
def _placeholder_image(text: str) -> bytes:
w, h = 1024, 768
img = Image.new("RGB", (w, h), color=(28, 28, 30))
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype("arial.ttf", 28)
except Exception:
font = ImageFont.load_default()
margin = 20
lines = _wrap_text(draw, text, font, w - 2 * margin)
y = margin
for line in lines:
draw.text((margin, y), line, font=font, fill=(220, 220, 220))
# compute line height in a way compatible with multiple Pillow versions
_, line_h = _text_size(draw, line, font)
y += line_h + 8
bio = io.BytesIO()
img.save(bio, format="PNG")
return bio.getvalue()
def _wrap_text(draw: ImageDraw.Draw, text: str, font: ImageFont.ImageFont, max_width: int):
words = text.split()
lines = []
cur = ""
for w in words:
test = (cur + " " + w).strip()
test_w, _ = _text_size(draw, test, font)
if test_w <= max_width:
cur = test
else:
if cur:
lines.append(cur)
cur = w
if cur:
lines.append(cur)
return lines
def _text_size(draw: ImageDraw.Draw, text: str, font: ImageFont.ImageFont):
"""Return (width, height) of text using available Pillow APIs.
Tries multiple methods for compatibility across Pillow versions.
"""
# Preferred: ImageDraw.textsize
try:
size = draw.textsize(text, font=font)
return size
except Exception:
pass
# Newer Pillow: ImageDraw.textbbox
try:
bbox = draw.textbbox((0, 0), text, font=font)
return (bbox[2] - bbox[0], bbox[3] - bbox[1])
except Exception:
pass
# Fallback to font methods
try:
size = font.getsize(text)
return size
except Exception:
pass
try:
bbox = font.getbbox(text)
return (bbox[2] - bbox[0], bbox[3] - bbox[1])
except Exception:
# Last resort: estimate
return (len(text) * 7, 16)