Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,14 @@ func (h *ChatHandler) CreateChatCompletion(

// Get provider based on the requested model
observability.AddSpanEvent(ctx, "selecting_provider")
selectedProviderModel, selectedProvider, err := h.providerHandler.SelectProviderModelForModelPublicID(ctx, request.Model)
isAPIKeyAuth := strings.EqualFold(reqCtx.GetHeader("X-Auth-Method"), "apikey")
var selectedProviderModel *domainmodel.ProviderModel
var selectedProvider *domainmodel.Provider
if isAPIKeyAuth {
selectedProviderModel, selectedProvider, err = h.providerHandler.SelectProviderModelForModelPublicIDIncludingInactive(ctx, request.Model)
} else {
selectedProviderModel, selectedProvider, err = h.providerHandler.SelectProviderModelForModelPublicID(ctx, request.Model)
}
if err != nil {
observability.RecordError(ctx, err)
return nil, platformerrors.AsError(ctx, platformerrors.LayerHandler, err, "failed to select provider model")
Expand Down Expand Up @@ -220,7 +227,6 @@ func (h *ChatHandler) CreateChatCompletion(
// Check if we should use the instruct model instead
// This happens when enable_thinking is explicitly false and the model has an instruct model configured
// Skip instruct fallback for API key authentication (API users should get the model they requested)
isAPIKeyAuth := strings.EqualFold(reqCtx.GetHeader("X-Auth-Method"), "apikey")
if !isAPIKeyAuth && request.EnableThinking != nil && !*request.EnableThinking && selectedProviderModel.InstructModelID != nil && !imageRequested {
instructModel, instructProvider, err := h.providerHandler.GetProviderModelByID(ctx, *selectedProviderModel.InstructModelID)
if err == nil && instructModel != nil && instructProvider != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"go.opentelemetry.io/otel/codes"

"jan-server/services/llm-api/internal/domain/conversation"
domainmodel "jan-server/services/llm-api/internal/domain/model"
"jan-server/services/llm-api/internal/domain/tokenusage"
"jan-server/services/llm-api/internal/infrastructure/inference"
"jan-server/services/llm-api/internal/infrastructure/metrics"
Expand Down Expand Up @@ -74,7 +75,15 @@ func (h *MessagesHandler) CreateMessage(ctx context.Context, reqCtx *gin.Context

// Get provider for the requested model
observability.AddSpanEvent(ctx, "selecting_provider")
selectedProviderModel, selectedProvider, err := h.providerHandler.SelectProviderModelForModelPublicID(ctx, request.Model)
isAPIKeyAuth := strings.EqualFold(reqCtx.GetHeader("X-Auth-Method"), "apikey")
var selectedProviderModel *domainmodel.ProviderModel
var selectedProvider *domainmodel.Provider
var err error
if isAPIKeyAuth {
selectedProviderModel, selectedProvider, err = h.providerHandler.SelectProviderModelForModelPublicIDIncludingInactive(ctx, request.Model)
} else {
selectedProviderModel, selectedProvider, err = h.providerHandler.SelectProviderModelForModelPublicID(ctx, request.Model)
}
if err != nil {
observability.RecordError(ctx, err)
return h.writeErrorResponse(reqCtx, http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Model not found: %s", request.Model))
Expand Down Expand Up @@ -450,7 +459,14 @@ func (h *MessagesHandler) CountTokens(ctx context.Context, reqCtx *gin.Context,
)

// Get provider for the requested model to validate it exists
selectedProviderModel, _, err := h.providerHandler.SelectProviderModelForModelPublicID(ctx, request.Model)
isAPIKeyAuth := strings.EqualFold(reqCtx.GetHeader("X-Auth-Method"), "apikey")
var selectedProviderModel *domainmodel.ProviderModel
var err error
if isAPIKeyAuth {
selectedProviderModel, _, err = h.providerHandler.SelectProviderModelForModelPublicIDIncludingInactive(ctx, request.Model)
} else {
selectedProviderModel, _, err = h.providerHandler.SelectProviderModelForModelPublicID(ctx, request.Model)
}
if err != nil || selectedProviderModel == nil {
return h.writeErrorResponse(reqCtx, http.StatusBadRequest, "invalid_request_error", fmt.Sprintf("Model not found: %s", request.Model))
}
Expand Down
Loading