@@ -2,11 +2,12 @@ package openai
22
33import (
44 "context"
5+ "encoding/json"
56 "errors"
67 "fmt"
78 "io"
89 "iter"
9- "slices "
10+ "strings "
1011
1112 "github.com/sashabaranov/go-openai"
1213 "google.golang.org/adk/model"
@@ -124,7 +125,7 @@ func (o *OpenAIModel) processStream(stream *openai.ChatCompletionStream, yield f
124125 }
125126 var finishReason genai.FinishReason
126127 var usageMetadata * genai.GenerateContentResponseUsageMetadata
127- toolCallsMap := make ( map [ int ] * toolCallBuilder )
128+ toolCalls := newChatStreamToolCallAggregator ( )
128129 var textContent string
129130 var thoughtContent string
130131 thinkParser := newThinkTagStreamParser ()
@@ -186,24 +187,8 @@ func (o *OpenAIModel) processStream(stream *openai.ChatCompletionStream, yield f
186187 }
187188
188189 // 处理标准工具调用
189- for _ , toolCall := range choice .Delta .ToolCalls {
190- idx := 0
191- if toolCall .Index != nil {
192- idx = * toolCall .Index
193- }
194-
195- if _ , exists := toolCallsMap [idx ]; ! exists {
196- toolCallsMap [idx ] = & toolCallBuilder {}
197- }
198-
199- builder := toolCallsMap [idx ]
200- if toolCall .ID != "" {
201- builder .id = toolCall .ID
202- }
203- if toolCall .Function .Name != "" {
204- builder .name = toolCall .Function .Name
205- }
206- builder .args += toolCall .Function .Arguments
190+ for pos , toolCall := range choice .Delta .ToolCalls {
191+ toolCalls .AddDelta (pos , toolCall )
207192 }
208193
209194 if choice .FinishReason != "" {
@@ -248,19 +233,18 @@ func (o *OpenAIModel) processStream(stream *openai.ChatCompletionStream, yield f
248233 }
249234
250235 // 聚合标准工具调用
251- if len (toolCallsMap ) > 0 {
252- indices := sortedKeys (toolCallsMap )
253- for _ , idx := range indices {
254- builder := toolCallsMap [idx ]
255- part := & genai.Part {
256- FunctionCall : & genai.FunctionCall {
257- ID : builder .id ,
258- Name : builder .name ,
259- Args : parseJSONArgs (builder .args ),
260- },
261- }
262- aggregatedContent .Parts = append (aggregatedContent .Parts , part )
236+ for _ , builder := range toolCalls .OrderedBuilders () {
237+ if builder == nil {
238+ continue
239+ }
240+ part := & genai.Part {
241+ FunctionCall : & genai.FunctionCall {
242+ ID : builder .id ,
243+ Name : builder .name ,
244+ Args : parseJSONArgs (builder .args ),
245+ },
263246 }
247+ aggregatedContent .Parts = append (aggregatedContent .Parts , part )
264248 }
265249
266250 if streamErr != nil {
@@ -285,12 +269,133 @@ type toolCallBuilder struct {
285269 args string
286270}
287271
288- // sortedKeys 返回排序后的 map keys
289- func sortedKeys (m map [int ]* toolCallBuilder ) []int {
290- keys := make ([]int , 0 , len (m ))
291- for k := range m {
292- keys = append (keys , k )
272+ // chatStreamToolCallAggregator 聚合兼容 OpenAI 的流式工具调用。
273+ // 一些兼容接口不会稳定返回 index,这里优先按 index / id 归并,
274+ // 都缺失时再按当前位置兜底,并在检测到新 JSON 对象时自动切分。
275+ type chatStreamToolCallAggregator struct {
276+ builders map [string ]* toolCallBuilder
277+ order []string
278+ indexKeys map [int ]string
279+ idKeys map [string ]string
280+ fallbackKeys map [int ]string
281+ nextSynthetic int
282+ }
283+
284+ func newChatStreamToolCallAggregator () * chatStreamToolCallAggregator {
285+ return & chatStreamToolCallAggregator {
286+ builders : make (map [string ]* toolCallBuilder ),
287+ indexKeys : make (map [int ]string ),
288+ idKeys : make (map [string ]string ),
289+ fallbackKeys : make (map [int ]string ),
290+ }
291+ }
292+
293+ func (a * chatStreamToolCallAggregator ) AddDelta (pos int , toolCall openai.ToolCall ) {
294+ key := a .lookupKey (pos , toolCall )
295+ if builder := a .builders [key ]; shouldRotateToolCallBuilder (builder , toolCall ) {
296+ key = a .newSyntheticKey ()
297+ }
298+
299+ a .bindAliases (pos , toolCall , key )
300+ builder := a .ensureBuilder (key )
301+ if toolCall .ID != "" {
302+ builder .id = toolCall .ID
303+ }
304+ if toolCall .Function .Name != "" {
305+ builder .name = toolCall .Function .Name
306+ }
307+ if toolCall .Function .Arguments != "" {
308+ builder .args += toolCall .Function .Arguments
309+ }
310+ }
311+
312+ func (a * chatStreamToolCallAggregator ) OrderedBuilders () []* toolCallBuilder {
313+ builders := make ([]* toolCallBuilder , 0 , len (a .order ))
314+ for _ , key := range a .order {
315+ builders = append (builders , a .builders [key ])
316+ }
317+ return builders
318+ }
319+
320+ func (a * chatStreamToolCallAggregator ) lookupKey (pos int , toolCall openai.ToolCall ) string {
321+ if toolCall .Index != nil {
322+ if key , ok := a .indexKeys [* toolCall .Index ]; ok {
323+ return key
324+ }
325+ }
326+ if toolCall .ID != "" {
327+ if key , ok := a .idKeys [toolCall .ID ]; ok {
328+ return key
329+ }
330+ }
331+ if key , ok := a .fallbackKeys [pos ]; ok {
332+ return key
333+ }
334+ if toolCall .Index != nil {
335+ return fmt .Sprintf ("idx:%d" , * toolCall .Index )
336+ }
337+ if toolCall .ID != "" {
338+ return fmt .Sprintf ("id:%s" , toolCall .ID )
339+ }
340+ return a .newSyntheticKey ()
341+ }
342+
343+ func (a * chatStreamToolCallAggregator ) bindAliases (pos int , toolCall openai.ToolCall , key string ) {
344+ if toolCall .Index != nil {
345+ a .indexKeys [* toolCall .Index ] = key
346+ }
347+ if toolCall .ID != "" {
348+ a .idKeys [toolCall .ID ] = key
349+ }
350+ if toolCall .Index == nil {
351+ a .fallbackKeys [pos ] = key
352+ }
353+ }
354+
355+ func (a * chatStreamToolCallAggregator ) ensureBuilder (key string ) * toolCallBuilder {
356+ if builder , ok := a .builders [key ]; ok {
357+ return builder
358+ }
359+ builder := & toolCallBuilder {}
360+ a .builders [key ] = builder
361+ a .order = append (a .order , key )
362+ return builder
363+ }
364+
365+ func (a * chatStreamToolCallAggregator ) newSyntheticKey () string {
366+ key := fmt .Sprintf ("anon:%d" , a .nextSynthetic )
367+ a .nextSynthetic ++
368+ return key
369+ }
370+
371+ func shouldRotateToolCallBuilder (builder * toolCallBuilder , toolCall openai.ToolCall ) bool {
372+ if builder == nil {
373+ return false
374+ }
375+ if toolCall .ID != "" && builder .id != "" && toolCall .ID != builder .id {
376+ return true
377+ }
378+ if toolCall .Function .Name != "" && builder .name != "" && toolCall .Function .Name != builder .name {
379+ return true
380+ }
381+
382+ if ! isCompleteJSONObject (builder .args ) {
383+ return false
384+ }
385+
386+ if toolCall .Function .Arguments != "" {
387+ return true
388+ }
389+ if toolCall .ID != "" && builder .id == "" {
390+ return true
391+ }
392+ return false
393+ }
394+
395+ func isCompleteJSONObject (s string ) bool {
396+ trimmed := strings .TrimSpace (s )
397+ if trimmed == "" {
398+ return false
293399 }
294- slices .Sort (keys )
295- return keys
400+ return json .Valid ([]byte (trimmed ))
296401}
0 commit comments