Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion pkg-py/src/shinychat/_chat_normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,8 @@ def _(chunk: RawMessageStreamEvent):
# ------------------------------------------------------------------

try:
from google.generativeai.types.generation_types import (
from google.genai.types import (
Content,
GenerateContentResponse,
)

Expand All @@ -305,6 +306,19 @@ def _(message: GenerateContentResponse):
def _(chunk: GenerateContentResponse):
return ChatMessage(content=chunk.text)

@message_content.register
def _(message: Content):
content = ""
for part in message.parts:
if hasattr(part, "text") and part.text:
content += part.text
return ChatMessage(content=content, role=message.role or "model")

@message_content_chunk.register
def _(chunk: Content):
# reuse the message logic
return message_content(chunk)

except ImportError:
pass

Expand Down
16 changes: 8 additions & 8 deletions pkg-py/src/shinychat/_chat_provider_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
ChatCompletionUserMessageParam,
)

if sys.version_info >= (3, 9):
import google.generativeai.types as gtypes # pyright: ignore[reportMissingTypeStubs]
if sys.version_info >= (3, 10):
import google.genai.types as gtypes # pyright: ignore[reportMissingImports]

GoogleMessage = gtypes.ContentDict
GoogleMessage = gtypes.Content
else:
GoogleMessage = object

Expand Down Expand Up @@ -81,20 +81,20 @@ def as_anthropic_message(message: ChatMessageDict) -> "AnthropicMessage":


def as_google_message(message: ChatMessageDict) -> "GoogleMessage":
if sys.version_info < (3, 9):
raise ValueError("Google requires Python 3.9")
if sys.version_info < (3, 10):
raise ValueError("Google requires Python 3.10")

import google.generativeai.types as gtypes # pyright: ignore[reportMissingTypeStubs]
import google.genai.types as gtypes # pyright: ignore[reportMissingImports]

role = message["role"]

if role == "system":
raise ValueError(
"Google requires a system prompt to be specified in the `GenerativeModel()` constructor."
"Google requires a system prompt to be specified with `GenerateContentConfig.system_instruction`."
)
elif role == "assistant":
role = "model"
return gtypes.ContentDict(parts=[message["content"]], role=role)
return gtypes.Content(parts=[gtypes.Part(text=message["content"])], role=role)


def as_langchain_message(message: ChatMessageDict) -> "LangChainMessage":
Expand Down
82 changes: 51 additions & 31 deletions pkg-py/tests/pytest/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
import sys
import types
from datetime import datetime
Expand Down Expand Up @@ -260,25 +261,55 @@ def test_langchain_normalization():
assert m.role == "assistant"


def test_google_normalization():
# Not available for Python 3.8
if sys.version_info < (3, 9):
def test_google_content_object_normalization():
# Not available for Python 3.9
if sys.version_info < (3, 10):
return

from google.genai import types

# Test Content object normalization
c = types.Content(parts=[types.Part(text="Hello world!")], role="model")
m = message_content(c)
assert m.content == "Hello world!"
assert m.role == "model"


def test_google_multimodal_normalization():
# Not available for Python 3.9
if sys.version_info < (3, 10):
return

from google.generativeai.generative_models import (
GenerativeModel, # pyright: ignore[reportMissingTypeStubs]
from google.genai import types

# Text part, image part, text part.
c = types.Content(
parts=[
types.Part(text="Here is an image:"),
types.Part(inline_data={"mime_type": "image/png", "data": "AAAA"}),
types.Part(text=" described above."),
],
role="model",
)

generate_content = GenerativeModel.generate_content # type: ignore
m = message_content(c)
assert m.content == "Here is an image: described above."
assert m.role == "model"


def test_google_normalization():
# Not available for Python 3.9
if sys.version_info < (3, 10):
return

from google.genai.models import Models
from google.genai.types import GenerateContentResponse

assert (
generate_content.__annotations__["return"]
== "generation_types.GenerateContentResponse"
inspect.signature(Models.generate_content).return_annotation
== GenerateContentResponse
)

# Not worth mocking the return value of generate_content() since it's a complex object
# and fairly simple to normalize....


def test_anthropic_normalization():
if sys.version_info < (3, 11):
Expand Down Expand Up @@ -480,32 +511,21 @@ def test_as_anthropic_message():
def test_as_google_message():
from shinychat._chat_provider_types import as_google_message

# Not available for Python 3.8
if sys.version_info < (3, 9):
# Not available for Python 3.9
if sys.version_info < (3, 10):
return

from google.generativeai.generative_models import (
GenerativeModel, # pyright: ignore[reportMissingTypeStubs]
)

generate_content = GenerativeModel.generate_content # type: ignore
from google.genai import types
from google.genai.models import Models

assert (
generate_content.__annotations__["contents"]
== "content_types.ContentsType"
)

from google.generativeai.types import (
content_types, # pyright: ignore[reportMissingTypeStubs]
)

assert is_type_in_union(
content_types.ContentDict, content_types.ContentsType
contents_annotation = (
inspect.signature(Models.generate_content).parameters["contents"].annotation
)
assert is_type_in_union(types.Content, contents_annotation)

msg = ChatMessageDict(content="I have a question", role="user")
assert as_google_message(msg) == content_types.ContentDict(
parts=["I have a question"], role="user"
assert as_google_message(msg) == types.Content(
parts=[types.Part(text="I have a question")], role="user"
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ providers = [
"anthropic;python_version>='3.11'",
"chatlas[mcp]>=0.12.0",
"pydantic",
"google-generativeai",
"google-genai",
"langchain-core>=1.0.0",
"ollama>=0.4.0",
"openai",
Expand Down