Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,40 @@ fmt.Println(result.Response.Content.Text())

🍔 For the full implementation and more [see the examples directory](https://github.com/charmbracelet/fantasy/tree/main/examples).

## Embeddings

Fantasy supports embeddings for OpenAI and OpenAI-compatible providers (e.g. OpenRouter, Azure, Together).
Providers that support embeddings implement `fantasy.EmbeddingProvider`.

```go
provider, err := openai.New(openai.WithAPIKey(myHotKey))
if err != nil {
fmt.Fprintln(os.Stderr, "Whoops:", err)
os.Exit(1)
}

embedProvider, ok := provider.(fantasy.EmbeddingProvider)
if !ok {
fmt.Fprintln(os.Stderr, "Embeddings not supported by this provider")
os.Exit(1)
}

embedModel, err := embedProvider.EmbeddingModel(ctx, "text-embedding-3-small")
if err != nil {
fmt.Fprintln(os.Stderr, "Dang:", err)
os.Exit(1)
}

input := "hello embeddings"
embeds, err := embedModel.Embed(ctx, fantasy.EmbeddingCall{Inputs: []string{input}})
if err != nil {
fmt.Fprintln(os.Stderr, "Oof:", err)
os.Exit(1)
}

fmt.Println(len(embeds.Embeddings[0].Vector))
```

## Multi-model? Multi-provider?

Yeah! Fantasy is designed to support a wide variety of providers and models under a single API. While many providers such as Microsoft Azure, Amazon Bedrock, and OpenRouter have dedicated packages in Fantasy, many others work just fine with `openaicompat`, the generic OpenAI-compatible layer. That said, if you find a provider that’s not compatible and needs special treatment, please let us know in an issue (or open a PR).
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
// Package fantasy provides a unified interface for interacting with various AI language models.
// Package fantasy provides a unified interface for interacting with various AI language and embedding models.
package fantasy
63 changes: 63 additions & 0 deletions embedding.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package fantasy

import (
"context"
"fmt"
)

// EmbeddingProvider represents a provider that can create embedding models.
// This is separate from Provider to avoid breaking changes.
type EmbeddingProvider interface {
EmbeddingModel(ctx context.Context, modelID string) (EmbeddingModel, error)
}

// EmbeddingModel represents a model that can generate embeddings.
type EmbeddingModel interface {
Embed(context.Context, EmbeddingCall) (*EmbeddingResponse, error)

Provider() string
Model() string
}

// EmbeddingCall represents a request to generate embeddings.
// Inputs must include at least one non-empty item.
type EmbeddingCall struct {
Inputs []string `json:"inputs,omitempty"`
Dimensions *int64 `json:"dimensions,omitempty"`

ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
}

// Embedding represents a single embedding vector.
type Embedding struct {
Index int `json:"index"`
Vector []float32 `json:"vector"`
}

// EmbeddingResponse represents the response from an embedding model.
type EmbeddingResponse struct {
Model string `json:"model"`
Usage Usage `json:"usage"`
Embeddings []Embedding `json:"embeddings"`
}

// ValidateEmbeddingCall validates the embedding request parameters.
func ValidateEmbeddingCall(call EmbeddingCall) error {
if len(call.Inputs) == 0 {
return &Error{
Title: "invalid argument",
Message: "embedding inputs are required",
}
}

for i, input := range call.Inputs {
if input == "" {
return &Error{
Title: "invalid argument",
Message: fmt.Sprintf("embedding inputs[%d] cannot be empty", i),
}
}
}

return nil
}
29 changes: 29 additions & 0 deletions embedding_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package fantasy

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestValidateEmbeddingCall(t *testing.T) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tests, especially when introducing new APIs, Fantasy uses charm.land/x/vcr to record requests & replay real data.

See https://github.com/charmbracelet/fantasy/tree/main/providertests/testdata/TestOpenAICommon/openai-o4-mini for example

I've got some captures for embeddings already done: cbca0e5#diff-2599f9d193307dc4e06a66ec239c05f98fdf3ee32363d9eb3ece3ca18256f79c

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added VCR provider tests + captures.

t.Run("requires inputs", func(t *testing.T) {
err := ValidateEmbeddingCall(EmbeddingCall{})
require.Error(t, err)
})

t.Run("rejects empty inputs", func(t *testing.T) {
err := ValidateEmbeddingCall(EmbeddingCall{Inputs: []string{""}})
require.Error(t, err)
})

t.Run("accepts single input in inputs", func(t *testing.T) {
err := ValidateEmbeddingCall(EmbeddingCall{Inputs: []string{"hello"}})
require.NoError(t, err)
})

t.Run("accepts batch inputs", func(t *testing.T) {
err := ValidateEmbeddingCall(EmbeddingCall{Inputs: []string{"a", "b"}})
require.NoError(t, err)
})
}
1 change: 1 addition & 0 deletions provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
)

// Provider represents a provider of language models.
// Providers that support embeddings also implement EmbeddingProvider.
type Provider interface {
Name() string
LanguageModel(ctx context.Context, modelID string) (LanguageModel, error)
Expand Down
2 changes: 1 addition & 1 deletion providers/azure/azure.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package azure provides an implementation of the fantasy AI SDK for Azure's language models.
// Package azure provides an implementation of the fantasy AI SDK for Azure's language and embedding models.
package azure

import (
Expand Down
85 changes: 85 additions & 0 deletions providers/openai/embedding_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package openai

import (
"context"

"charm.land/fantasy"
"github.com/openai/openai-go/v2"
"github.com/openai/openai-go/v2/packages/param"
)

type embeddingModel struct {
provider string
modelID string
client openai.Client
}

// Model implements fantasy.EmbeddingModel.
func (e embeddingModel) Model() string {
return e.modelID
}

// Provider implements fantasy.EmbeddingModel.
func (e embeddingModel) Provider() string {
return e.provider
}

// Embed implements fantasy.EmbeddingModel.
func (e embeddingModel) Embed(ctx context.Context, call fantasy.EmbeddingCall) (*fantasy.EmbeddingResponse, error) {
if err := fantasy.ValidateEmbeddingCall(call); err != nil {
return nil, err
}

params := openai.EmbeddingNewParams{
Model: e.modelID,
}

if call.ProviderOptions != nil {
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok := v.(*ProviderOptions)
if !ok {
return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"}
}
if providerOptions.User != nil {
params.User = param.NewOpt(*providerOptions.User)
}
}
}

if call.Dimensions != nil {
params.Dimensions = param.NewOpt(*call.Dimensions)
}

params.Input = openai.EmbeddingNewParamsInputUnion{
OfArrayOfStrings: call.Inputs,
}

response, err := e.client.Embeddings.New(ctx, params)
if err != nil {
return nil, toProviderErr(err)
}

embeddings := make([]fantasy.Embedding, 0, len(response.Data))
for _, embedding := range response.Data {
vector := make([]float32, len(embedding.Embedding))
for i, value := range embedding.Embedding {
vector[i] = float32(value)
}
embeddings = append(embeddings, fantasy.Embedding{
Index: int(embedding.Index),
Vector: vector,
})
}

usage := fantasy.Usage{
InputTokens: response.Usage.PromptTokens,
TotalTokens: response.Usage.TotalTokens,
OutputTokens: 0,
}

return &fantasy.EmbeddingResponse{
Model: response.Model,
Usage: usage,
Embeddings: embeddings,
}, nil
}
33 changes: 32 additions & 1 deletion providers/openai/openai.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language models.
// Package openai provides an implementation of the fantasy AI SDK for OpenAI's language and embedding models.
package openai

import (
Expand Down Expand Up @@ -186,6 +186,37 @@ func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.Lan
), nil
}

// EmbeddingModel implements fantasy.EmbeddingProvider.
func (o *provider) EmbeddingModel(_ context.Context, modelID string) (fantasy.EmbeddingModel, error) {
openaiClientOptions := make([]option.RequestOption, 0, 5+len(o.options.headers)+len(o.options.sdkOptions))
openaiClientOptions = append(openaiClientOptions, option.WithMaxRetries(0))

if o.options.apiKey != "" {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(o.options.apiKey))
}
if o.options.baseURL != "" {
openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(o.options.baseURL))
}

for key, value := range o.options.headers {
openaiClientOptions = append(openaiClientOptions, option.WithHeader(key, value))
}

if o.options.client != nil {
openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(o.options.client))
}

openaiClientOptions = append(openaiClientOptions, o.options.sdkOptions...)

client := openai.NewClient(openaiClientOptions...)

return embeddingModel{
modelID: modelID,
provider: o.options.name,
client: client,
}, nil
}

func (o *provider) Name() string {
return Name
}
55 changes: 55 additions & 0 deletions providers/openai/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3247,3 +3247,58 @@ func TestResponsesToPrompt_DropsEmptyMessages(t *testing.T) {
require.Empty(t, warnings)
})
}

func TestOpenAIEmbeddings(t *testing.T) {
server := newMockServer()
defer server.close()

server.response = map[string]any{
"object": "list",
"model": "text-embedding-3-small",
"data": []map[string]any{
{
"object": "embedding",
"index": 0,
"embedding": []float64{0.1, -0.2},
},
},
"usage": map[string]any{
"prompt_tokens": 5,
"total_tokens": 5,
},
}

provider, err := New(WithBaseURL(server.server.URL))
require.NoError(t, err)

embeddingProvider, ok := provider.(fantasy.EmbeddingProvider)
require.True(t, ok)

model, err := embeddingProvider.EmbeddingModel(t.Context(), "text-embedding-3-small")
require.NoError(t, err)
require.Equal(t, "text-embedding-3-small", model.Model())
require.Equal(t, Name, model.Provider())

dims := int64(2)
user := "alice"
response, err := model.Embed(t.Context(), fantasy.EmbeddingCall{
Inputs: []string{"hello"},
Dimensions: &dims,
ProviderOptions: NewProviderOptions(&ProviderOptions{
User: fantasy.Opt(user),
}),
})
require.NoError(t, err)
require.Len(t, response.Embeddings, 1)
require.Equal(t, []float32{0.1, -0.2}, response.Embeddings[0].Vector)
require.Equal(t, int64(5), response.Usage.InputTokens)
require.Equal(t, int64(5), response.Usage.TotalTokens)

require.Len(t, server.calls, 1)
call := server.calls[0]
require.Equal(t, "/embeddings", call.path)
require.Equal(t, "text-embedding-3-small", call.body["model"])
require.Equal(t, []any{"hello"}, call.body["input"])
require.Equal(t, float64(2), call.body["dimensions"])
require.Equal(t, "alice", call.body["user"])
}
2 changes: 1 addition & 1 deletion providers/openaicompat/openaicompat.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package openaicompat provides an implementation of the fantasy AI SDK for OpenAI-compatible APIs.
// Package openaicompat provides an implementation of the fantasy AI SDK for OpenAI-compatible language and embedding APIs.
package openaicompat

import (
Expand Down
2 changes: 1 addition & 1 deletion providers/openrouter/openrouter.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package openrouter provides an implementation of the fantasy AI SDK for OpenRouter's language models.
// Package openrouter provides an implementation of the fantasy AI SDK for OpenRouter's language and embedding models.
package openrouter

import (
Expand Down
Loading