Skip to content

Commit 44af274

Browse files
committed
Merge branch 'dev' for v0.3.5 release
2 parents 3d83775 + 0fca120 commit 44af274

File tree

4 files changed

+282
-48
lines changed

4 files changed

+282
-48
lines changed

internal/adk/openai/convert.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,17 @@ func convertChatCompletionResponse(resp *openai.ChatCompletionResponse) (*model.
474474
}
475475
}
476476

477+
// 兼容旧版 function_call 字段
478+
if choice.Message.FunctionCall != nil && choice.Message.FunctionCall.Name != "" {
479+
content.Parts = append(content.Parts, &genai.Part{
480+
FunctionCall: &genai.FunctionCall{
481+
ID: "legacy_function_call",
482+
Name: choice.Message.FunctionCall.Name,
483+
Args: parseJSONArgs(choice.Message.FunctionCall.Arguments),
484+
},
485+
})
486+
}
487+
477488
// 处理 usage
478489
var usageMetadata *genai.GenerateContentResponseUsageMetadata
479490
if resp.Usage.TotalTokens > 0 {

internal/adk/openai/model.go

Lines changed: 144 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ package openai
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"fmt"
78
"io"
89
"iter"
9-
"slices"
10+
"strings"
1011

1112
"github.com/sashabaranov/go-openai"
1213
"google.golang.org/adk/model"
@@ -124,7 +125,7 @@ func (o *OpenAIModel) processStream(stream *openai.ChatCompletionStream, yield f
124125
}
125126
var finishReason genai.FinishReason
126127
var usageMetadata *genai.GenerateContentResponseUsageMetadata
127-
toolCallsMap := make(map[int]*toolCallBuilder)
128+
toolCalls := newChatStreamToolCallAggregator()
128129
var textContent string
129130
var thoughtContent string
130131
thinkParser := newThinkTagStreamParser()
@@ -186,24 +187,8 @@ func (o *OpenAIModel) processStream(stream *openai.ChatCompletionStream, yield f
186187
}
187188

188189
// 处理标准工具调用
189-
for _, toolCall := range choice.Delta.ToolCalls {
190-
idx := 0
191-
if toolCall.Index != nil {
192-
idx = *toolCall.Index
193-
}
194-
195-
if _, exists := toolCallsMap[idx]; !exists {
196-
toolCallsMap[idx] = &toolCallBuilder{}
197-
}
198-
199-
builder := toolCallsMap[idx]
200-
if toolCall.ID != "" {
201-
builder.id = toolCall.ID
202-
}
203-
if toolCall.Function.Name != "" {
204-
builder.name = toolCall.Function.Name
205-
}
206-
builder.args += toolCall.Function.Arguments
190+
for pos, toolCall := range choice.Delta.ToolCalls {
191+
toolCalls.AddDelta(pos, toolCall)
207192
}
208193

209194
if choice.FinishReason != "" {
@@ -248,19 +233,18 @@ func (o *OpenAIModel) processStream(stream *openai.ChatCompletionStream, yield f
248233
}
249234

250235
// 聚合标准工具调用
251-
if len(toolCallsMap) > 0 {
252-
indices := sortedKeys(toolCallsMap)
253-
for _, idx := range indices {
254-
builder := toolCallsMap[idx]
255-
part := &genai.Part{
256-
FunctionCall: &genai.FunctionCall{
257-
ID: builder.id,
258-
Name: builder.name,
259-
Args: parseJSONArgs(builder.args),
260-
},
261-
}
262-
aggregatedContent.Parts = append(aggregatedContent.Parts, part)
236+
for _, builder := range toolCalls.OrderedBuilders() {
237+
if builder == nil {
238+
continue
239+
}
240+
part := &genai.Part{
241+
FunctionCall: &genai.FunctionCall{
242+
ID: builder.id,
243+
Name: builder.name,
244+
Args: parseJSONArgs(builder.args),
245+
},
263246
}
247+
aggregatedContent.Parts = append(aggregatedContent.Parts, part)
264248
}
265249

266250
if streamErr != nil {
@@ -285,12 +269,133 @@ type toolCallBuilder struct {
285269
args string
286270
}
287271

288-
// sortedKeys 返回排序后的 map keys
289-
func sortedKeys(m map[int]*toolCallBuilder) []int {
290-
keys := make([]int, 0, len(m))
291-
for k := range m {
292-
keys = append(keys, k)
272+
// chatStreamToolCallAggregator 聚合兼容 OpenAI 的流式工具调用。
273+
// 一些兼容接口不会稳定返回 index,这里优先按 index / id 归并,
274+
// 都缺失时再按当前位置兜底,并在检测到新 JSON 对象时自动切分。
275+
type chatStreamToolCallAggregator struct {
276+
builders map[string]*toolCallBuilder
277+
order []string
278+
indexKeys map[int]string
279+
idKeys map[string]string
280+
fallbackKeys map[int]string
281+
nextSynthetic int
282+
}
283+
284+
func newChatStreamToolCallAggregator() *chatStreamToolCallAggregator {
285+
return &chatStreamToolCallAggregator{
286+
builders: make(map[string]*toolCallBuilder),
287+
indexKeys: make(map[int]string),
288+
idKeys: make(map[string]string),
289+
fallbackKeys: make(map[int]string),
290+
}
291+
}
292+
293+
func (a *chatStreamToolCallAggregator) AddDelta(pos int, toolCall openai.ToolCall) {
294+
key := a.lookupKey(pos, toolCall)
295+
if builder := a.builders[key]; shouldRotateToolCallBuilder(builder, toolCall) {
296+
key = a.newSyntheticKey()
297+
}
298+
299+
a.bindAliases(pos, toolCall, key)
300+
builder := a.ensureBuilder(key)
301+
if toolCall.ID != "" {
302+
builder.id = toolCall.ID
303+
}
304+
if toolCall.Function.Name != "" {
305+
builder.name = toolCall.Function.Name
306+
}
307+
if toolCall.Function.Arguments != "" {
308+
builder.args += toolCall.Function.Arguments
309+
}
310+
}
311+
312+
func (a *chatStreamToolCallAggregator) OrderedBuilders() []*toolCallBuilder {
313+
builders := make([]*toolCallBuilder, 0, len(a.order))
314+
for _, key := range a.order {
315+
builders = append(builders, a.builders[key])
316+
}
317+
return builders
318+
}
319+
320+
func (a *chatStreamToolCallAggregator) lookupKey(pos int, toolCall openai.ToolCall) string {
321+
if toolCall.Index != nil {
322+
if key, ok := a.indexKeys[*toolCall.Index]; ok {
323+
return key
324+
}
325+
}
326+
if toolCall.ID != "" {
327+
if key, ok := a.idKeys[toolCall.ID]; ok {
328+
return key
329+
}
330+
}
331+
if key, ok := a.fallbackKeys[pos]; ok {
332+
return key
333+
}
334+
if toolCall.Index != nil {
335+
return fmt.Sprintf("idx:%d", *toolCall.Index)
336+
}
337+
if toolCall.ID != "" {
338+
return fmt.Sprintf("id:%s", toolCall.ID)
339+
}
340+
return a.newSyntheticKey()
341+
}
342+
343+
func (a *chatStreamToolCallAggregator) bindAliases(pos int, toolCall openai.ToolCall, key string) {
344+
if toolCall.Index != nil {
345+
a.indexKeys[*toolCall.Index] = key
346+
}
347+
if toolCall.ID != "" {
348+
a.idKeys[toolCall.ID] = key
349+
}
350+
if toolCall.Index == nil {
351+
a.fallbackKeys[pos] = key
352+
}
353+
}
354+
355+
func (a *chatStreamToolCallAggregator) ensureBuilder(key string) *toolCallBuilder {
356+
if builder, ok := a.builders[key]; ok {
357+
return builder
358+
}
359+
builder := &toolCallBuilder{}
360+
a.builders[key] = builder
361+
a.order = append(a.order, key)
362+
return builder
363+
}
364+
365+
func (a *chatStreamToolCallAggregator) newSyntheticKey() string {
366+
key := fmt.Sprintf("anon:%d", a.nextSynthetic)
367+
a.nextSynthetic++
368+
return key
369+
}
370+
371+
func shouldRotateToolCallBuilder(builder *toolCallBuilder, toolCall openai.ToolCall) bool {
372+
if builder == nil {
373+
return false
374+
}
375+
if toolCall.ID != "" && builder.id != "" && toolCall.ID != builder.id {
376+
return true
377+
}
378+
if toolCall.Function.Name != "" && builder.name != "" && toolCall.Function.Name != builder.name {
379+
return true
380+
}
381+
382+
if !isCompleteJSONObject(builder.args) {
383+
return false
384+
}
385+
386+
if toolCall.Function.Arguments != "" {
387+
return true
388+
}
389+
if toolCall.ID != "" && builder.id == "" {
390+
return true
391+
}
392+
return false
393+
}
394+
395+
func isCompleteJSONObject(s string) bool {
396+
trimmed := strings.TrimSpace(s)
397+
if trimmed == "" {
398+
return false
293399
}
294-
slices.Sort(keys)
295-
return keys
400+
return json.Valid([]byte(trimmed))
296401
}

internal/adk/openai/responses_convert.go

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ func toResponsesRequest(req *model.LLMRequest, modelName string, noSystemRole bo
7070

7171
// 转换工具定义
7272
if len(req.Config.Tools) > 0 {
73-
apiReq.Tools = convertResponsesTools(req.Config.Tools)
73+
apiReq.Tools, err = convertResponsesTools(req.Config.Tools)
74+
if err != nil {
75+
return CreateResponseRequest{}, err
76+
}
7477
}
7578

7679
// 应用生成参数
@@ -144,7 +147,6 @@ func toResponsesInputItem(content *genai.Content) ([]ResponsesInputItem, error)
144147
}
145148
toolCallItems = append(toolCallItems, ResponsesInputItem{
146149
Type: "function_call",
147-
ID: part.FunctionCall.ID,
148150
CallID: part.FunctionCall.ID,
149151
Name: part.FunctionCall.Name,
150152
Arguments: string(argsJSON),
@@ -182,7 +184,7 @@ func convertRoleForResponses(role string) string {
182184
}
183185

184186
// convertResponsesTools 转换工具定义为 Responses API 扁平化格式
185-
func convertResponsesTools(genaiTools []*genai.Tool) []ResponsesTool {
187+
func convertResponsesTools(genaiTools []*genai.Tool) ([]ResponsesTool, error) {
186188
var tools []ResponsesTool
187189
for _, genaiTool := range genaiTools {
188190
if genaiTool == nil {
@@ -193,6 +195,10 @@ func convertResponsesTools(genaiTools []*genai.Tool) []ResponsesTool {
193195
if params == nil {
194196
params = funcDecl.Parameters
195197
}
198+
params, err := normalizeResponsesToolSchema(params)
199+
if err != nil {
200+
return nil, fmt.Errorf("normalize responses tool schema %s: %w", funcDecl.Name, err)
201+
}
196202
tools = append(tools, ResponsesTool{
197203
Type: "function",
198204
Name: funcDecl.Name,
@@ -201,7 +207,95 @@ func convertResponsesTools(genaiTools []*genai.Tool) []ResponsesTool {
201207
})
202208
}
203209
}
204-
return tools
210+
return tools, nil
211+
}
212+
213+
func normalizeResponsesToolSchema(schema any) (any, error) {
214+
if schema == nil {
215+
return defaultResponsesToolSchema(), nil
216+
}
217+
218+
raw, err := json.Marshal(schema)
219+
if err != nil {
220+
return nil, fmt.Errorf("marshal schema: %w", err)
221+
}
222+
if string(raw) == "null" {
223+
return defaultResponsesToolSchema(), nil
224+
}
225+
226+
var normalized any
227+
if err := json.Unmarshal(raw, &normalized); err != nil {
228+
return nil, fmt.Errorf("unmarshal schema: %w", err)
229+
}
230+
if normalized == nil {
231+
return defaultResponsesToolSchema(), nil
232+
}
233+
234+
normalizeResponsesSchemaNode(normalized)
235+
if root, ok := normalized.(map[string]any); ok {
236+
ensureResponsesObjectSchema(root)
237+
}
238+
239+
return normalized, nil
240+
}
241+
242+
func defaultResponsesToolSchema() map[string]any {
243+
return map[string]any{
244+
"type": "object",
245+
"properties": map[string]any{},
246+
"additionalProperties": false,
247+
}
248+
}
249+
250+
func normalizeResponsesSchemaNode(node any) {
251+
switch typed := node.(type) {
252+
case map[string]any:
253+
ensureResponsesObjectSchema(typed)
254+
for key, value := range typed {
255+
if key == "additionalProperties" {
256+
if schemaMap, ok := value.(map[string]any); ok && isAlwaysFalseSchema(schemaMap) {
257+
typed[key] = false
258+
continue
259+
}
260+
}
261+
normalizeResponsesSchemaNode(value)
262+
}
263+
case []any:
264+
for _, item := range typed {
265+
normalizeResponsesSchemaNode(item)
266+
}
267+
}
268+
}
269+
270+
func ensureResponsesObjectSchema(schema map[string]any) {
271+
if !isObjectSchemaType(schema["type"]) {
272+
return
273+
}
274+
if props, ok := schema["properties"]; !ok || props == nil {
275+
schema["properties"] = map[string]any{}
276+
}
277+
}
278+
279+
func isObjectSchemaType(schemaType any) bool {
280+
switch typed := schemaType.(type) {
281+
case string:
282+
return typed == "object"
283+
case []any:
284+
for _, item := range typed {
285+
if item == "object" {
286+
return true
287+
}
288+
}
289+
}
290+
return false
291+
}
292+
293+
func isAlwaysFalseSchema(schema map[string]any) bool {
294+
if len(schema) != 1 {
295+
return false
296+
}
297+
notSchema, ok := schema["not"].(map[string]any)
298+
return ok && len(notSchema) == 0
205299
}
206300

207301
// convertResponsesResponse 将 Responses API 响应转换为 ADK LLMResponse

0 commit comments

Comments
 (0)