Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 11 additions & 72 deletions pkg/fake/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -338,87 +338,26 @@ func SimulatedStreamCopy(c echo.Context, resp *http.Response, chunkDelay time.Du
ctx := c.Request().Context()
writer := c.Response().Writer

reader := bufio.NewReaderSize(resp.Body, 64*1024)
w := c.Response().Writer

// Reuse timer to avoid allocations per chunk
timer := time.NewTimer(chunkDelay)
defer timer.Stop()

dataPrefix := []byte("data:")
rf, ok := w.(io.ReaderFrom)
if !ok {
// fallback seguro
_, err := io.Copy(w, resp.Body)
return err
}

for {
select {
case <-ctx.Done():
slog.WarnContext(ctx, "client disconnected, stop streaming")
return nil
default:
}

line, err := reader.ReadBytes('\n')
if err != nil {
if err == io.EOF {
// Write any remaining data without newline
if len(line) > 0 {
_, _ = writer.Write(line)
c.Response().Flush()
n, err := rf.ReadFrom(io.LimitReader(resp.Body, 256))
if n > 0 {
if flusher, ok := w.(http.Flusher); ok {
flusher.Flush()
}
return nil
}
return err
}

// Write the line (already includes newline from ReadBytes)
if _, err := writer.Write(line); err != nil {
return err
}

// Add delay after data lines (SSE events start with "data:")
if bytes.HasPrefix(line, dataPrefix) {
c.Response().Flush()
timer.Reset(chunkDelay)
select {
case <-ctx.Done():
return nil
case <-timer.C:
}
}
}
}

// streamReadResult holds the result of a streaming read operation.
type streamReadResult struct {
n int64
err error
}

// StreamCopy copies a streaming response to the client.
// It properly handles context cancellation during blocking reads.
func StreamCopy(c echo.Context, resp *http.Response) error {
ctx := c.Request().Context()
writer := c.Response().Writer.(io.ReaderFrom)

// Use a channel to receive read results from a goroutine.
// This allows us to properly select on context cancellation
// even when the read is blocking.
resultCh := make(chan streamReadResult, 1)

for {
// Start a goroutine to perform the blocking read
go func() {
n, err := writer.ReadFrom(io.LimitReader(resp.Body, 256))
resultCh <- streamReadResult{n: n, err: err}
}()

// Wait for either context cancellation or read completion
select {
case <-ctx.Done():
slog.WarnContext(ctx, "client disconnected, stop streaming")
// Close the response body to unblock the read goroutine
resp.Body.Close()
return nil
case result := <-resultCh:
if result.n > 0 {
c.Response().Flush() // keep flushing to client
}
if result.err != nil {
// io.EOF or context canceled means normal completion
Expand Down
11 changes: 7 additions & 4 deletions pkg/runtime/connectrpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,14 @@ func (c *ConnectRPCClient) convertProtoEventToRuntimeEvent(e *cagentv1.Event) Ev
LastMessage: convertProtoMessageUsage(ev.TokenUsage.Usage.LastMessage),
}
}

return &TokenUsageEvent{
Type: "token_usage",
SessionID: ev.TokenUsage.SessionId,
Usage: usage,
AgentContext: AgentContext{AgentName: ev.TokenUsage.AgentName},
Type: "token_usage",
SessionID: ev.TokenUsage.SessionId,
Usage: usage,
AgentContext: AgentContext{
AgentName: ev.TokenUsage.AgentName,
},
}

case *cagentv1.Event_SessionTitle:
Expand Down
71 changes: 59 additions & 12 deletions pkg/runtime/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package runtime

import (
"cmp"
"time"

"github.com/docker/cagent/pkg/chat"
"github.com/docker/cagent/pkg/config/types"
Expand Down Expand Up @@ -192,9 +193,14 @@ type TokenUsageEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Usage *Usage `json:"usage"`

AgentContext
}

func (*TokenUsageEvent) GetType() string {
return "token_usage"
}

type Usage struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
Expand All @@ -212,11 +218,29 @@ type MessageUsage struct {
Model string
}

func TokenUsage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int64, cost float64) Event {
return TokenUsageWithMessage(sessionID, agentName, inputTokens, outputTokens, contextLength, contextLimit, cost, nil)
}

func TokenUsageWithMessage(sessionID, agentName string, inputTokens, outputTokens, contextLength, contextLimit int64, cost float64, msgUsage *MessageUsage) Event {
func TokenUsage(
sessionID, agentName string,
inputTokens, outputTokens, contextLength, contextLimit int64,
cost float64,
) Event {
return TokenUsageWithMessage(
sessionID,
agentName,
inputTokens,
outputTokens,
contextLength,
contextLimit,
cost,
nil,
)
}

func TokenUsageWithMessage(
sessionID, agentName string,
inputTokens, outputTokens, contextLength, contextLimit int64,
cost float64,
msgUsage *MessageUsage,
) Event {
return &TokenUsageEvent{
Type: "token_usage",
SessionID: sessionID,
Expand All @@ -232,13 +256,6 @@ func TokenUsageWithMessage(sessionID, agentName string, inputTokens, outputToken
}
}

type SessionTitleEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Title string `json:"title"`
AgentContext
}

func SessionTitle(sessionID, title string) Event {
return &SessionTitleEvent{
Type: "session_title",
Expand Down Expand Up @@ -525,3 +542,33 @@ func HookBlocked(toolCall tools.ToolCall, toolDefinition tools.Tool, message, ag
AgentContext: AgentContext{AgentName: agentName},
}
}

type SessionMetricsEvent struct {
Type string `json:"type"` // "session_metrics"

SessionID string `json:"session_id"`

UserMessages int `json:"user_messages"`
AssistantMessages int `json:"assistant_messages"`
ToolCalls int `json:"tool_calls"`
ToolErrors int `json:"tool_errors"`

StartedAt time.Time `json:"started_at"`
EndedAt time.Time `json:"ended_at"`
}

func (*SessionMetricsEvent) GetType() string {
return "session_metrics"
}

type SessionTitleEvent struct {
Type string `json:"type"`
SessionID string `json:"session_id"`
Title string `json:"title"`

AgentContext
}

func (*SessionTitleEvent) GetType() string {
return "session_title"
}
59 changes: 53 additions & 6 deletions pkg/server/session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,29 +130,47 @@ func (sm *SessionManager) DeleteSession(ctx context.Context, sessionID string) e
}

// RunSession runs a session with the given messages.
func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilename, currentAgent string, messages []api.Message) (<-chan runtime.Event, error) {
func (sm *SessionManager) RunSession(
ctx context.Context,
sessionID, agentFilename, currentAgent string,
messages []api.Message,
) (<-chan runtime.Event, error) {
sm.mux.Lock()
defer sm.mux.Unlock()

// Load persisted session
sess, err := sm.sessionStore.GetSession(ctx, sessionID)
if err != nil {
sm.mux.Unlock()
return nil, err
}

// Mark execution start (observability only)
sess.Metrics = session.Metrics{}
sess.Metrics.StartedAt = time.Now()

// Clone runtime config and inherit working dir
rc := sm.runConfig.Clone()
rc.WorkingDir = sess.WorkingDir

// Append user messages and count them
for _, msg := range messages {
sess.AddMessage(session.UserMessage(msg.Content, msg.MultiContent...))
sess.Metrics.UserMessages++
}

if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
sm.mux.Unlock()
return nil, err
}

// Get or create runtime for this session
runtimeSession, exists := sm.runtimeSessions.Load(sessionID)
streamCtx, cancel := context.WithCancel(ctx)

if !exists {
rt, err := sm.runtimeForSession(ctx, sess, agentFilename, currentAgent, rc)
if err != nil {
sm.mux.Unlock()
cancel()
return nil, err
}
Expand All @@ -163,22 +181,51 @@ func (sm *SessionManager) RunSession(ctx context.Context, sessionID, agentFilena
sm.runtimeSessions.Store(sessionID, runtimeSession)
}

sm.mux.Unlock()

streamChan := make(chan runtime.Event)

go func() {
stream := runtimeSession.runtime.RunStream(streamCtx, sess)
defer cancel()
defer close(streamChan)

stream := runtimeSession.runtime.RunStream(streamCtx, sess)

for event := range stream {
if streamCtx.Err() != nil {
return
}

// Collect session-level observability metrics
if e, ok := event.(interface{ GetType() string }); ok {
switch e.GetType() {
case "assistant_message":
sess.Metrics.AssistantMessages++
case "tool_call":
sess.Metrics.ToolCalls++
case "tool_error":
sess.Metrics.ToolErrors++
}
}

streamChan <- event
}

if err := sm.sessionStore.UpdateSession(ctx, sess); err != nil {
return
}
// Mark execution end
sess.Metrics.EndedAt = time.Now()

streamChan <- runtime.TokenUsage(
sess.ID,
currentAgent,
sess.InputTokens,
sess.OutputTokens,
0,
0,
sess.Cost,
)

// Persist updated session state (metrics are ephemeral)
_ = sm.sessionStore.UpdateSession(ctx, sess)
}()

return streamChan, nil
Expand Down
36 changes: 32 additions & 4 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,35 @@ func (si *Item) IsSubSession() bool {
return si.SubSession != nil
}

// Metrics holds runtime-level metrics collected during a session.
// These metrics are not persisted and are intended for observability,
// debugging, and UX purposes.
type Metrics struct {
// StartedAt is the time when the session execution began.
StartedAt time.Time

// EndedAt is the time when the session execution finished.
EndedAt time.Time

// UserMessages is the number of user messages sent during the session.
UserMessages int

// AssistantMessages is the number of assistant messages generated.
AssistantMessages int

// ToolCalls is the total number of tool invocations.
ToolCalls int

// ToolErrors is the number of failed tool invocations.
ToolErrors int
}

// Reset clears all metrics.
// MUST be called at the beginning of each RunSession execution.
func (m *Metrics) Reset() {
*m = Metrics{}
}

// Session represents the agent's state including conversation history and variables
type Session struct {
// ID is the unique identifier for the session
Expand All @@ -59,13 +88,13 @@ type Session struct {
// CreatedAt is the time the session was created
CreatedAt time.Time `json:"created_at"`

// Metrics holds performance and interaction metrics for this session
Metrics Metrics `json:"-"`

// ToolsApproved is a flag to indicate if the tools have been approved
ToolsApproved bool `json:"tools_approved"`

// Thinking is a session-level flag to enable thinking/interleaved thinking
// defaults for all providers. When false, providers will not apply auto-thinking budgets
// or interleaved thinking, regardless of model config. This is controlled by the /think
// command in the TUI. Defaults to true (thinking enabled).
Thinking bool `json:"thinking"`

// HideToolResults is a flag to indicate if tool results should be hidden
Expand Down Expand Up @@ -94,7 +123,6 @@ type Session struct {

// AgentModelOverrides stores per-agent model overrides for this session.
// Key is the agent name, value is the model reference (e.g., "openai/gpt-4o" or a named model from config).
// When a session is loaded, these overrides are reapplied to the runtime.
AgentModelOverrides map[string]string `json:"agent_model_overrides,omitempty"`

// CustomModelsUsed tracks custom models (provider/model format) used during this session.
Expand Down
Loading