diff --git a/docs/src/content/docs/reference/frontmatter-full.md b/docs/src/content/docs/reference/frontmatter-full.md index aa4c5451a2..7c5e30dff8 100644 --- a/docs/src/content/docs/reference/frontmatter-full.md +++ b/docs/src/content/docs/reference/frontmatter-full.md @@ -5620,10 +5620,14 @@ safe-outputs: # Option 2: undefined - # Array of extra job steps to run after detection + # Array of extra job steps to run before engine execution # (optional) steps: [] + # Array of extra job steps to run after engine execution + # (optional) + post-steps: [] + # Runner specification for the detection job. Overrides agent.runs-on for the # detection job only. Defaults to agent.runs-on. # (optional) diff --git a/docs/src/content/docs/reference/threat-detection.md b/docs/src/content/docs/reference/threat-detection.md index f1a43b9813..ca7fab40fb 100644 --- a/docs/src/content/docs/reference/threat-detection.md +++ b/docs/src/content/docs/reference/threat-detection.md @@ -77,7 +77,10 @@ safe-outputs: threat-detection: enabled: true # Enable/disable detection prompt: "Focus on SQL injection" # Additional analysis instructions - steps: # Custom detection steps + steps: # Custom steps run before engine execution + - name: Setup Security Gateway + run: echo "Connecting to security gateway..." + post-steps: # Custom steps run after engine execution - name: Custom Security Check run: echo "Running additional checks" ``` @@ -90,7 +93,8 @@ safe-outputs: | `prompt` | string | Custom instructions appended to default detection prompt | | `engine` | string/object/false | AI engine config (`"copilot"`, full config object, or `false` for no AI) | | `runs-on` | string/array/object | Runner for the detection job (default: inherits from workflow `runs-on`) | -| `steps` | array | Additional GitHub Actions steps to run after AI analysis | +| `steps` | array | Additional GitHub Actions steps to run **before** AI analysis (pre-steps) | +| `post-steps` | array | Additional GitHub Actions steps to run **after** AI analysis (post-steps) | ## AI-Based Detection (Default) @@ -186,13 +190,32 @@ safe-outputs: ## Custom Detection Steps -Add specialized security scanning tools alongside or instead of AI detection: +Add specialized security scanning tools alongside or instead of AI detection. You can run steps **before** the AI engine (for setup, gateway connections, etc.) and steps **after** (for additional scanning based on AI results). + +### Pre-Steps (`steps:`) + +Steps defined under `steps:` run **before** the AI engine executes. Use these for setup tasks such as connecting to a private AI gateway, installing security tools, or preparing artifacts. ```yaml wrap safe-outputs: create-pull-request: threat-detection: steps: + - name: Connect to Security Gateway + run: | + echo "Setting up secure connection to analysis gateway..." + # Authentication and connection setup +``` + +### Post-Steps (`post-steps:`) + +Steps defined under `post-steps:` run **after** the AI engine completes its analysis. Use these for additional security scanning, reporting, or cleanup. + +```yaml wrap +safe-outputs: + create-pull-request: + threat-detection: + post-steps: - name: Run Security Scanner run: | echo "Scanning agent output for threats..." @@ -206,11 +229,11 @@ safe-outputs: **Available Artifacts:** Custom steps have access to `/tmp/gh-aw/threat-detection/prompt.txt` (workflow prompt), `agent_output.json` (safe output items), and `aw.patch` (git patch file). -**Execution Order:** Download artifacts → Run AI analysis (if enabled) → Execute custom steps → Upload detection log. +**Execution Order:** Download artifacts → Execute pre-steps (`steps:`) → Run AI analysis (if enabled) → Execute post-steps (`post-steps:`) → Upload detection log. ## Example: LlamaGuard Integration -Use Ollama with LlamaGuard 3 for specialized threat detection: +Use Ollama with LlamaGuard 3 for specialized threat detection running after AI analysis: ```yaml wrap --- @@ -219,7 +242,7 @@ engine: copilot safe-outputs: create-pull-request: threat-detection: - steps: + post-steps: - name: Ollama LlamaGuard 3 Scan uses: actions/github-script@v8 with: @@ -261,7 +284,7 @@ safe-outputs: threat-detection: prompt: "Check for authentication bypass vulnerabilities" engine: copilot - steps: + post-steps: - name: Static Analysis run: | # Run static analysis tool @@ -273,6 +296,24 @@ safe-outputs: path: /tmp/gh-aw/threat-detection/aw.patch ``` +## Example: Private AI Gateway + +Connect to a private AI gateway before running the detection engine: + +```yaml wrap +safe-outputs: + create-pull-request: + threat-detection: + steps: + - name: Connect to AI Gateway + run: | + # Authenticate and set up connection to private AI gateway + echo "Setting up gateway connection..." + ./scripts/setup-gateway.sh + engine: + id: copilot +``` + ## Error Handling **When Threats Are Detected:** diff --git a/pkg/parser/schemas/main_workflow_schema.json b/pkg/parser/schemas/main_workflow_schema.json index 17af615ce6..09794f5d5d 100644 --- a/pkg/parser/schemas/main_workflow_schema.json +++ b/pkg/parser/schemas/main_workflow_schema.json @@ -7756,7 +7756,14 @@ }, "steps": { "type": "array", - "description": "Array of extra job steps to run after detection", + "description": "Array of extra job steps to run before engine execution", + "items": { + "$ref": "#/$defs/githubActionsStep" + } + }, + "post-steps": { + "type": "array", + "description": "Array of extra job steps to run after engine execution", "items": { "$ref": "#/$defs/githubActionsStep" } diff --git a/pkg/workflow/threat_detection.go b/pkg/workflow/threat_detection.go index cf8136239f..a83795ba2d 100644 --- a/pkg/workflow/threat_detection.go +++ b/pkg/workflow/threat_detection.go @@ -3,6 +3,7 @@ package workflow import ( "encoding/json" "fmt" + "maps" "strings" "github.com/github/gh-aw/pkg/constants" @@ -14,7 +15,8 @@ var threatLog = logger.New("workflow:threat_detection") // ThreatDetectionConfig holds configuration for threat detection in agent output type ThreatDetectionConfig struct { Prompt string `yaml:"prompt,omitempty"` // Additional custom prompt instructions to append - Steps []any `yaml:"steps,omitempty"` // Array of extra job steps + Steps []any `yaml:"steps,omitempty"` // Array of extra job steps to run before engine execution + PostSteps []any `yaml:"post-steps,omitempty"` // Array of extra job steps to run after engine execution EngineConfig *EngineConfig `yaml:"engine-config,omitempty"` // Extended engine configuration for threat detection EngineDisabled bool `yaml:"-"` // Internal flag: true when engine is explicitly set to false RunsOn string `yaml:"runs-on,omitempty"` // Runner override for the detection job @@ -24,7 +26,7 @@ type ThreatDetectionConfig struct { // that actually executes. Returns false when the engine is disabled and no // custom steps are configured, since the job would have nothing to run. func (td *ThreatDetectionConfig) HasRunnableDetection() bool { - return !td.EngineDisabled || len(td.Steps) > 0 + return !td.EngineDisabled || len(td.Steps) > 0 || len(td.PostSteps) > 0 } // IsDetectionJobEnabled reports whether a detection job should be created for @@ -108,13 +110,20 @@ func (c *Compiler) parseThreatDetectionConfig(outputMap map[string]any) *ThreatD } } - // Parse steps field + // Parse steps field (pre-execution steps, run before engine execution) if steps, exists := configMap["steps"]; exists { if stepsArray, ok := steps.([]any); ok { threatConfig.Steps = stepsArray } } + // Parse post-steps field (post-execution steps, run after engine execution) + if postSteps, exists := configMap["post-steps"]; exists { + if postStepsArray, ok := postSteps.([]any); ok { + threatConfig.PostSteps = postStepsArray + } + } + // Parse runs-on field if runOn, exists := configMap["runs-on"]; exists { if runOnStr, ok := runOn.(string); ok { @@ -144,7 +153,7 @@ func (c *Compiler) parseThreatDetectionConfig(outputMap map[string]any) *ThreatD } } - threatLog.Printf("Threat detection configured with custom prompt: %v, custom steps: %v", threatConfig.Prompt != "", len(threatConfig.Steps) > 0) + threatLog.Printf("Threat detection configured with custom prompt: %v, custom pre-steps: %v, custom post-steps: %v", threatConfig.Prompt != "", len(threatConfig.Steps) > 0, len(threatConfig.PostSteps) > 0) return threatConfig } } @@ -186,21 +195,26 @@ func (c *Compiler) buildDetectionJobSteps(data *WorkflowData) []string { // Step 3: Prepare files - copies agent output files to expected paths steps = append(steps, c.buildPrepareDetectionFilesStep()...) - // Step 4: Setup threat detection (github-script) + // Step 4: Custom pre-steps if configured (run before engine execution) + if len(data.SafeOutputs.ThreatDetection.Steps) > 0 { + steps = append(steps, c.buildCustomThreatDetectionSteps(data.SafeOutputs.ThreatDetection.Steps)...) + } + + // Step 5: Setup threat detection (github-script) steps = append(steps, c.buildThreatDetectionAnalysisStep(data)...) - // Step 5: Engine execution (AWF, no network) + // Step 6: Engine execution (AWF, no network) steps = append(steps, c.buildDetectionEngineExecutionStep(data)...) - // Step 6: Custom steps if configured - if len(data.SafeOutputs.ThreatDetection.Steps) > 0 { - steps = append(steps, c.buildCustomThreatDetectionSteps(data.SafeOutputs.ThreatDetection.Steps)...) + // Step 7: Custom post-steps if configured (run after engine execution) + if len(data.SafeOutputs.ThreatDetection.PostSteps) > 0 { + steps = append(steps, c.buildCustomThreatDetectionSteps(data.SafeOutputs.ThreatDetection.PostSteps)...) } - // Step 7: Upload detection-artifact + // Step 8: Upload detection-artifact steps = append(steps, c.buildUploadDetectionLogStep(data)...) - // Step 8: Parse results, log extensively, and set job conclusion (single JS step) + // Step 9: Parse results, log extensively, and set job conclusion (single JS step) steps = append(steps, c.buildDetectionConclusionStep()...) threatLog.Printf("Generated %d detection job step lines", len(steps)) @@ -554,10 +568,21 @@ await main();` } // buildCustomThreatDetectionSteps builds YAML steps from user-configured threat detection steps. +// It injects the detection guard condition into each step unless an explicit if: condition is +// already set, ensuring custom steps only run when the detection_guard determines that detection +// should proceed and preventing unexpected side effects in runs with no agent outputs to analyze. func (c *Compiler) buildCustomThreatDetectionSteps(steps []any) []string { var result []string for _, step := range steps { if stepMap, ok := step.(map[string]any); ok { + // Inject the detection guard condition unless the user already provided an if: condition. + if _, hasIf := stepMap["if"]; !hasIf { + // Clone the map to avoid mutating the original config. + injected := make(map[string]any, len(stepMap)+1) + maps.Copy(injected, stepMap) + injected["if"] = detectionStepCondition + stepMap = injected + } if stepYAML, err := ConvertStepToYAML(stepMap); err == nil { result = append(result, stepYAML) } diff --git a/pkg/workflow/threat_detection_test.go b/pkg/workflow/threat_detection_test.go index ad5a0b8f12..7b15305565 100644 --- a/pkg/workflow/threat_detection_test.go +++ b/pkg/workflow/threat_detection_test.go @@ -76,6 +76,27 @@ func TestParseThreatDetectionConfig(t *testing.T) { }, }, }, + { + name: "object with custom post-steps", + outputMap: map[string]any{ + "threat-detection": map[string]any{ + "post-steps": []any{ + map[string]any{ + "name": "Custom post validation", + "run": "echo 'Post validating...'", + }, + }, + }, + }, + expectedConfig: &ThreatDetectionConfig{ + PostSteps: []any{ + map[string]any{ + "name": "Custom post validation", + "run": "echo 'Post validating...'", + }, + }, + }, + }, { name: "object with custom prompt", outputMap: map[string]any{ @@ -88,15 +109,21 @@ func TestParseThreatDetectionConfig(t *testing.T) { }, }, { - name: "object with all overrides", + name: "object with all overrides including pre and post steps", outputMap: map[string]any{ "threat-detection": map[string]any{ "enabled": true, "prompt": "Check for backdoor installations.", "steps": []any{ map[string]any{ - "name": "Extra step", - "uses": "actions/custom@v1", + "name": "Pre step", + "uses": "actions/setup@v1", + }, + }, + "post-steps": []any{ + map[string]any{ + "name": "Post step", + "uses": "actions/cleanup@v1", }, }, }, @@ -105,8 +132,14 @@ func TestParseThreatDetectionConfig(t *testing.T) { Prompt: "Check for backdoor installations.", Steps: []any{ map[string]any{ - "name": "Extra step", - "uses": "actions/custom@v1", + "name": "Pre step", + "uses": "actions/setup@v1", + }, + }, + PostSteps: []any{ + map[string]any{ + "name": "Post step", + "uses": "actions/cleanup@v1", }, }, }, @@ -146,6 +179,10 @@ func TestParseThreatDetectionConfig(t *testing.T) { t.Errorf("Expected %d steps, got %d", len(tt.expectedConfig.Steps), len(result.Steps)) } + if len(result.PostSteps) != len(tt.expectedConfig.PostSteps) { + t.Errorf("Expected %d post-steps, got %d", len(tt.expectedConfig.PostSteps), len(result.PostSteps)) + } + if result.RunsOn != tt.expectedConfig.RunsOn { t.Errorf("Expected RunsOn %q, got %q", tt.expectedConfig.RunsOn, result.RunsOn) } @@ -334,58 +371,202 @@ func TestThreatDetectionWithEngineConfig(t *testing.T) { func TestThreatDetectionStepsOrdering(t *testing.T) { compiler := NewCompiler() - data := &WorkflowData{ - SafeOutputs: &SafeOutputsConfig{ - ThreatDetection: &ThreatDetectionConfig{ - Steps: []any{ - map[string]any{ - "name": "Custom Threat Scan", - "run": "echo 'Custom scanning...'", + t.Run("pre-steps come before engine execution", func(t *testing.T) { + data := &WorkflowData{ + SafeOutputs: &SafeOutputsConfig{ + ThreatDetection: &ThreatDetectionConfig{ + Steps: []any{ + map[string]any{ + "name": "Custom Pre Scan", + "run": "echo 'Custom pre-scanning...'", + }, }, }, }, - }, - } + } - steps := compiler.buildDetectionJobSteps(data) + steps := compiler.buildDetectionJobSteps(data) - if len(steps) == 0 { - t.Fatal("Expected non-empty steps") - } + if len(steps) == 0 { + t.Fatal("Expected non-empty steps") + } - // Join all steps into a single string for easier verification - stepsString := strings.Join(steps, "") + // Join all steps into a single string for easier verification + stepsString := strings.Join(steps, "") - // Find the positions of key steps - customStepPos := strings.Index(stepsString, "Custom Threat Scan") - concludeStepPos := strings.Index(stepsString, "Parse and conclude threat detection") - uploadStepPos := strings.Index(stepsString, "Upload threat detection log") + // Find the positions of key steps + preStepPos := strings.Index(stepsString, "Custom Pre Scan") + setupStepPos := strings.Index(stepsString, "Setup threat detection") + uploadStepPos := strings.Index(stepsString, "Upload threat detection log") - // Verify all steps exist - if customStepPos == -1 { - t.Error("Expected to find 'Custom Threat Scan' step") - } - if concludeStepPos == -1 { - t.Error("Expected to find 'Parse and conclude threat detection' step") - } - if uploadStepPos == -1 { - t.Error("Expected to find 'Upload threat detection log' step") - } + // Verify all steps exist + if preStepPos == -1 { + t.Error("Expected to find 'Custom Pre Scan' step") + } + if setupStepPos == -1 { + t.Error("Expected to find 'Setup threat detection' step") + } + if uploadStepPos == -1 { + t.Error("Expected to find 'Upload threat detection log' step") + } + if !strings.Contains(stepsString, "Parse and conclude threat detection") { + t.Error("Expected to find 'Parse and conclude threat detection' step") + } - // Verify ordering: custom steps should come before upload step - if customStepPos > uploadStepPos { - t.Errorf("Custom threat detection steps should come before 'Upload threat detection log' step. Got custom at position %d, upload at position %d", customStepPos, uploadStepPos) - } + // Verify ordering: pre-steps should come before setup threat detection + if preStepPos > setupStepPos { + t.Errorf("Custom pre-steps should come before 'Setup threat detection'. Got pre-step at position %d, setup at position %d", preStepPos, setupStepPos) + } - // Verify ordering: upload step should come before the conclude step - if uploadStepPos > concludeStepPos { - t.Errorf("'Upload threat detection log' step should come before 'Parse and conclude threat detection' step. Got upload at position %d, conclude at position %d", uploadStepPos, concludeStepPos) - } + // Verify ordering: pre-steps should come before upload and conclude + if preStepPos > uploadStepPos { + t.Errorf("Custom pre-steps should come before 'Upload threat detection log'. Got pre-step at position %d, upload at position %d", preStepPos, uploadStepPos) + } + }) - // Verify the expected order: custom -> upload -> conclude - if customStepPos >= uploadStepPos || uploadStepPos >= concludeStepPos { - t.Errorf("Expected step order: custom steps < upload log < conclude. Got positions: custom=%d, upload=%d, conclude=%d", customStepPos, uploadStepPos, concludeStepPos) - } + t.Run("post-steps come after engine execution and before upload", func(t *testing.T) { + data := &WorkflowData{ + SafeOutputs: &SafeOutputsConfig{ + ThreatDetection: &ThreatDetectionConfig{ + PostSteps: []any{ + map[string]any{ + "name": "Custom Post Scan", + "run": "echo 'Custom post-scanning...'", + }, + }, + }, + }, + } + + steps := compiler.buildDetectionJobSteps(data) + + if len(steps) == 0 { + t.Fatal("Expected non-empty steps") + } + + stepsString := strings.Join(steps, "") + + postStepPos := strings.Index(stepsString, "Custom Post Scan") + // Use the engine execution step ID as the stable marker for the engine step boundary + engineStepPos := strings.Index(stepsString, "id: detection_agentic_execution") + uploadStepPos := strings.Index(stepsString, "Upload threat detection log") + concludeStepPos := strings.Index(stepsString, "Parse and conclude threat detection") + + if postStepPos == -1 { + t.Error("Expected to find 'Custom Post Scan' step") + } + if engineStepPos == -1 { + t.Error("Expected to find 'id: detection_agentic_execution' engine step") + } + if uploadStepPos == -1 { + t.Error("Expected to find 'Upload threat detection log' step") + } + if concludeStepPos == -1 { + t.Error("Expected to find 'Parse and conclude threat detection' step") + } + + // Verify ordering: post-steps should come after the engine execution step + if postStepPos < engineStepPos { + t.Errorf("Custom post-steps should come after engine execution step. Got post-step at position %d, engine at position %d", postStepPos, engineStepPos) + } + if postStepPos > uploadStepPos { + t.Errorf("Custom post-steps should come before 'Upload threat detection log'. Got post-step at position %d, upload at position %d", postStepPos, uploadStepPos) + } + if postStepPos > concludeStepPos { + t.Errorf("Custom post-steps should come before 'Parse and conclude threat detection'. Got post-step at position %d, conclude at position %d", postStepPos, concludeStepPos) + } + }) + + t.Run("pre-steps and post-steps both present in correct order", func(t *testing.T) { + data := &WorkflowData{ + SafeOutputs: &SafeOutputsConfig{ + ThreatDetection: &ThreatDetectionConfig{ + Steps: []any{ + map[string]any{ + "name": "Custom Pre Step", + "run": "echo 'pre'", + }, + }, + PostSteps: []any{ + map[string]any{ + "name": "Custom Post Step", + "run": "echo 'post'", + }, + }, + }, + }, + } + + steps := compiler.buildDetectionJobSteps(data) + stepsString := strings.Join(steps, "") + + preStepPos := strings.Index(stepsString, "Custom Pre Step") + postStepPos := strings.Index(stepsString, "Custom Post Step") + engineStepPos := strings.Index(stepsString, "id: detection_agentic_execution") + uploadStepPos := strings.Index(stepsString, "Upload threat detection log") + + if preStepPos == -1 { + t.Error("Expected to find 'Custom Pre Step'") + } + if postStepPos == -1 { + t.Error("Expected to find 'Custom Post Step'") + } + if engineStepPos == -1 { + t.Error("Expected to find 'id: detection_agentic_execution' engine step") + } + + // pre-steps before engine, post-steps after engine but before upload + if preStepPos > engineStepPos { + t.Errorf("Pre-steps should come before engine execution step. Got pre=%d, engine=%d", preStepPos, engineStepPos) + } + if postStepPos < engineStepPos { + t.Errorf("Post-steps should come after engine execution step. Got post=%d, engine=%d", postStepPos, engineStepPos) + } + if postStepPos > uploadStepPos { + t.Errorf("Post-steps should come before 'Upload threat detection log'. Got post=%d, upload=%d", postStepPos, uploadStepPos) + } + // pre-steps before post-steps + if preStepPos > postStepPos { + t.Errorf("Pre-steps should come before post-steps. Got pre=%d, post=%d", preStepPos, postStepPos) + } + }) +} + +func TestCustomThreatDetectionStepsGuardCondition(t *testing.T) { + compiler := NewCompiler() + + t.Run("injects detection guard condition when no if: present", func(t *testing.T) { + steps := []any{ + map[string]any{ + "name": "No If Step", + "run": "echo hello", + }, + } + result := compiler.buildCustomThreatDetectionSteps(steps) + stepsStr := strings.Join(result, "") + if !strings.Contains(stepsStr, detectionStepCondition) { + t.Errorf("Expected detection guard condition to be injected, got:\n%s", stepsStr) + } + }) + + t.Run("preserves user-provided if: condition", func(t *testing.T) { + userCondition := "always()" + steps := []any{ + map[string]any{ + "name": "User If Step", + "if": userCondition, + "run": "echo hello", + }, + } + result := compiler.buildCustomThreatDetectionSteps(steps) + stepsStr := strings.Join(result, "") + if strings.Contains(stepsStr, detectionStepCondition) { + t.Error("Expected detection guard condition NOT to be injected when user provides if:") + } + if !strings.Contains(stepsStr, userCondition) { + t.Errorf("Expected user if: condition %q to be preserved, got:\n%s", userCondition, stepsStr) + } + }) } func TestBuildDetectionEngineExecutionStepWithThreatDetectionEngine(t *testing.T) {