diff --git a/.github/workflows/train-drain3-weights.yml b/.github/workflows/train-drain3-weights.yml new file mode 100644 index 0000000000..d30224bf6b --- /dev/null +++ b/.github/workflows/train-drain3-weights.yml @@ -0,0 +1,92 @@ +name: Train Log Pattern Weights + +on: + schedule: + - cron: "0 4 * * *" # Daily at 04:00 UTC + workflow_dispatch: + +permissions: {} + +jobs: + train: + name: Download logs and train drain3 weights + runs-on: ubuntu-latest + timeout-minutes: 30 + permissions: + contents: write + pull-requests: write + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Set up Go + uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # v6 + with: + go-version-file: go.mod + cache: true + + - name: Build gh-aw + run: make build + + - name: Download run logs and train weights + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + mkdir -p /tmp/drain3-logs + ./gh-aw logs --train --output /tmp/drain3-logs --count 50 + + - name: Copy trained weights to source tree + run: | + if [ -f /tmp/drain3-logs/drain3_weights.json ]; then + cp /tmp/drain3-logs/drain3_weights.json pkg/agentdrain/data/default_weights.json + echo "✅ Weights file updated successfully" + else + echo "⚠️ No drain3_weights.json produced – skipping PR creation" + exit 0 + fi + + - name: Check for changes + id: check-changes + run: | + if git diff --quiet pkg/agentdrain/data/default_weights.json; then + echo "changes=false" >> "$GITHUB_OUTPUT" + echo "No changes to default_weights.json – weights are already up to date" + else + echo "changes=true" >> "$GITHUB_OUTPUT" + echo "Changes detected in default_weights.json" + fi + + - name: Configure Git + if: steps.check-changes.outputs.changes == 'true' + run: | + git config --global user.name "github-actions[bot]" + git config --global user.email "github-actions[bot]@users.noreply.github.com" + + - name: Create pull request with updated weights + if: steps.check-changes.outputs.changes == 'true' + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + BRANCH_NAME="ci/train-drain3-weights-$(date +%Y%m%d)" + + git checkout -b "$BRANCH_NAME" + git add pkg/agentdrain/data/default_weights.json + git commit -m "chore: update drain3 default weights from daily training run" + + git push origin "$BRANCH_NAME" + + gh pr create \ + --title "chore: update drain3 default log pattern weights" \ + --body "This pull request updates the default Drain3 log pattern weights (\`pkg/agentdrain/data/default_weights.json\`) by training on the most recent workflow run logs. + + ## What changed + - Re-trained log template clusters from the latest run logs using \`gh aw logs --train\` + - Copied resulting \`drain3_weights.json\` to the embedded defaults path + + ## How to verify + 1. Build the binary with \`make build\` + 2. Run \`gh aw audit\` or \`gh aw logs --train\` and confirm the anomaly analysis reflects the updated patterns + + This PR was created automatically by the [train-drain3-weights](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) workflow." \ + --head "$BRANCH_NAME" \ + --base main diff --git a/pkg/agentdrain/anomaly.go b/pkg/agentdrain/anomaly.go new file mode 100644 index 0000000000..19053eb966 --- /dev/null +++ b/pkg/agentdrain/anomaly.go @@ -0,0 +1,75 @@ +package agentdrain + +import "strings" + +// AnomalyDetector evaluates match results and produces AnomalyReports. +type AnomalyDetector struct { + threshold float64 + rareThreshold int +} + +// NewAnomalyDetector creates an AnomalyDetector with the given thresholds. +func NewAnomalyDetector(simThreshold float64, rareClusterThreshold int) *AnomalyDetector { + return &AnomalyDetector{ + threshold: simThreshold, + rareThreshold: rareClusterThreshold, + } +} + +// Analyze produces an AnomalyReport for a match result. +// +// - isNew indicates the line created a brand-new cluster. +// - cluster is the cluster that was matched or created. +func (d *AnomalyDetector) Analyze(result *MatchResult, isNew bool, cluster *Cluster) *AnomalyReport { + report := &AnomalyReport{ + IsNewTemplate: isNew, + NewClusterCreated: isNew, + } + + if !isNew { + report.LowSimilarity = result.Similarity < d.threshold + } + + if cluster != nil { + report.RareCluster = cluster.Size <= d.rareThreshold + } + + // Weighted anomaly score. + var score float64 + if report.IsNewTemplate { + score += 1.0 + } + if report.LowSimilarity { + score += 0.7 + } + if report.RareCluster { + score += 0.3 + } + // Normalize to [0, 1]. + const maxScore = 2.0 + if score > maxScore { + score = maxScore + } + report.AnomalyScore = score / maxScore + + report.Reason = buildReason(report) + return report +} + +// buildReason constructs a human-readable summary of detected anomalies. +func buildReason(r *AnomalyReport) string { + var parts []string + if r.IsNewTemplate { + parts = append(parts, "new log template discovered") + } + if r.LowSimilarity { + parts = append(parts, "low similarity to known template") + } + if r.RareCluster { + parts = append(parts, "rare cluster (few observations)") + } + if len(parts) == 0 { + return "no anomaly detected" + } + return strings.Join(parts, "; ") +} diff --git a/pkg/agentdrain/anomaly_test.go b/pkg/agentdrain/anomaly_test.go new file mode 100644 index 0000000000..d244b581b1 --- /dev/null +++ b/pkg/agentdrain/anomaly_test.go @@ -0,0 +1,123 @@ +//go:build !integration + +package agentdrain + +import ( + "testing" +) + +func TestAnomalyDetection_NewTemplate(t *testing.T) { + d := NewAnomalyDetector(0.4, 2) + c := &Cluster{ID: 1, Template: []string{"stage=plan"}, Size: 1} + result := &MatchResult{ClusterID: 1, Similarity: 1.0} + + report := d.Analyze(result, true, c) + + if !report.IsNewTemplate { + t.Error("expected IsNewTemplate=true") + } + if !report.NewClusterCreated { + t.Error("expected NewClusterCreated=true") + } + if report.AnomalyScore <= 0 { + t.Errorf("expected positive anomaly score for new template, got %v", report.AnomalyScore) + } +} + +func TestAnomalyDetection_LowSimilarity(t *testing.T) { + d := NewAnomalyDetector(0.4, 2) + // Size=5 means not rare; not new. + c := &Cluster{ID: 1, Template: []string{"a", "b", "c"}, Size: 5} + result := &MatchResult{ClusterID: 1, Similarity: 0.2} + + report := d.Analyze(result, false, c) + + if !report.LowSimilarity { + t.Error("expected LowSimilarity=true for similarity below threshold") + } + if report.IsNewTemplate { + t.Error("expected IsNewTemplate=false") + } + if report.AnomalyScore <= 0 { + t.Errorf("expected positive anomaly score, got %v", report.AnomalyScore) + } +} + +func TestAnomalyDetection_RareCluster(t *testing.T) { + d := NewAnomalyDetector(0.4, 2) + c := &Cluster{ID: 1, Template: []string{"a"}, Size: 1} + result := &MatchResult{ClusterID: 1, Similarity: 0.9} + + report := d.Analyze(result, false, c) + + if !report.RareCluster { + t.Error("expected RareCluster=true for size=1 with rareThreshold=2") + } + if report.AnomalyScore <= 0 { + t.Errorf("expected positive anomaly score, got %v", report.AnomalyScore) + } +} + +func TestAnomalyDetection_Normal(t *testing.T) { + d := NewAnomalyDetector(0.4, 2) + // High size, high similarity, not new. + c := &Cluster{ID: 1, Template: []string{"a", "b"}, Size: 100} + result := &MatchResult{ClusterID: 1, Similarity: 0.9} + + report := d.Analyze(result, false, c) + + if report.IsNewTemplate { + t.Error("expected IsNewTemplate=false") + } + if report.LowSimilarity { + t.Error("expected LowSimilarity=false") + } + if report.RareCluster { + t.Error("expected RareCluster=false") + } + if report.AnomalyScore > 0 { + t.Errorf("expected zero anomaly score for normal event, got %v", report.AnomalyScore) + } + if report.Reason != "no anomaly detected" { + t.Errorf("expected 'no anomaly detected', got %q", report.Reason) + } +} + +func TestAnalyzeEvent(t *testing.T) { + cfg := DefaultConfig() + m, err := NewMiner(cfg) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + + // First occurrence → new template. + evt := AgentEvent{ + Stage: "plan", + Fields: map[string]string{"action": "start", "model": "gpt-4"}, + } + result, report, err := m.AnalyzeEvent(evt) + if err != nil { + t.Fatalf("AnalyzeEvent: %v", err) + } + if result == nil { + t.Fatal("AnalyzeEvent: expected non-nil result") + } + if report == nil { + t.Fatal("AnalyzeEvent: expected non-nil report") + } + if !report.IsNewTemplate { + t.Error("first event should be a new template") + } + + // Second occurrence of the same event → not new. + result2, report2, err := m.AnalyzeEvent(evt) + if err != nil { + t.Fatalf("AnalyzeEvent (second): %v", err) + } + if result2 == nil || report2 == nil { + t.Fatal("AnalyzeEvent (second): expected non-nil results") + } + if report2.IsNewTemplate { + t.Error("second identical event should not be a new template") + } +} diff --git a/pkg/agentdrain/cluster.go b/pkg/agentdrain/cluster.go new file mode 100644 index 0000000000..078ec9684b --- /dev/null +++ b/pkg/agentdrain/cluster.go @@ -0,0 +1,88 @@ +package agentdrain + +// clusterStore manages the set of known log template clusters. +type clusterStore struct { + clusters map[int]*Cluster + nextID int +} + +func newClusterStore() *clusterStore { + return &clusterStore{ + clusters: make(map[int]*Cluster), + nextID: 1, + } +} + +// add creates a new Cluster for the given template and returns a pointer to it. +func (s *clusterStore) add(template []string, stage string) *Cluster { + id := s.nextID + s.nextID++ + tmpl := make([]string, len(template)) + copy(tmpl, template) + c := &Cluster{ + ID: id, + Template: tmpl, + Size: 1, + Stage: stage, + } + s.clusters[id] = c + return c +} + +// get retrieves a cluster by ID. +func (s *clusterStore) get(id int) (*Cluster, bool) { + c, ok := s.clusters[id] + return c, ok +} + +// all returns a snapshot of all clusters as a value slice. +func (s *clusterStore) all() []Cluster { + out := make([]Cluster, 0, len(s.clusters)) + for _, c := range s.clusters { + out = append(out, *c) + } + return out +} + +// computeSimilarity returns the fraction of positions where tokens a and b +// match exactly, considering only positions that are not paramToken in a. +// Returns 0 when the slices have different lengths. +func computeSimilarity(a, b []string, paramToken string) float64 { + if len(a) != len(b) { + return 0 + } + nonParam := 0 + matches := 0 + for i, tok := range a { + if tok == paramToken { + continue + } + nonParam++ + if tok == b[i] { + matches++ + } + } + if nonParam == 0 { + // All positions are wildcards – treat as a perfect structural match. + return 1.0 + } + return float64(matches) / float64(nonParam) +} + +// mergeTemplate produces a new template by replacing positions where the two +// token slices differ with paramToken. Positions where either token already is +// paramToken also become paramToken. +func mergeTemplate(existing, incoming []string, paramToken string) []string { + if len(existing) != len(incoming) { + return existing + } + merged := make([]string, len(existing)) + for i, tok := range existing { + if tok == paramToken || incoming[i] == paramToken || tok != incoming[i] { + merged[i] = paramToken + } else { + merged[i] = tok + } + } + return merged +} diff --git a/pkg/agentdrain/config.go b/pkg/agentdrain/config.go new file mode 100644 index 0000000000..f87628e823 --- /dev/null +++ b/pkg/agentdrain/config.go @@ -0,0 +1,45 @@ +package agentdrain + +// DefaultConfig returns a Config pre-loaded with sensible production defaults. +func DefaultConfig() Config { + return Config{ + Depth: 4, + SimThreshold: 0.4, + MaxChildren: 100, + ParamToken: "<*>", + RareClusterThreshold: 2, + ExcludeFields: []string{"session_id", "trace_id", "span_id", "timestamp"}, + MaskRules: []MaskRule{ + { + Name: "uuid", + Pattern: `[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}`, + Replacement: "", + }, + { + Name: "session_id", + Pattern: `session=[a-z0-9]+`, + Replacement: "session=<*>", + }, + { + Name: "number_value", + Pattern: `=\d+`, + Replacement: "=", + }, + { + Name: "url", + Pattern: `https?://[^\s]+`, + Replacement: "", + }, + { + Name: "quoted_string", + Pattern: `"[^"]*"`, + Replacement: `"<*>"`, + }, + { + Name: "timestamp", + Pattern: `\d{4}-\d{2}-\d{2}[T ]\d{2}:\d{2}:\d{2}(?:\.\d+)?(?:Z|[+-]\d{2}:\d{2})?`, + Replacement: "", + }, + }, + } +} diff --git a/pkg/agentdrain/coordinator.go b/pkg/agentdrain/coordinator.go new file mode 100644 index 0000000000..365e9118e3 --- /dev/null +++ b/pkg/agentdrain/coordinator.go @@ -0,0 +1,167 @@ +package agentdrain + +import ( + "encoding/json" + "fmt" + "strings" + "sync" +) + +// Coordinator manages one Miner per agent pipeline stage. +type Coordinator struct { + miners map[string]*Miner + cfg Config + mu sync.RWMutex +} + +// NewCoordinator creates a Coordinator with one Miner for each provided stage name. +func NewCoordinator(cfg Config, stages []string) (*Coordinator, error) { + miners := make(map[string]*Miner, len(stages)) + for _, stage := range stages { + m, err := NewMiner(cfg) + if err != nil { + return nil, fmt.Errorf("agentdrain: NewCoordinator: stage %q: %w", stage, err) + } + miners[stage] = m + } + return &Coordinator{miners: miners, cfg: cfg}, nil +} + +// TrainEvent routes the event to the miner responsible for evt.Stage. +// Returns an error when the stage has no associated miner. +func (c *Coordinator) TrainEvent(evt AgentEvent) (*MatchResult, error) { + m, err := c.minerFor(evt.Stage) + if err != nil { + return nil, err + } + return m.TrainEvent(evt) +} + +// AnalyzeEvent routes the event to the correct stage miner and returns both +// the match result and an anomaly report. +func (c *Coordinator) AnalyzeEvent(evt AgentEvent) (*MatchResult, *AnomalyReport, error) { + m, err := c.minerFor(evt.Stage) + if err != nil { + return nil, nil, err + } + return m.AnalyzeEvent(evt) +} + +// Stages returns the list of stage names managed by this Coordinator. +func (c *Coordinator) Stages() []string { + c.mu.RLock() + defer c.mu.RUnlock() + stages := make([]string, 0, len(c.miners)) + for s := range c.miners { + stages = append(stages, s) + } + return stages +} + +// MinerForStage returns the Miner for the given stage, or false if not found. +func (c *Coordinator) MinerForStage(stage string) (*Miner, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + m, ok := c.miners[stage] + return m, ok +} + +// AllClusters returns a map from stage name to the list of clusters in that miner. +func (c *Coordinator) AllClusters() map[string][]Cluster { + c.mu.RLock() + defer c.mu.RUnlock() + result := make(map[string][]Cluster, len(c.miners)) + for stage, m := range c.miners { + result[stage] = m.Clusters() + } + return result +} + +// SaveSnapshots serializes each stage miner's state and returns a map from +// stage name to JSON bytes. +func (c *Coordinator) SaveSnapshots() (map[string][]byte, error) { + c.mu.RLock() + defer c.mu.RUnlock() + out := make(map[string][]byte, len(c.miners)) + for stage, m := range c.miners { + data, err := m.SaveJSON() + if err != nil { + return nil, fmt.Errorf("agentdrain: SaveSnapshots: stage %q: %w", stage, err) + } + out[stage] = data + } + return out, nil +} + +// LoadSnapshots restores each stage miner from the provided JSON bytes map. +// Stages that are not present in snapshots retain their current state. +func (c *Coordinator) LoadSnapshots(snapshots map[string][]byte) error { + c.mu.Lock() + defer c.mu.Unlock() + for stage, data := range snapshots { + m, ok := c.miners[stage] + if !ok { + // Create a new miner for previously unknown stages. + var err error + m, err = NewMiner(c.cfg) + if err != nil { + return fmt.Errorf("agentdrain: LoadSnapshots: stage %q: %w", stage, err) + } + c.miners[stage] = m + } + if err := m.LoadJSON(data); err != nil { + return fmt.Errorf("agentdrain: LoadSnapshots: stage %q: %w", stage, err) + } + } + return nil +} + +// minerFor retrieves the miner for the given stage, returning an error if missing. +func (c *Coordinator) minerFor(stage string) (*Miner, error) { + c.mu.RLock() + m, ok := c.miners[stage] + c.mu.RUnlock() + if !ok { + return nil, fmt.Errorf("agentdrain: no miner registered for stage %q", stage) + } + return m, nil +} + +// SaveWeightsJSON serializes all stage snapshots into a single combined JSON blob. +// The result can be written to pkg/agentdrain/data/default_weights.json and +// committed to embed it as the default starting weights for future runs. +func (c *Coordinator) SaveWeightsJSON() ([]byte, error) { + snapshots, err := c.SaveSnapshots() + if err != nil { + return nil, err + } + combined := make(map[string]json.RawMessage, len(snapshots)) + for stage, data := range snapshots { + combined[stage] = json.RawMessage(data) + } + return json.Marshal(combined) +} + +// LoadWeightsJSON restores all stage miners from a combined JSON blob produced +// by SaveWeightsJSON. +func (c *Coordinator) LoadWeightsJSON(data []byte) error { + var combined map[string]json.RawMessage + if err := json.Unmarshal(data, &combined); err != nil { + return fmt.Errorf("agentdrain: LoadWeightsJSON: %w", err) + } + snapshots := make(map[string][]byte, len(combined)) + for stage, raw := range combined { + snapshots[stage] = []byte(raw) + } + return c.LoadSnapshots(snapshots) +} + +// StageSequence converts a slice of AgentEvents into a space-separated string +// of their stage names, e.g. "plan tool_call tool_result finish". +func StageSequence(events []AgentEvent) string { + stages := make([]string, 0, len(events)) + for _, e := range events { + stages = append(stages, e.Stage) + } + return strings.Join(stages, " ") +} diff --git a/pkg/agentdrain/data/default_weights.json b/pkg/agentdrain/data/default_weights.json new file mode 100644 index 0000000000..0967ef424b --- /dev/null +++ b/pkg/agentdrain/data/default_weights.json @@ -0,0 +1 @@ +{} diff --git a/pkg/agentdrain/defaults.go b/pkg/agentdrain/defaults.go new file mode 100644 index 0000000000..4360a0a01e --- /dev/null +++ b/pkg/agentdrain/defaults.go @@ -0,0 +1,30 @@ +package agentdrain + +import ( + "bytes" + _ "embed" +) + +//go:embed data/default_weights.json +var defaultWeightsJSON []byte + +// LoadDefaultWeights restores all stage miners from the embedded default weights file +// (pkg/agentdrain/data/default_weights.json). When the file is empty or contains +// only an empty JSON object the call is a no-op and returns nil. +// +// Update the default weights by running: +// +// gh aw logs --train --output +// +// and copying the resulting drain3_weights.json to pkg/agentdrain/data/default_weights.json, +// then rebuilding the binary. +func (c *Coordinator) LoadDefaultWeights() error { + if len(defaultWeightsJSON) == 0 { + return nil + } + // A bare "{}" file means no weights have been trained yet. + if string(bytes.TrimSpace(defaultWeightsJSON)) == "{}" { + return nil + } + return c.LoadWeightsJSON(defaultWeightsJSON) +} diff --git a/pkg/agentdrain/mask.go b/pkg/agentdrain/mask.go new file mode 100644 index 0000000000..651038c23b --- /dev/null +++ b/pkg/agentdrain/mask.go @@ -0,0 +1,79 @@ +package agentdrain + +import ( + "fmt" + "regexp" + "sort" + "strings" +) + +// Masker applies a sequence of regex substitution rules to normalize log lines. +type Masker struct { + rules []compiledRule +} + +type compiledRule struct { + name string + re *regexp.Regexp + replacement string +} + +// NewMasker compiles the given MaskRules into a Masker ready for use. +// Returns an error if any pattern fails to compile. +func NewMasker(rules []MaskRule) (*Masker, error) { + compiled := make([]compiledRule, 0, len(rules)) + for _, r := range rules { + re, err := regexp.Compile(r.Pattern) + if err != nil { + return nil, fmt.Errorf("agentdrain: mask rule %q: %w", r.Name, err) + } + compiled = append(compiled, compiledRule{ + name: r.Name, + re: re, + replacement: r.Replacement, + }) + } + return &Masker{rules: compiled}, nil +} + +// Mask applies all mask rules in order and returns the transformed line. +func (m *Masker) Mask(line string) string { + for _, r := range m.rules { + line = r.re.ReplaceAllString(line, r.replacement) + } + return line +} + +// FlattenEvent converts an AgentEvent into a deterministic string suitable for +// template mining. Field keys are sorted alphabetically; fields listed in +// excludeFields are omitted. The result looks like: +// +// stage=tool_call key1=val1 key2=val2 +func FlattenEvent(evt AgentEvent, excludeFields []string) string { + excluded := make(map[string]bool, len(excludeFields)) + for _, f := range excludeFields { + excluded[f] = true + } + + keys := make([]string, 0, len(evt.Fields)) + for k := range evt.Fields { + if !excluded[k] { + keys = append(keys, k) + } + } + sort.Strings(keys) + + parts := make([]string, 0, len(keys)+1) + if evt.Stage != "" { + parts = append(parts, "stage="+evt.Stage) + } + for _, k := range keys { + parts = append(parts, k+"="+evt.Fields[k]) + } + return strings.Join(parts, " ") +} + +// Tokenize splits a log line on whitespace and returns the individual tokens. +func Tokenize(line string) []string { + return strings.Fields(line) +} diff --git a/pkg/agentdrain/miner.go b/pkg/agentdrain/miner.go new file mode 100644 index 0000000000..067d8128b6 --- /dev/null +++ b/pkg/agentdrain/miner.go @@ -0,0 +1,188 @@ +package agentdrain + +import ( + "errors" + "fmt" + "strings" + "sync" +) + +// Miner is a concurrent Drain-style log template miner. +// Use NewMiner to create an instance. +type Miner struct { + cfg Config + masker *Masker + tree *parseTree + store *clusterStore + mu sync.RWMutex +} + +// NewMiner creates a Miner from the given Config. +func NewMiner(cfg Config) (*Miner, error) { + masker, err := NewMasker(cfg.MaskRules) + if err != nil { + return nil, fmt.Errorf("agentdrain: NewMiner: %w", err) + } + return &Miner{ + cfg: cfg, + masker: masker, + tree: newParseTree(), + store: newClusterStore(), + }, nil +} + +// Train processes a raw log line, updates the miner state, and returns the +// match result. It is safe to call from multiple goroutines. +func (m *Miner) Train(line string) (*MatchResult, error) { + masked := m.masker.Mask(line) + tokens := Tokenize(masked) + if len(tokens) == 0 { + return nil, errors.New("agentdrain: Train: empty line after masking") + } + + m.mu.Lock() + defer m.mu.Unlock() + + result, _ := m.match(tokens) + if result != nil { + // Merge and update existing cluster. + c, _ := m.store.get(result.ClusterID) + c.Template = mergeTemplate(c.Template, tokens, m.cfg.ParamToken) + c.Size++ + result.Template = strings.Join(c.Template, " ") + result.Params = extractParams(tokens, c.Template, m.cfg.ParamToken) + return result, nil + } + + // Create new cluster. + c := m.store.add(tokens, "") + m.tree.addCluster(tokens, c.ID, m.cfg.Depth, m.cfg.MaxChildren, m.cfg.ParamToken) + return &MatchResult{ + ClusterID: c.ID, + Template: strings.Join(c.Template, " "), + Params: []string{}, + Similarity: 1.0, + Stage: c.Stage, + }, nil +} + +// Match performs inference only: it finds the best matching cluster but does +// not mutate any state. Returns (result, true) when a match is found. +// It is safe to call from multiple goroutines. +func (m *Miner) Match(line string) (*MatchResult, bool, error) { + masked := m.masker.Mask(line) + tokens := Tokenize(masked) + if len(tokens) == 0 { + return nil, false, errors.New("agentdrain: Match: empty line after masking") + } + + m.mu.RLock() + defer m.mu.RUnlock() + + result, _ := m.match(tokens) + if result == nil { + return nil, false, nil + } + return result, true, nil +} + +// match is the internal (non-locking) lookup. Must be called with mu held. +func (m *Miner) match(tokens []string) (*MatchResult, bool) { + candidates := m.tree.search(tokens, m.cfg.Depth, m.cfg.ParamToken) + bestSim := -1.0 + var best *Cluster + for _, id := range candidates { + c, ok := m.store.get(id) + if !ok { + continue + } + sim := computeSimilarity(c.Template, tokens, m.cfg.ParamToken) + if sim > bestSim { + bestSim = sim + best = c + } + } + if best == nil || bestSim < m.cfg.SimThreshold { + return nil, false + } + params := extractParams(tokens, best.Template, m.cfg.ParamToken) + return &MatchResult{ + ClusterID: best.ID, + Template: strings.Join(best.Template, " "), + Params: params, + Similarity: bestSim, + Stage: best.Stage, + }, true +} + +// TrainEvent flattens the AgentEvent and calls Train. +func (m *Miner) TrainEvent(evt AgentEvent) (*MatchResult, error) { + line := FlattenEvent(evt, m.cfg.ExcludeFields) + result, err := m.Train(line) + if err != nil { + return nil, err + } + result.Stage = evt.Stage + // Propagate stage to cluster. + m.mu.Lock() + if c, ok := m.store.get(result.ClusterID); ok && c.Stage == "" { + c.Stage = evt.Stage + } + m.mu.Unlock() + return result, nil +} + +// AnalyzeEvent performs inference on the event, builds an AnomalyReport, and +// then calls TrainEvent to update the miner. Returns the match result and report. +func (m *Miner) AnalyzeEvent(evt AgentEvent) (*MatchResult, *AnomalyReport, error) { + line := FlattenEvent(evt, m.cfg.ExcludeFields) + masked := m.masker.Mask(line) + tokens := Tokenize(masked) + if len(tokens) == 0 { + return nil, nil, errors.New("agentdrain: AnalyzeEvent: empty event after masking") + } + + m.mu.RLock() + inferResult, _ := m.match(tokens) + m.mu.RUnlock() + + isNew := inferResult == nil + result, err := m.TrainEvent(evt) + if err != nil { + return nil, nil, err + } + + var cluster *Cluster + m.mu.RLock() + cluster, _ = m.store.get(result.ClusterID) + m.mu.RUnlock() + + detector := NewAnomalyDetector(m.cfg.SimThreshold, m.cfg.RareClusterThreshold) + report := detector.Analyze(result, isNew, cluster) + return result, report, nil +} + +// Clusters returns a snapshot of all known clusters. +func (m *Miner) Clusters() []Cluster { + m.mu.RLock() + defer m.mu.RUnlock() + return m.store.all() +} + +// ClusterCount returns the number of known clusters. +func (m *Miner) ClusterCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.store.clusters) +} + +// extractParams returns the token values at positions where the template has paramToken. +func extractParams(tokens []string, template []string, paramToken string) []string { + params := []string{} + for i, tok := range template { + if tok == paramToken && i < len(tokens) { + params = append(params, tokens[i]) + } + } + return params +} diff --git a/pkg/agentdrain/miner_test.go b/pkg/agentdrain/miner_test.go new file mode 100644 index 0000000000..1cf7ae3533 --- /dev/null +++ b/pkg/agentdrain/miner_test.go @@ -0,0 +1,376 @@ +//go:build !integration + +package agentdrain + +import ( + "fmt" + "strings" + "sync" + "testing" +) + +func TestNewMiner(t *testing.T) { + cfg := DefaultConfig() + m, err := NewMiner(cfg) + if err != nil { + t.Fatalf("NewMiner: unexpected error: %v", err) + } + if m == nil { + t.Fatal("NewMiner: expected non-nil miner") + } + if m.ClusterCount() != 0 { + t.Errorf("NewMiner: expected 0 clusters, got %d", m.ClusterCount()) + } +} + +func TestTrain_ClusterCreation(t *testing.T) { + m, err := NewMiner(DefaultConfig()) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + result, err := m.Train("stage=plan action=start") + if err != nil { + t.Fatalf("Train: unexpected error: %v", err) + } + if result.ClusterID == 0 { + t.Error("Train: expected non-zero ClusterID") + } + if m.ClusterCount() != 1 { + t.Errorf("Train: expected 1 cluster, got %d", m.ClusterCount()) + } +} + +func TestTrain_ClusterMerge(t *testing.T) { + cfg := DefaultConfig() + cfg.SimThreshold = 0.4 + m, err := NewMiner(cfg) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + + // These two lines differ only in the tool name value. + _, err = m.Train("stage=tool_call tool=search") + if err != nil { + t.Fatalf("Train 1: %v", err) + } + result, err := m.Train("stage=tool_call tool=read_file") + if err != nil { + t.Fatalf("Train 2: %v", err) + } + + // Should merge into one cluster. + if m.ClusterCount() != 1 { + t.Errorf("expected 1 cluster after merge, got %d", m.ClusterCount()) + } + if !strings.Contains(result.Template, "<*>") { + t.Errorf("expected wildcard in merged template, got: %q", result.Template) + } +} + +func TestMatch_InferenceOnly(t *testing.T) { + m, err := NewMiner(DefaultConfig()) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + // Train once. + _, err = m.Train("stage=plan action=start") + if err != nil { + t.Fatalf("Train: %v", err) + } + before := m.ClusterCount() + + // Match should not create new clusters. + _, _, err = m.Match("stage=plan action=unknown_value") + if err != nil { + t.Fatalf("Match: unexpected error: %v", err) + } + if m.ClusterCount() != before { + t.Errorf("Match: cluster count changed from %d to %d", before, m.ClusterCount()) + } +} + +func TestMasking(t *testing.T) { + masker, err := NewMasker(DefaultConfig().MaskRules) + if err != nil { + t.Fatalf("NewMasker: %v", err) + } + + tests := []struct { + input string + check func(string) bool + name string + }{ + { + name: "UUID replaced", + input: "id=550e8400-e29b-41d4-a716-446655440000 msg=ok", + check: func(s string) bool { return strings.Contains(s, "") }, + }, + { + name: "URL replaced", + input: "fetching https://example.com/api/v1", + check: func(s string) bool { return strings.Contains(s, "") }, + }, + { + name: "Number value replaced", + input: "latency_ms=250", + check: func(s string) bool { return strings.Contains(s, "=") }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := masker.Mask(tt.input) + if !tt.check(out) { + t.Errorf("Mask(%q) = %q, check failed", tt.input, out) + } + }) + } +} + +func TestFlattenEvent(t *testing.T) { + evt := AgentEvent{ + Stage: "tool_call", + Fields: map[string]string{ + "tool": "search", + "query": "foo", + "session_id": "abc123", + "latency_ms": "42", + }, + } + exclude := []string{"session_id"} + result := FlattenEvent(evt, exclude) + + // session_id must be excluded. + if strings.Contains(result, "session_id") { + t.Errorf("FlattenEvent: excluded field present: %q", result) + } + // Keys should be sorted: latency_ms < query < tool. + idx := func(s string) int { return strings.Index(result, s) } + if idx("latency_ms=") > idx("query=") || idx("query=") > idx("tool=") { + t.Errorf("FlattenEvent: keys not sorted: %q", result) + } + // Stage should appear first. + if !strings.HasPrefix(result, "stage=tool_call") { + t.Errorf("FlattenEvent: stage not first: %q", result) + } +} + +func TestPreTrainTemplate(t *testing.T) { + cfg := DefaultConfig() + cfg.SimThreshold = 0.4 + m, err := NewMiner(cfg) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + + m.PreTrainTemplate("stage=tool_call tool=<*> latency_ms=<*>", 5) + if m.ClusterCount() != 1 { + t.Fatalf("PreTrainTemplate: expected 1 cluster, got %d", m.ClusterCount()) + } + + // A real line matching that pattern should hit the pre-trained cluster. + result, ok, err := m.Match("stage=tool_call tool=search latency_ms=<*>") + if err != nil { + t.Fatalf("Match: %v", err) + } + if !ok { + t.Error("Match: expected to find pre-trained cluster, got no match") + } + if result != nil && result.ClusterID == 0 { + t.Error("Match: expected valid cluster ID") + } +} + +func TestSaveLoadJSON(t *testing.T) { + cfg := DefaultConfig() + m, err := NewMiner(cfg) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + lines := []string{ + "stage=plan action=start", + "stage=plan action=start", + "stage=tool_call tool=search", + } + for _, l := range lines { + if _, err := m.Train(l); err != nil { + t.Fatalf("Train(%q): %v", l, err) + } + } + originalCount := m.ClusterCount() + + data, err := m.SaveJSON() + if err != nil { + t.Fatalf("SaveJSON: %v", err) + } + + m2, err := LoadMinerJSON(data) + if err != nil { + t.Fatalf("LoadMinerJSON: %v", err) + } + if m2.ClusterCount() != originalCount { + t.Errorf("round-trip: expected %d clusters, got %d", originalCount, m2.ClusterCount()) + } +} + +func TestConcurrency(t *testing.T) { + m, err := NewMiner(DefaultConfig()) + if err != nil { + t.Fatalf("NewMiner: %v", err) + } + + var wg sync.WaitGroup + const goroutines = 10 + const linesEach = 50 + + for g := range goroutines { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := range linesEach { + line := fmt.Sprintf("stage=work goroutine=%d iter=%d", id, i) + if _, err := m.Train(line); err != nil { + t.Errorf("Train: %v", err) + } + } + }(g) + } + wg.Wait() + + if m.ClusterCount() == 0 { + t.Error("expected clusters after concurrent training") + } +} + +func TestStageRouting(t *testing.T) { + cfg := DefaultConfig() + stages := []string{"plan", "tool_call", "finish"} + coord, err := NewCoordinator(cfg, stages) + if err != nil { + t.Fatalf("NewCoordinator: %v", err) + } + + events := []AgentEvent{ + {Stage: "plan", Fields: map[string]string{"action": "start"}}, + {Stage: "tool_call", Fields: map[string]string{"tool": "search", "query": "foo"}}, + {Stage: "finish", Fields: map[string]string{"status": "ok"}}, + } + for _, evt := range events { + if _, err := coord.TrainEvent(evt); err != nil { + t.Fatalf("TrainEvent(%q): %v", evt.Stage, err) + } + } + + for _, stage := range stages { + m, ok := coord.MinerForStage(stage) + if !ok { + t.Errorf("MinerForStage(%q): not found", stage) + continue + } + if m.ClusterCount() == 0 { + t.Errorf("stage %q: expected at least one cluster", stage) + } + } + + // Unknown stage should error. + _, err = coord.TrainEvent(AgentEvent{Stage: "unknown", Fields: map[string]string{}}) + if err == nil { + t.Error("expected error for unknown stage, got nil") + } +} + +func TestComputeSimilarity(t *testing.T) { + param := "<*>" + tests := []struct { + name string + a []string + b []string + expected float64 + }{ + { + name: "identical", + a: []string{"stage=plan", "action=start"}, + b: []string{"stage=plan", "action=start"}, + expected: 1.0, + }, + { + name: "one diff", + a: []string{"stage=plan", "action=start"}, + b: []string{"stage=plan", "action=stop"}, + expected: 0.5, + }, + { + name: "length mismatch", + a: []string{"a", "b"}, + b: []string{"a"}, + expected: 0.0, + }, + { + name: "wildcard ignored", + a: []string{"stage=plan", param}, + b: []string{"stage=plan", "anything"}, + expected: 1.0, + }, + { + name: "all wildcards", + a: []string{param, param}, + b: []string{"x", "y"}, + expected: 1.0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := computeSimilarity(tt.a, tt.b, param) + if got != tt.expected { + t.Errorf("computeSimilarity = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestMergeTemplate(t *testing.T) { + param := "<*>" + tests := []struct { + name string + existing []string + incoming []string + expected []string + }{ + { + name: "no difference", + existing: []string{"a", "b"}, + incoming: []string{"a", "b"}, + expected: []string{"a", "b"}, + }, + { + name: "one diff becomes wildcard", + existing: []string{"a", "b"}, + incoming: []string{"a", "c"}, + expected: []string{"a", param}, + }, + { + name: "existing wildcard preserved", + existing: []string{param, "b"}, + incoming: []string{"x", "b"}, + expected: []string{param, "b"}, + }, + { + name: "length mismatch returns existing", + existing: []string{"a", "b"}, + incoming: []string{"a"}, + expected: []string{"a", "b"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeTemplate(tt.existing, tt.incoming, param) + if len(got) != len(tt.expected) { + t.Fatalf("mergeTemplate len = %d, want %d", len(got), len(tt.expected)) + } + for i, tok := range got { + if tok != tt.expected[i] { + t.Errorf("mergeTemplate[%d] = %q, want %q", i, tok, tt.expected[i]) + } + } + }) + } +} diff --git a/pkg/agentdrain/persist.go b/pkg/agentdrain/persist.go new file mode 100644 index 0000000000..094f73190e --- /dev/null +++ b/pkg/agentdrain/persist.go @@ -0,0 +1,96 @@ +package agentdrain + +import ( + "encoding/json" + "fmt" +) + +// Snapshot is the serializable representation of a Miner's state. +type Snapshot struct { + Config Config `json:"config"` + Clusters []SnapshotCluster `json:"clusters"` + NextID int `json:"next_id"` +} + +// SnapshotCluster is the serializable form of a single Cluster. +type SnapshotCluster struct { + ID int `json:"id"` + Template []string `json:"template"` + Size int `json:"size"` + Stage string `json:"stage"` +} + +// SaveJSON serializes the miner's current state to JSON bytes. +func (m *Miner) SaveJSON() ([]byte, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + snap := Snapshot{ + Config: m.cfg, + NextID: m.store.nextID, + } + for _, c := range m.store.clusters { + tmpl := make([]string, len(c.Template)) + copy(tmpl, c.Template) + snap.Clusters = append(snap.Clusters, SnapshotCluster{ + ID: c.ID, + Template: tmpl, + Size: c.Size, + Stage: c.Stage, + }) + } + return json.Marshal(snap) +} + +// LoadJSON restores miner state from JSON bytes produced by SaveJSON. +// The existing state is replaced; the parse tree is rebuilt from the snapshot. +func (m *Miner) LoadJSON(data []byte) error { + var snap Snapshot + if err := json.Unmarshal(data, &snap); err != nil { + return fmt.Errorf("agentdrain: LoadJSON: %w", err) + } + + masker, err := NewMasker(snap.Config.MaskRules) + if err != nil { + return fmt.Errorf("agentdrain: LoadJSON: %w", err) + } + + m.mu.Lock() + defer m.mu.Unlock() + + m.cfg = snap.Config + m.masker = masker + m.store = newClusterStore() + m.tree = newParseTree() + m.store.nextID = snap.NextID + + for _, sc := range snap.Clusters { + tmpl := make([]string, len(sc.Template)) + copy(tmpl, sc.Template) + c := &Cluster{ + ID: sc.ID, + Template: tmpl, + Size: sc.Size, + Stage: sc.Stage, + } + m.store.clusters[c.ID] = c + m.tree.addCluster(c.Template, c.ID, m.cfg.Depth, m.cfg.MaxChildren, m.cfg.ParamToken) + } + return nil +} + +// LoadMinerJSON creates a new Miner by restoring state from JSON bytes. +func LoadMinerJSON(data []byte) (*Miner, error) { + var snap Snapshot + if err := json.Unmarshal(data, &snap); err != nil { + return nil, fmt.Errorf("agentdrain: LoadMinerJSON: %w", err) + } + m, err := NewMiner(snap.Config) + if err != nil { + return nil, err + } + if err := m.LoadJSON(data); err != nil { + return nil, err + } + return m, nil +} diff --git a/pkg/agentdrain/pretrain.go b/pkg/agentdrain/pretrain.go new file mode 100644 index 0000000000..67187714de --- /dev/null +++ b/pkg/agentdrain/pretrain.go @@ -0,0 +1,50 @@ +package agentdrain + +import "strings" + +// PreTrainTemplate seeds the miner with a known template string and a +// synthetic observation count. The template is tokenized but not masked, +// so callers should pass already-normalized templates. +func (m *Miner) PreTrainTemplate(template string, count int) { + tokens := Tokenize(template) + if len(tokens) == 0 { + return + } + + m.mu.Lock() + defer m.mu.Unlock() + + // Check for an existing identical template. + candidates := m.tree.search(tokens, m.cfg.Depth, m.cfg.ParamToken) + for _, id := range candidates { + c, ok := m.store.get(id) + if !ok { + continue + } + if strings.Join(c.Template, " ") == strings.Join(tokens, " ") { + c.Size += count + return + } + } + + // Create a new cluster pre-seeded with the desired count. + c := m.store.add(tokens, "") + c.Size = count + m.tree.addCluster(tokens, c.ID, m.cfg.Depth, m.cfg.MaxChildren, m.cfg.ParamToken) +} + +// PreTrainTemplates seeds the miner with a slice of template strings, each +// with an initial count of 1. +func (m *Miner) PreTrainTemplates(templates []string) { + for _, t := range templates { + m.PreTrainTemplate(t, 1) + } +} + +// PreTrainTemplateCounts seeds the miner with a map of template strings to +// their initial observation counts. +func (m *Miner) PreTrainTemplateCounts(templates map[string]int) { + for t, count := range templates { + m.PreTrainTemplate(t, count) + } +} diff --git a/pkg/agentdrain/tree.go b/pkg/agentdrain/tree.go new file mode 100644 index 0000000000..a26a386de2 --- /dev/null +++ b/pkg/agentdrain/tree.go @@ -0,0 +1,75 @@ +package agentdrain + +// treeNode is an internal node in the Drain parse tree. +type treeNode struct { + // children maps a token string to its subtree node. + children map[string]*treeNode + // clusterIDs holds the IDs of clusters stored at this leaf. + clusterIDs []int +} + +func newTreeNode() *treeNode { + return &treeNode{children: make(map[string]*treeNode)} +} + +// parseTree is the two-level prefix tree used by Drain to bucket candidate clusters. +// Level 0 (root) → level 1 keyed by token count → level 2 keyed by first token +// → leaf containing cluster IDs. +type parseTree struct { + // root maps token-count (as int key via an inner map) to first-token nodes. + // Structure: tokenCount → firstToken → *treeNode (leaf) + root map[int]map[string]*treeNode +} + +func newParseTree() *parseTree { + return &parseTree{root: make(map[int]map[string]*treeNode)} +} + +// addCluster inserts clusterID into the leaf for the given tokens. +func (t *parseTree) addCluster(tokens []string, clusterID int, depth int, maxChildren int, paramToken string) { + n := len(tokens) + if t.root[n] == nil { + t.root[n] = make(map[string]*treeNode) + } + key := t.firstKey(tokens, depth, paramToken) + leaf := t.root[n][key] + if leaf == nil { + leaf = newTreeNode() + t.root[n][key] = leaf + } + leaf.clusterIDs = append(leaf.clusterIDs, clusterID) +} + +// search returns candidate cluster IDs for the given tokens. +func (t *parseTree) search(tokens []string, depth int, paramToken string) []int { + n := len(tokens) + byCount, ok := t.root[n] + if !ok { + return nil + } + key := t.firstKey(tokens, depth, paramToken) + leaf, ok := byCount[key] + if !ok { + // Also try the wildcard bucket. + leaf, ok = byCount[paramToken] + if !ok { + return nil + } + } + out := make([]int, len(leaf.clusterIDs)) + copy(out, leaf.clusterIDs) + return out +} + +// firstKey returns the routing key derived from the first meaningful token. +// When depth == 1, all lines with the same length share a single bucket. +func (t *parseTree) firstKey(tokens []string, depth int, paramToken string) string { + if depth <= 1 || len(tokens) == 0 { + return "*" + } + tok := tokens[0] + if tok == paramToken { + return paramToken + } + return tok +} diff --git a/pkg/agentdrain/types.go b/pkg/agentdrain/types.go new file mode 100644 index 0000000000..08df039cd4 --- /dev/null +++ b/pkg/agentdrain/types.go @@ -0,0 +1,79 @@ +package agentdrain + +// Config holds tuning parameters for the Drain log template miner. +type Config struct { + // Depth controls how many levels of the parse tree are used. + Depth int + // SimThreshold is the minimum similarity score (0–1) required to match an existing cluster. + SimThreshold float64 + // MaxChildren limits the number of children per internal tree node. + MaxChildren int + // ParamToken is the wildcard string inserted where tokens differ across log lines. + ParamToken string + // RareClusterThreshold marks clusters with size ≤ this value as rare. + RareClusterThreshold int + // MaskRules are applied before tokenization to normalize variable parts of log lines. + MaskRules []MaskRule + // ExcludeFields lists AgentEvent field keys that are omitted when flattening events. + ExcludeFields []string +} + +// MaskRule describes a regex substitution applied to log lines before processing. +type MaskRule struct { + // Name is a human-readable identifier for the rule. + Name string + // Pattern is the regular expression to match. + Pattern string + // Replacement is the string substituted for each match. + Replacement string +} + +// Cluster represents a group of log lines that share the same template. +type Cluster struct { + // ID is the unique cluster identifier. + ID int + // Template is the tokenized log template with wildcards at variable positions. + Template []string + // Size is the number of log lines that have been assigned to this cluster. + Size int + // Stage identifies which agent stage generated this cluster. + Stage string +} + +// MatchResult is returned after processing a log line through the miner. +type MatchResult struct { + // ClusterID is the ID of the matched or newly created cluster. + ClusterID int + // Template is the space-joined template string. + Template string + // Params holds the actual token values at wildcard positions. + Params []string + // Similarity is the fraction of non-wildcard positions that matched exactly. + Similarity float64 + // Stage is the agent stage associated with the matched cluster. + Stage string +} + +// AnomalyReport describes anomalies detected for a log line. +type AnomalyReport struct { + // IsNewTemplate is true when the log line created a new cluster. + IsNewTemplate bool + // LowSimilarity is true when the best match score was below the configured threshold. + LowSimilarity bool + // RareCluster is true when the matched cluster has been seen fewer times than the rare threshold. + RareCluster bool + // NewClusterCreated is true when this event produced a brand-new cluster. + NewClusterCreated bool + // AnomalyScore is a weighted composite score in the range [0, 1]. + AnomalyScore float64 + // Reason is a human-readable description of all anomalies that were detected. + Reason string +} + +// AgentEvent is a structured log event emitted by an agent pipeline stage. +type AgentEvent struct { + // Stage identifies the pipeline stage (e.g., "plan", "tool_call", "finish"). + Stage string + // Fields contains the key-value pairs parsed from the log line. + Fields map[string]string +} diff --git a/pkg/cli/audit_cross_run.go b/pkg/cli/audit_cross_run.go index 10b9511ad1..cde4e9dc3c 100644 --- a/pkg/cli/audit_cross_run.go +++ b/pkg/cli/audit_cross_run.go @@ -36,6 +36,7 @@ type CrossRunAuditReport struct { ErrorTrend ErrorTrendData `json:"error_trend"` DomainInventory []DomainInventoryEntry `json:"domain_inventory"` PerRunBreakdown []PerRunFirewallBreakdown `json:"per_run_breakdown"` + Drain3Insights []ObservabilityInsight `json:"drain3_insights,omitempty"` } // CrossRunSummary provides top-level statistics across all analyzed runs. @@ -366,6 +367,9 @@ func buildCrossRunAuditReport(inputs []crossRunInput) *CrossRunAuditReport { auditCrossRunLog.Printf("Cross-run audit report built: runs=%d, with_data=%d, unique_domains=%d, mcp_servers=%d", report.RunsAnalyzed, report.RunsWithData, report.Summary.UniqueDomains, len(report.MCPHealth)) + // --- Phase 7: drain3 multi-run pattern analysis --- + report.Drain3Insights = buildDrain3InsightsFromCrossRunInputs(inputs) + return report } @@ -466,3 +470,30 @@ func buildMetricsTrend(rows []metricsRawRow) MetricsTrendData { return trend } + +// buildDrain3InsightsFromCrossRunInputs converts cross-run inputs to ProcessedRuns and +// delegates to the shared multi-run drain3 analysis function. +// Returns nil if inputs is empty or if no events could be extracted. +func buildDrain3InsightsFromCrossRunInputs(inputs []crossRunInput) []ObservabilityInsight { + if len(inputs) == 0 { + return nil + } + runs := make([]ProcessedRun, 0, len(inputs)) + for _, in := range inputs { + pr := ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: in.RunID, + WorkflowName: in.WorkflowName, + Conclusion: in.Conclusion, + Duration: in.Duration, + Turns: in.Metrics.Turns, + TokenUsage: in.Metrics.TokenUsage, + EstimatedCost: in.Metrics.EstimatedCost, + ErrorCount: in.ErrorCount, + }, + MCPFailures: in.MCPFailures, + } + runs = append(runs, pr) + } + return buildDrain3InsightsMultiRun(runs) +} diff --git a/pkg/cli/audit_cross_run_render.go b/pkg/cli/audit_cross_run_render.go index bee22a1a38..a53266795f 100644 --- a/pkg/cli/audit_cross_run_render.go +++ b/pkg/cli/audit_cross_run_render.go @@ -139,6 +139,29 @@ func renderCrossRunReportMarkdown(report *CrossRunAuditReport) { fmt.Println() } + // Drain3 insights + if len(report.Drain3Insights) > 0 { + fmt.Println("## Agent Event Pattern Analysis") + fmt.Println() + for _, insight := range report.Drain3Insights { + severityIcon := "ℹ" + switch insight.Severity { + case "high": + severityIcon = "🔴" + case "medium": + severityIcon = "🟠" + case "low": + severityIcon = "🟡" + } + fmt.Printf("### %s %s\n\n", severityIcon, insight.Title) + fmt.Printf("**Category:** %s | **Severity:** %s\n\n", insight.Category, insight.Severity) + fmt.Printf("%s\n\n", insight.Summary) + if insight.Evidence != "" { + fmt.Printf("_Evidence:_ `%s`\n\n", insight.Evidence) + } + } + } + // Per-run breakdown if len(report.PerRunBreakdown) > 0 { fmt.Println("## Per-Run Breakdown") @@ -274,6 +297,28 @@ func renderCrossRunReportPretty(report *CrossRunAuditReport) { fmt.Fprintln(os.Stderr) } + // Drain3 insights + if len(report.Drain3Insights) > 0 { + fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Agent Event Pattern Analysis (%d insights)", len(report.Drain3Insights)))) + for _, insight := range report.Drain3Insights { + severityIcon := "ℹ" + switch insight.Severity { + case "high": + severityIcon = "🔴" + case "medium": + severityIcon = "🟠" + case "low": + severityIcon = "🟡" + } + fmt.Fprintf(os.Stderr, " %s [%s/%s] %s\n", severityIcon, insight.Category, insight.Severity, insight.Title) + fmt.Fprintf(os.Stderr, " %s\n", insight.Summary) + if insight.Evidence != "" { + fmt.Fprintf(os.Stderr, " evidence: %s\n", insight.Evidence) + } + } + fmt.Fprintln(os.Stderr) + } + // Per-run breakdown if len(report.PerRunBreakdown) > 0 { fmt.Fprintln(os.Stderr, console.FormatInfoMessage("Per-Run Breakdown")) diff --git a/pkg/cli/audit_cross_run_test.go b/pkg/cli/audit_cross_run_test.go index c3f48f0db1..133c2d54a6 100644 --- a/pkg/cli/audit_cross_run_test.go +++ b/pkg/cli/audit_cross_run_test.go @@ -758,3 +758,148 @@ func TestRenderCrossRunReportMarkdown_IncludesNewSections(t *testing.T) { assert.Contains(t, output, "Per-Run Breakdown", "Should have per-run breakdown") assert.Contains(t, output, "⚠", "Should have spike warnings") } + +func TestBuildDrain3InsightsFromCrossRunInputs_Empty(t *testing.T) { + insights := buildDrain3InsightsFromCrossRunInputs(nil) + assert.Nil(t, insights, "should return nil for empty inputs") + + insights = buildDrain3InsightsFromCrossRunInputs([]crossRunInput{}) + assert.Nil(t, insights, "should return nil for empty slice") +} + +func TestBuildDrain3InsightsFromCrossRunInputs_WithInputs(t *testing.T) { + inputs := []crossRunInput{ + { + RunID: 1, + WorkflowName: "test-workflow", + Conclusion: "success", + Metrics: LogMetrics{ + Turns: 5, + TokenUsage: 1000, + EstimatedCost: 0.05, + }, + ErrorCount: 0, + }, + { + RunID: 2, + WorkflowName: "test-workflow", + Conclusion: "failure", + Metrics: LogMetrics{ + Turns: 8, + TokenUsage: 2000, + EstimatedCost: 0.1, + }, + ErrorCount: 2, + MCPFailures: []MCPFailureReport{ + {ServerName: "github", Status: "timeout"}, + }, + }, + } + + // Verify the conversion maps fields correctly by checking via a converted ProcessedRun. + runs := make([]ProcessedRun, 0, len(inputs)) + for _, in := range inputs { + runs = append(runs, ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: in.RunID, + WorkflowName: in.WorkflowName, + Conclusion: in.Conclusion, + Turns: in.Metrics.Turns, + TokenUsage: in.Metrics.TokenUsage, + EstimatedCost: in.Metrics.EstimatedCost, + ErrorCount: in.ErrorCount, + }, + MCPFailures: in.MCPFailures, + }) + } + require.Equal(t, int64(1), runs[0].Run.DatabaseID, "first run ID should map to 1") + require.Equal(t, "test-workflow", runs[0].Run.WorkflowName, "workflow name should be mapped") + require.Equal(t, "success", runs[0].Run.Conclusion, "conclusion should be mapped") + require.Equal(t, 5, runs[0].Run.Turns, "turns should be mapped from Metrics.Turns") + require.Equal(t, 1000, runs[0].Run.TokenUsage, "tokens should be mapped from Metrics.TokenUsage") + require.Len(t, runs[1].MCPFailures, 1, "MCPFailures should be mapped") + require.Equal(t, "github", runs[1].MCPFailures[0].ServerName, "MCP server name should be mapped") + + insights := buildDrain3InsightsFromCrossRunInputs(inputs) + // Drain3 insights may or may not be generated depending on event count, + // but the function should not panic or error. + // If insights are generated they should have valid fields. + for _, insight := range insights { + assert.NotEmpty(t, insight.Category, "insight should have a category") + assert.NotEmpty(t, insight.Severity, "insight should have a severity") + assert.NotEmpty(t, insight.Title, "insight should have a title") + } +} + +func TestBuildCrossRunAuditReport_IncludesDrain3Insights(t *testing.T) { + inputs := []crossRunInput{ + { + RunID: 100, + WorkflowName: "test-workflow", + Conclusion: "success", + Metrics: LogMetrics{Turns: 5, TokenUsage: 500, EstimatedCost: 0.05}, + ErrorCount: 1, + MCPFailures: []MCPFailureReport{{ServerName: "github", Status: "timeout"}}, + }, + { + RunID: 101, + WorkflowName: "test-workflow", + Conclusion: "failure", + Metrics: LogMetrics{Turns: 8, TokenUsage: 2000, EstimatedCost: 0.1}, + ErrorCount: 2, + }, + } + + report := buildCrossRunAuditReport(inputs) + require.NotNil(t, report, "report should not be nil") + + // Phase 7 should have run and may produce insights. Even if no events are + // extracted the field must be initialised (nil is acceptable). + // Verify that Phase 7 fired without panic; if insights were produced, check + // they have the required fields. + for _, insight := range report.Drain3Insights { + assert.NotEmpty(t, insight.Category, "Drain3 insight should have a category") + assert.NotEmpty(t, insight.Severity, "Drain3 insight should have a severity") + assert.NotEmpty(t, insight.Title, "Drain3 insight should have a title") + } +} + +func TestRenderCrossRunReportMarkdown_IncludesDrain3Section(t *testing.T) { + report := &CrossRunAuditReport{ + RunsAnalyzed: 1, + Drain3Insights: []ObservabilityInsight{ + { + Category: "execution", + Severity: "info", + Title: "Log template patterns mined", + Summary: "Analysis identified 2 event templates.", + Evidence: "plan=1 finish=1", + }, + { + Category: "reliability", + Severity: "high", + Title: "2 anomalous event pattern(s) detected", + Summary: "Unusual events detected.", + }, + }, + } + + oldStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + + renderCrossRunReportMarkdown(report) + + w.Close() + os.Stdout = oldStdout + + var buf bytes.Buffer + _, _ = buf.ReadFrom(r) + output := buf.String() + + assert.Contains(t, output, "Agent Event Pattern Analysis", "Should include agent event pattern analysis section header") + assert.Contains(t, output, "Log template patterns mined", "Should include first insight title") + assert.Contains(t, output, "2 anomalous event pattern(s) detected", "Should include second insight title") + assert.Contains(t, output, "plan=1 finish=1", "Should include evidence") + assert.Contains(t, output, "🔴", "Should include high severity icon") +} diff --git a/pkg/cli/audit_report.go b/pkg/cli/audit_report.go index f402198856..2f648700ce 100644 --- a/pkg/cli/audit_report.go +++ b/pkg/cli/audit_report.go @@ -338,6 +338,7 @@ func buildAuditData(processedRun ProcessedRun, metrics LogMetrics, mcpToolUsage recommendations = append(recommendations, generateAgenticAssessmentRecommendations(agenticAssessments)...) observabilityInsights := buildAuditObservabilityInsights(processedRun, metricsData, toolUsage, createdItems) + observabilityInsights = append(observabilityInsights, buildDrain3Insights(processedRun, metricsData, toolUsage)...) // Generate performance metrics performanceMetrics := generatePerformanceMetrics(processedRun, metricsData, toolUsage) diff --git a/pkg/cli/context_cancellation_test.go b/pkg/cli/context_cancellation_test.go index df9fb11e71..b9b9109002 100644 --- a/pkg/cli/context_cancellation_test.go +++ b/pkg/cli/context_cancellation_test.go @@ -71,7 +71,7 @@ func TestDownloadWorkflowLogsWithCancellation(t *testing.T) { cancel() // Try to download logs with a cancelled context - err := DownloadWorkflowLogs(ctx, "", 10, "", "", "/tmp/test-logs", "", "", 0, 0, "", false, false, false, false, false, false, false, 0, "", "", false) + err := DownloadWorkflowLogs(ctx, "", 10, "", "", "/tmp/test-logs", "", "", 0, 0, "", false, false, false, false, false, false, false, 0, "", "", false, false) // Should return context.Canceled error assert.ErrorIs(t, err, context.Canceled, "Should return context.Canceled error when context is cancelled") @@ -111,7 +111,7 @@ func TestDownloadWorkflowLogsTimeoutRespected(t *testing.T) { start := time.Now() // Use a workflow name that doesn't exist to avoid actual network calls - _ = DownloadWorkflowLogs(ctx, "nonexistent-workflow-12345", 100, "", "", "/tmp/test-logs", "", "", 0, 0, "", false, false, false, false, false, false, false, 1, "", "", false) + _ = DownloadWorkflowLogs(ctx, "nonexistent-workflow-12345", 100, "", "", "/tmp/test-logs", "", "", 0, 0, "", false, false, false, false, false, false, false, 1, "", "", false, false) elapsed := time.Since(start) // Should complete within reasonable time (give 5 seconds buffer for test overhead) diff --git a/pkg/cli/drain3_integration.go b/pkg/cli/drain3_integration.go new file mode 100644 index 0000000000..5f93cfc1f3 --- /dev/null +++ b/pkg/cli/drain3_integration.go @@ -0,0 +1,371 @@ +package cli + +import ( + "fmt" + "strconv" + "strings" + + "github.com/github/gh-aw/pkg/agentdrain" + "github.com/github/gh-aw/pkg/logger" +) + +var drain3Log = logger.New("cli:drain3_integration") + +// defaultAgentDrainStages lists the stage names recognised by the coordinator. +var defaultAgentDrainStages = []string{ + "plan", "tool_call", "tool_result", "retry", "error", "finish", +} + +// buildDrain3Insights analyses a single ProcessedRun using Drain3-style template +// mining and returns additional ObservabilityInsights to be appended to the +// existing insight list. +func buildDrain3Insights(processedRun ProcessedRun, metrics MetricsData, toolUsage []ToolUsageInfo) []ObservabilityInsight { + drain3Log.Printf("Building drain3 insights: run_id=%d turns=%d tools=%d mcpFailures=%d missingTools=%d", + processedRun.Run.DatabaseID, metrics.Turns, len(toolUsage), len(processedRun.MCPFailures), len(processedRun.MissingTools)) + + cfg := agentdrain.DefaultConfig() + coordinator, err := agentdrain.NewCoordinator(cfg, defaultAgentDrainStages) + if err != nil { + drain3Log.Printf("Failed to create drain3 coordinator: %v", err) + return nil + } + if err := coordinator.LoadDefaultWeights(); err != nil { + drain3Log.Printf("Failed to load default drain3 weights: %v", err) + } + + events := buildAgentEventsFromProcessedRun(processedRun, metrics, toolUsage) + if len(events) == 0 { + return nil + } + + var anomalies []struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport + } + + for _, evt := range events { + result, report, err := coordinator.AnalyzeEvent(evt) + if err != nil { + // Unknown stage – skip gracefully. + drain3Log.Printf("AnalyzeEvent failed for stage=%s: %v", evt.Stage, err) + continue + } + if report != nil && report.AnomalyScore > 0.5 { + anomalies = append(anomalies, struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport + }{evt, result, report}) + } + } + + return buildInsightsFromDrain3Analysis(coordinator, anomalies, events) +} + +// buildDrain3InsightsMultiRun analyses multiple ProcessedRuns using a shared +// Drain3 coordinator, which allows cross-run pattern detection. It returns +// additional ObservabilityInsights. +func buildDrain3InsightsMultiRun(processedRuns []ProcessedRun) []ObservabilityInsight { + if len(processedRuns) == 0 { + return nil + } + drain3Log.Printf("Building drain3 multi-run insights: runs=%d", len(processedRuns)) + + cfg := agentdrain.DefaultConfig() + coordinator, err := agentdrain.NewCoordinator(cfg, defaultAgentDrainStages) + if err != nil { + drain3Log.Printf("Failed to create drain3 coordinator: %v", err) + return nil + } + if err := coordinator.LoadDefaultWeights(); err != nil { + drain3Log.Printf("Failed to load default drain3 weights: %v", err) + } + + totalEvents := 0 + var highAnomalies []struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport + } + + for _, pr := range processedRuns { + events := buildAgentEventsFromProcessedRun(pr, MetricsData{ + Turns: pr.Run.Turns, + TokenUsage: pr.Run.TokenUsage, + EstimatedCost: pr.Run.EstimatedCost, + ErrorCount: pr.Run.ErrorCount, + WarningCount: pr.Run.WarningCount, + }, nil) + totalEvents += len(events) + + for _, evt := range events { + result, report, err := coordinator.AnalyzeEvent(evt) + if err != nil { + continue + } + if report != nil && report.AnomalyScore > 0.6 { + highAnomalies = append(highAnomalies, struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport + }{evt, result, report}) + } + } + } + + if totalEvents == 0 { + return nil + } + + return buildMultiRunInsightsFromDrain3(coordinator, highAnomalies, len(processedRuns), totalEvents) +} + +// buildAgentEventsFromProcessedRun converts the structured data in a ProcessedRun +// into a slice of AgentEvents suitable for drain3 ingestion. +func buildAgentEventsFromProcessedRun(processedRun ProcessedRun, metrics MetricsData, toolUsage []ToolUsageInfo) []agentdrain.AgentEvent { + var events []agentdrain.AgentEvent + + // Synthesise a planning event from overall metrics. + if metrics.Turns > 0 { + events = append(events, agentdrain.AgentEvent{ + Stage: "plan", + Fields: map[string]string{ + "turns": strconv.Itoa(metrics.Turns), + "errors": strconv.Itoa(metrics.ErrorCount), + }, + }) + } + + // Tool-call events from the per-tool usage summary. + for _, tu := range toolUsage { + events = append(events, agentdrain.AgentEvent{ + Stage: "tool_call", + Fields: map[string]string{ + "tool": tu.Name, + "calls": strconv.Itoa(tu.CallCount), + }, + }) + } + + // MCP failures become error-stage events. + for _, f := range processedRun.MCPFailures { + events = append(events, agentdrain.AgentEvent{ + Stage: "error", + Fields: map[string]string{ + "type": "mcp_failure", + "server": f.ServerName, + "status": f.Status, + }, + }) + } + + // Missing tools are capability-friction errors. + for _, mt := range processedRun.MissingTools { + events = append(events, agentdrain.AgentEvent{ + Stage: "error", + Fields: map[string]string{ + "type": "missing_tool", + "tool": mt.Tool, + "reason": mt.Reason, + }, + }) + } + + // Missing data is a different error class. + for _, md := range processedRun.MissingData { + events = append(events, agentdrain.AgentEvent{ + Stage: "error", + Fields: map[string]string{ + "type": "missing_data", + "data_type": md.DataType, + "reason": md.Reason, + }, + }) + } + + // No-ops map to tool_result stage. + for _, n := range processedRun.Noops { + events = append(events, agentdrain.AgentEvent{ + Stage: "tool_result", + Fields: map[string]string{ + "status": "noop", + "message": n.Message, + }, + }) + } + + // Synthesise a finish event only when the run has a meaningful conclusion. + conclusion := processedRun.Run.Conclusion + if conclusion == "" { + conclusion = processedRun.Run.Status + } + if conclusion != "" || metrics.TokenUsage > 0 { + events = append(events, agentdrain.AgentEvent{ + Stage: "finish", + Fields: map[string]string{ + "status": conclusion, + "tokens": strconv.Itoa(metrics.TokenUsage), + }, + }) + } + + return events +} + +// buildInsightsFromDrain3Analysis converts drain3 coordinator analysis into +// ObservabilityInsights for a single run. +func buildInsightsFromDrain3Analysis( + coordinator *agentdrain.Coordinator, + anomalies []struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport + }, + events []agentdrain.AgentEvent, +) []ObservabilityInsight { + var insights []ObservabilityInsight + + // Cluster summary insight. + allClusters := coordinator.AllClusters() + totalClusters := 0 + for _, cs := range allClusters { + totalClusters += len(cs) + } + if totalClusters > 0 { + stageBreakdown := buildStageBreakdown(allClusters) + insights = append(insights, ObservabilityInsight{ + Category: "execution", + Severity: "info", + Title: "Log template patterns mined", + Summary: fmt.Sprintf( + "Analysis identified %d distinct event templates across %d pipeline stages from %d events.", + totalClusters, len(allClusters), len(events), + ), + Evidence: stageBreakdown, + }) + } + + // Anomaly insight. + if len(anomalies) > 0 { + severity := "low" + if len(anomalies) >= 3 { + severity = "high" + } else if len(anomalies) >= 2 { + severity = "medium" + } + reasons := buildAnomalyReasons(anomalies) + insights = append(insights, ObservabilityInsight{ + Category: "reliability", + Severity: severity, + Title: fmt.Sprintf("%d anomalous event pattern(s) detected", len(anomalies)), + Summary: fmt.Sprintf( + "Anomaly detection flagged %d event(s) as unusual based on template similarity and cluster rarity.", + len(anomalies), + ), + Evidence: reasons, + }) + } + + // Stage sequence insight. + sequence := agentdrain.StageSequence(events) + if sequence != "" { + insights = append(insights, ObservabilityInsight{ + Category: "execution", + Severity: "info", + Title: "Agent stage sequence", + Summary: "The observed pipeline stage sequence for this run.", + Evidence: sequence, + }) + } + + return insights +} + +// buildMultiRunInsightsFromDrain3 converts cross-run drain3 analysis into insights. +func buildMultiRunInsightsFromDrain3( + coordinator *agentdrain.Coordinator, + highAnomalies []struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport + }, + runCount, totalEvents int, +) []ObservabilityInsight { + var insights []ObservabilityInsight + + allClusters := coordinator.AllClusters() + totalClusters := 0 + for _, cs := range allClusters { + totalClusters += len(cs) + } + + if totalClusters > 0 { + stageBreakdown := buildStageBreakdown(allClusters) + insights = append(insights, ObservabilityInsight{ + Category: "execution", + Severity: "info", + Title: "Cross-run log template patterns", + Summary: fmt.Sprintf( + "Mined %d distinct event templates across %d pipeline stages from %d events in %d runs.", + totalClusters, len(allClusters), totalEvents, runCount, + ), + Evidence: stageBreakdown, + }) + } + + if len(highAnomalies) > 0 { + severity := "medium" + if len(highAnomalies) >= 5 { + severity = "high" + } + reasons := buildAnomalyReasons(highAnomalies) + insights = append(insights, ObservabilityInsight{ + Category: "reliability", + Severity: severity, + Title: fmt.Sprintf("%d high-anomaly events across %d runs", len(highAnomalies), runCount), + Summary: fmt.Sprintf( + "Cross-run analysis flagged %d events with anomaly score > 0.6, indicating unusual patterns relative to the learned templates.", + len(highAnomalies), + ), + Evidence: reasons, + }) + } + + return insights +} + +// buildStageBreakdown builds a human-readable stage → cluster-count string. +func buildStageBreakdown(allClusters map[string][]agentdrain.Cluster) string { + if len(allClusters) == 0 { + return "" + } + parts := make([]string, 0, len(allClusters)) + for stage, clusters := range allClusters { + if len(clusters) > 0 { + parts = append(parts, fmt.Sprintf("%s=%d", stage, len(clusters))) + } + } + return strings.Join(parts, " ") +} + +// buildAnomalyReasons summarises anomaly reasons into a compact evidence string. +func buildAnomalyReasons(anomalies []struct { + evt agentdrain.AgentEvent + result *agentdrain.MatchResult + report *agentdrain.AnomalyReport +}) string { + reasons := make([]string, 0, len(anomalies)) + seen := make(map[string]bool) + for _, a := range anomalies { + r := fmt.Sprintf("stage=%s score=%.2f: %s", a.evt.Stage, a.report.AnomalyScore, a.report.Reason) + if !seen[r] { + reasons = append(reasons, r) + seen[r] = true + } + if len(reasons) >= 5 { + break + } + } + return strings.Join(reasons, "; ") +} diff --git a/pkg/cli/drain3_integration_test.go b/pkg/cli/drain3_integration_test.go new file mode 100644 index 0000000000..0051f7e4ff --- /dev/null +++ b/pkg/cli/drain3_integration_test.go @@ -0,0 +1,222 @@ +//go:build !integration + +package cli + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBuildDrain3Insights_NoEvents(t *testing.T) { + // A ProcessedRun with no meaningful events should return no insights. + processedRun := ProcessedRun{} + metrics := MetricsData{} + toolUsage := []ToolUsageInfo{} + + insights := buildDrain3Insights(processedRun, metrics, toolUsage) + assert.Empty(t, insights, "expected no insights when run has no events") +} + +func TestBuildDrain3Insights_BasicRun(t *testing.T) { + processedRun := ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: 42, + Conclusion: "success", + Turns: 5, + TokenUsage: 1200, + }, + } + metrics := MetricsData{ + Turns: 5, + TokenUsage: 1200, + ErrorCount: 0, + } + toolUsage := []ToolUsageInfo{ + {Name: "bash", CallCount: 3}, + {Name: "github_issue_read", CallCount: 2}, + } + + insights := buildDrain3Insights(processedRun, metrics, toolUsage) + require.NotEmpty(t, insights, "expected at least one drain3 insight for a run with events") + + // Verify all insights have required fields. + for _, ins := range insights { + assert.NotEmpty(t, ins.Category, "insight Category must not be empty") + assert.NotEmpty(t, ins.Severity, "insight Severity must not be empty") + assert.NotEmpty(t, ins.Title, "insight Title must not be empty") + assert.NotEmpty(t, ins.Summary, "insight Summary must not be empty") + } +} + +func TestBuildDrain3Insights_WithErrors(t *testing.T) { + processedRun := ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: 99, + Conclusion: "failure", + Turns: 8, + ErrorCount: 2, + }, + MCPFailures: []MCPFailureReport{ + {ServerName: "github", Status: "connection_refused"}, + {ServerName: "search", Status: "timeout"}, + }, + MissingTools: []MissingToolReport{ + {Tool: "terraform", Reason: "not installed"}, + }, + } + metrics := MetricsData{ + Turns: 8, + ErrorCount: 2, + } + + insights := buildDrain3Insights(processedRun, metrics, nil) + require.NotEmpty(t, insights, "expected insights when run has MCP failures and missing tools") + + categories := make([]string, 0, len(insights)) + for _, ins := range insights { + categories = append(categories, ins.Category) + } + // Should have execution and/or reliability categories. + assert.Contains(t, categories, "execution", "expected an 'execution' category insight") +} + +func TestBuildDrain3Insights_StageSequenceEvidence(t *testing.T) { + processedRun := ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: 7, + Conclusion: "success", + Turns: 3, + }, + } + metrics := MetricsData{Turns: 3, TokenUsage: 500} + toolUsage := []ToolUsageInfo{ + {Name: "search", CallCount: 1}, + } + + insights := buildDrain3Insights(processedRun, metrics, toolUsage) + require.NotEmpty(t, insights, "expected insights to be generated") + + // One insight should carry the stage-sequence evidence. + var found bool + for _, ins := range insights { + if ins.Title == "Agent stage sequence" { + assert.NotEmpty(t, ins.Evidence, "stage sequence insight should have evidence") + found = true + break + } + } + assert.True(t, found, "expected a 'Agent stage sequence' insight") +} + +func TestBuildDrain3InsightsMultiRun_Empty(t *testing.T) { + insights := buildDrain3InsightsMultiRun(nil) + assert.Empty(t, insights, "expected no insights for nil runs slice") + + insights = buildDrain3InsightsMultiRun([]ProcessedRun{}) + assert.Empty(t, insights, "expected no insights for empty runs slice") +} + +func TestBuildDrain3InsightsMultiRun_MultipleRuns(t *testing.T) { + runs := []ProcessedRun{ + { + Run: WorkflowRun{ + DatabaseID: 1, + Conclusion: "success", + Turns: 5, + TokenUsage: 1000, + ErrorCount: 0, + }, + MCPFailures: []MCPFailureReport{}, + MissingTools: []MissingToolReport{}, + }, + { + Run: WorkflowRun{ + DatabaseID: 2, + Conclusion: "failure", + Turns: 10, + TokenUsage: 2000, + ErrorCount: 3, + }, + MCPFailures: []MCPFailureReport{ + {ServerName: "github", Status: "error"}, + }, + MissingTools: []MissingToolReport{ + {Tool: "docker", Reason: "not found"}, + }, + }, + { + Run: WorkflowRun{ + DatabaseID: 3, + Conclusion: "success", + Turns: 4, + TokenUsage: 800, + }, + }, + } + + insights := buildDrain3InsightsMultiRun(runs) + require.NotEmpty(t, insights, "expected insights from multi-run analysis") + + for _, ins := range insights { + assert.NotEmpty(t, ins.Category, "insight Category must not be empty") + assert.NotEmpty(t, ins.Severity, "insight Severity must not be empty") + assert.NotEmpty(t, ins.Title, "insight Title must not be empty") + } +} + +func TestBuildAgentEventsFromProcessedRun(t *testing.T) { + pr := ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: 5, + Conclusion: "success", + Turns: 4, + TokenUsage: 900, + }, + MCPFailures: []MCPFailureReport{{ServerName: "s3", Status: "timeout"}}, + MissingTools: []MissingToolReport{{Tool: "kubectl", Reason: "missing"}}, + MissingData: []MissingDataReport{{DataType: "env_var", Reason: "undefined"}}, + Noops: []NoopReport{{Message: "already up to date"}}, + } + metrics := MetricsData{Turns: 4, TokenUsage: 900} + toolUsage := []ToolUsageInfo{ + {Name: "bash", CallCount: 2}, + } + + events := buildAgentEventsFromProcessedRun(pr, metrics, toolUsage) + require.NotEmpty(t, events, "expected events to be generated") + + stages := make(map[string]int) + for _, e := range events { + stages[e.Stage]++ + } + + assert.Positive(t, stages["plan"], "expected at least one plan event") + assert.Positive(t, stages["tool_call"], "expected at least one tool_call event") + assert.Positive(t, stages["error"], "expected at least one error event") + assert.Positive(t, stages["tool_result"], "expected at least one tool_result event (noop)") + assert.Positive(t, stages["finish"], "expected at least one finish event") +} + +func TestBuildDrain3Insights_IncludedInAuditData(t *testing.T) { + // Verify that buildAuditData appends drain3 insights to ObservabilityInsights. + processedRun := ProcessedRun{ + Run: WorkflowRun{ + DatabaseID: 10, + Conclusion: "success", + Turns: 3, + TokenUsage: 500, + }, + } + metrics := MetricsData{Turns: 3, TokenUsage: 500} + toolUsage := []ToolUsageInfo{{Name: "bash", CallCount: 1}} + + // Combine the two pipelines as done in buildAuditData. + existing := buildAuditObservabilityInsights(processedRun, metrics, toolUsage, nil) + drain3 := buildDrain3Insights(processedRun, metrics, toolUsage) + all := append(existing, drain3...) + + // We should have at least the drain3 insights. + assert.GreaterOrEqual(t, len(all), len(drain3), "combined insights should include drain3 results") +} diff --git a/pkg/cli/drain3_train.go b/pkg/cli/drain3_train.go new file mode 100644 index 0000000000..3737c35652 --- /dev/null +++ b/pkg/cli/drain3_train.go @@ -0,0 +1,95 @@ +package cli + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/github/gh-aw/pkg/agentdrain" + "github.com/github/gh-aw/pkg/console" + "github.com/github/gh-aw/pkg/logger" +) + +var drain3TrainLog = logger.New("cli:drain3_train") + +// drain3WeightsFilename is the output filename for the trained weights. +const drain3WeightsFilename = "drain3_weights.json" + +// TrainDrain3Weights trains a Drain3 coordinator across all processed runs, +// serialises the resulting weights to drain3_weights.json in outputDir, and +// prints instructions on how to embed the file as default weights. +// +// This function is invoked when the user passes --train to the logs command. +func TrainDrain3Weights(processedRuns []ProcessedRun, outputDir string, verbose bool) error { + if len(processedRuns) == 0 { + fmt.Fprintln(os.Stderr, console.FormatWarningMessage("No processed runs available for log pattern training")) + return nil + } + + fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf("Training log pattern weights from %d run(s)...", len(processedRuns)))) + + cfg := agentdrain.DefaultConfig() + coordinator, err := agentdrain.NewCoordinator(cfg, defaultAgentDrainStages) + if err != nil { + return fmt.Errorf("log pattern training: create coordinator: %w", err) + } + + totalEvents := 0 + for _, pr := range processedRuns { + events := buildAgentEventsFromProcessedRun(pr, MetricsData{ + Turns: pr.Run.Turns, + TokenUsage: pr.Run.TokenUsage, + EstimatedCost: pr.Run.EstimatedCost, + ErrorCount: pr.Run.ErrorCount, + WarningCount: pr.Run.WarningCount, + }, nil) + totalEvents += len(events) + for _, evt := range events { + if _, err := coordinator.TrainEvent(evt); err != nil { + drain3TrainLog.Printf("TrainEvent skipped: stage=%s err=%v", evt.Stage, err) + } + } + } + + if verbose { + allClusters := coordinator.AllClusters() + total := 0 + for _, cs := range allClusters { + total += len(cs) + } + fmt.Fprintln(os.Stderr, console.FormatInfoMessage(fmt.Sprintf( + "Trained %d events → %d clusters across %d stages", + totalEvents, total, len(allClusters), + ))) + } + + weightsData, err := coordinator.SaveWeightsJSON() + if err != nil { + return fmt.Errorf("log pattern training: serialize weights: %w", err) + } + + // Pretty-print the weights for readability. + var raw map[string]any + if unmarshalErr := json.Unmarshal(weightsData, &raw); unmarshalErr != nil { + drain3TrainLog.Printf("Could not unmarshal weights for pretty-printing: %v", unmarshalErr) + } else if pretty, marshalErr := json.MarshalIndent(raw, "", " "); marshalErr != nil { + drain3TrainLog.Printf("Could not indent weights JSON: %v", marshalErr) + } else { + weightsData = pretty + } + + outputPath := filepath.Join(outputDir, drain3WeightsFilename) + if err := os.WriteFile(outputPath, weightsData, 0o644); err != nil { + return fmt.Errorf("log pattern training: write weights file: %w", err) + } + + fmt.Fprintln(os.Stderr, console.FormatSuccessMessage("Log pattern weights written to: "+outputPath)) + fmt.Fprintln(os.Stderr, console.FormatInfoMessage( + "To embed these weights as default, copy the file and rebuild:\n"+ + " cp "+outputPath+" pkg/agentdrain/data/default_weights.json\n"+ + " make build", + )) + + return nil +} diff --git a/pkg/cli/drain3_train_test.go b/pkg/cli/drain3_train_test.go new file mode 100644 index 0000000000..e9fb7081f2 --- /dev/null +++ b/pkg/cli/drain3_train_test.go @@ -0,0 +1,113 @@ +//go:build !integration + +package cli + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTrainDrain3Weights_NoRuns(t *testing.T) { + tmpDir := t.TempDir() + err := TrainDrain3Weights(nil, tmpDir, false) + require.NoError(t, err, "should not error when no runs provided") + + // No weights file should be written. + _, statErr := os.Stat(filepath.Join(tmpDir, drain3WeightsFilename)) + assert.True(t, os.IsNotExist(statErr), "weights file should not be created for empty run list") +} + +func TestTrainDrain3Weights_WithRuns(t *testing.T) { + tmpDir := t.TempDir() + + runs := []ProcessedRun{ + { + Run: WorkflowRun{ + DatabaseID: 1, + Conclusion: "success", + Turns: 5, + TokenUsage: 1000, + }, + MCPFailures: []MCPFailureReport{{ServerName: "github", Status: "ok"}}, + MissingTools: []MissingToolReport{{Tool: "terraform", Reason: "missing"}}, + }, + { + Run: WorkflowRun{ + DatabaseID: 2, + Conclusion: "failure", + Turns: 8, + ErrorCount: 2, + }, + MCPFailures: []MCPFailureReport{{ServerName: "search", Status: "timeout"}}, + }, + } + + err := TrainDrain3Weights(runs, tmpDir, true) + require.NoError(t, err, "training should succeed with valid runs") + + // Weights file should be written. + weightsPath := filepath.Join(tmpDir, drain3WeightsFilename) + data, err := os.ReadFile(weightsPath) + require.NoError(t, err, "weights file should exist after training") + require.NotEmpty(t, data, "weights file should not be empty") + + // File should be valid JSON with stage keys. + var weights map[string]any + err = json.Unmarshal(data, &weights) + require.NoError(t, err, "weights file should be valid JSON") + assert.NotEmpty(t, weights, "weights JSON should have stage keys") + + // Should contain at least one of the expected stage keys. + expectedStages := []string{"plan", "tool_call", "error", "finish"} + for _, stage := range expectedStages { + assert.Contains(t, weights, stage, "weights should contain stage %q", stage) + } +} + +func TestTrainDrain3Weights_JSONStructure(t *testing.T) { + tmpDir := t.TempDir() + + runs := []ProcessedRun{ + { + Run: WorkflowRun{ + DatabaseID: 1, + Conclusion: "success", + Turns: 3, + TokenUsage: 500, + }, + }, + } + + err := TrainDrain3Weights(runs, tmpDir, false) + require.NoError(t, err, "training should not error") + + weightsPath := filepath.Join(tmpDir, drain3WeightsFilename) + data, err := os.ReadFile(weightsPath) + require.NoError(t, err, "weights file should exist") + + // Should be a map of stage → snapshot objects. + var weights map[string]json.RawMessage + err = json.Unmarshal(data, &weights) + require.NoError(t, err, "should unmarshal to map of raw JSON") + + // Each value should itself be a valid JSON object. + for stage, raw := range weights { + var snap map[string]any + require.NoError(t, json.Unmarshal(raw, &snap), "stage %q snapshot should be valid JSON", stage) + assert.Contains(t, snap, "clusters", "stage %q snapshot should have clusters key", stage) + assert.Contains(t, snap, "config", "stage %q snapshot should have config key", stage) + } +} + +func TestLogsCommandHasTrainFlag(t *testing.T) { + cmd := NewLogsCommand() + flag := cmd.Flags().Lookup("train") + require.NotNil(t, flag, "logs command should have --train flag") + assert.Equal(t, "bool", flag.Value.Type(), "--train flag should be a bool") + assert.Equal(t, "false", flag.DefValue, "--train flag should default to false") +} diff --git a/pkg/cli/logs_ci_scenario_test.go b/pkg/cli/logs_ci_scenario_test.go index 766c95a95d..b648b5638e 100644 --- a/pkg/cli/logs_ci_scenario_test.go +++ b/pkg/cli/logs_ci_scenario_test.go @@ -52,6 +52,7 @@ func TestLogsJSONOutputWithNoRuns(t *testing.T) { "summary.json", // summaryFile "", // safeOutputType false, // filteredIntegrity + false, // train ) // Restore stdout and read output diff --git a/pkg/cli/logs_command.go b/pkg/cli/logs_command.go index e8564953d6..fa5558e35d 100644 --- a/pkg/cli/logs_command.go +++ b/pkg/cli/logs_command.go @@ -140,6 +140,7 @@ Examples: summaryFile, _ := cmd.Flags().GetString("summary-file") safeOutputType, _ := cmd.Flags().GetString("safe-output") filteredIntegrity, _ := cmd.Flags().GetBool("filtered-integrity") + train, _ := cmd.Flags().GetBool("train") // Resolve relative dates to absolute dates for GitHub CLI now := time.Now() @@ -172,9 +173,9 @@ Examples: } } - logsCommandLog.Printf("Executing logs download: workflow=%s, count=%d, engine=%s", workflowName, count, engine) + logsCommandLog.Printf("Executing logs download: workflow=%s, count=%d, engine=%s, train=%v", workflowName, count, engine, train) - return DownloadWorkflowLogs(cmd.Context(), workflowName, count, startDate, endDate, outputDir, engine, ref, beforeRunID, afterRunID, repoOverride, verbose, toolGraph, noStaged, firewallOnly, noFirewall, parse, jsonOutput, timeout, summaryFile, safeOutputType, filteredIntegrity) + return DownloadWorkflowLogs(cmd.Context(), workflowName, count, startDate, endDate, outputDir, engine, ref, beforeRunID, afterRunID, repoOverride, verbose, toolGraph, noStaged, firewallOnly, noFirewall, parse, jsonOutput, timeout, summaryFile, safeOutputType, filteredIntegrity, train) }, } @@ -198,6 +199,7 @@ Examples: addJSONFlag(logsCmd) logsCmd.Flags().Int("timeout", 0, "Download timeout in minutes (0 = no timeout)") logsCmd.Flags().String("summary-file", "summary.json", "Path to write the summary JSON file relative to output directory (use empty string to disable)") + logsCmd.Flags().Bool("train", false, "Train drain3 log template weights from downloaded runs and write drain3_weights.json to the output directory") logsCmd.MarkFlagsMutuallyExclusive("firewall", "no-firewall") // Register completions for logs command diff --git a/pkg/cli/logs_download_test.go b/pkg/cli/logs_download_test.go index 117f0870ab..a7e67407b1 100644 --- a/pkg/cli/logs_download_test.go +++ b/pkg/cli/logs_download_test.go @@ -21,7 +21,7 @@ func TestDownloadWorkflowLogs(t *testing.T) { // Test the DownloadWorkflowLogs function // This should either fail with auth error (if not authenticated) // or succeed with no results (if authenticated but no workflows match) - err := DownloadWorkflowLogs(context.Background(), "", 1, "", "", "./test-logs", "", "", 0, 0, "", false, false, false, false, false, false, false, 0, "summary.json", "", false) + err := DownloadWorkflowLogs(context.Background(), "", 1, "", "", "./test-logs", "", "", 0, 0, "", false, false, false, false, false, false, false, 0, "summary.json", "", false, false) // If GitHub CLI is authenticated, the function may succeed but find no results // If not authenticated, it should return an auth error @@ -308,7 +308,7 @@ func TestDownloadWorkflowLogsWithEngineFilter(t *testing.T) { if !tt.expectError { // For valid engines, test that the function can be called without panic // It may still fail with auth errors, which is expected - err := DownloadWorkflowLogs(context.Background(), "", 1, "", "", "./test-logs", tt.engine, "", 0, 0, "", false, false, false, false, false, false, false, 0, "summary.json", "", false) + err := DownloadWorkflowLogs(context.Background(), "", 1, "", "", "./test-logs", tt.engine, "", 0, 0, "", false, false, false, false, false, false, false, 0, "summary.json", "", false, false) // Clean up any created directories os.RemoveAll("./test-logs") diff --git a/pkg/cli/logs_json_stderr_order_test.go b/pkg/cli/logs_json_stderr_order_test.go index b3f77c9f7b..c0c7c59a2a 100644 --- a/pkg/cli/logs_json_stderr_order_test.go +++ b/pkg/cli/logs_json_stderr_order_test.go @@ -59,6 +59,7 @@ func TestLogsJSONOutputBeforeStderr(t *testing.T) { "summary.json", // summaryFile "", // safeOutputType false, // filteredIntegrity + false, // train ) // Close writers first @@ -180,6 +181,7 @@ func TestLogsJSONAndStderrRedirected(t *testing.T) { "summary.json", "", // safeOutputType false, // filteredIntegrity + false, // train ) // Close the writer diff --git a/pkg/cli/logs_orchestrator.go b/pkg/cli/logs_orchestrator.go index 55c9d80113..09ddd92e32 100644 --- a/pkg/cli/logs_orchestrator.go +++ b/pkg/cli/logs_orchestrator.go @@ -42,8 +42,8 @@ func getMaxConcurrentDownloads() int { } // DownloadWorkflowLogs downloads and analyzes workflow logs with metrics -func DownloadWorkflowLogs(ctx context.Context, workflowName string, count int, startDate, endDate, outputDir, engine, ref string, beforeRunID, afterRunID int64, repoOverride string, verbose bool, toolGraph bool, noStaged bool, firewallOnly bool, noFirewall bool, parse bool, jsonOutput bool, timeout int, summaryFile string, safeOutputType string, filteredIntegrity bool) error { - logsOrchestratorLog.Printf("Starting workflow log download: workflow=%s, count=%d, startDate=%s, endDate=%s, outputDir=%s, summaryFile=%s, safeOutputType=%s, filteredIntegrity=%v", workflowName, count, startDate, endDate, outputDir, summaryFile, safeOutputType, filteredIntegrity) +func DownloadWorkflowLogs(ctx context.Context, workflowName string, count int, startDate, endDate, outputDir, engine, ref string, beforeRunID, afterRunID int64, repoOverride string, verbose bool, toolGraph bool, noStaged bool, firewallOnly bool, noFirewall bool, parse bool, jsonOutput bool, timeout int, summaryFile string, safeOutputType string, filteredIntegrity bool, train bool) error { + logsOrchestratorLog.Printf("Starting workflow log download: workflow=%s, count=%d, startDate=%s, endDate=%s, outputDir=%s, summaryFile=%s, safeOutputType=%s, filteredIntegrity=%v, train=%v", workflowName, count, startDate, endDate, outputDir, summaryFile, safeOutputType, filteredIntegrity, train) // Ensure .github/aw/logs/.gitignore exists on every invocation if err := ensureLogsGitignore(); err != nil { @@ -520,6 +520,13 @@ func DownloadWorkflowLogs(ctx context.Context, workflowName string, count int, s } } + // Train drain3 weights if requested. + if train { + if err := TrainDrain3Weights(processedRuns, outputDir, verbose); err != nil { + return fmt.Errorf("log pattern training: %w", err) + } + } + // Render output based on format preference if jsonOutput { if err := renderLogsJSON(logsData); err != nil { diff --git a/pkg/cli/logs_report.go b/pkg/cli/logs_report.go index 2565818d41..1c36afca46 100644 --- a/pkg/cli/logs_report.go +++ b/pkg/cli/logs_report.go @@ -310,6 +310,7 @@ func buildLogsData(processedRuns []ProcessedRun, outputDir string, continuation redactedDomains := buildRedactedDomainsSummary(processedRuns) observability := buildLogsObservabilityInsights(processedRuns, toolUsage) + observability = append(observability, buildDrain3InsightsMultiRun(processedRuns)...) absOutputDir, _ := filepath.Abs(outputDir)