diff --git a/.gitignore b/.gitignore index f60e99f..adaff6e 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,9 @@ Thumbs.db *.key credentials.json secrets.yaml + +# Generated policy files +*-full-access-policy.json +*-iam-reference.json +*-policy-documentation.md +*-scp.json diff --git a/README.md b/README.md index 028396e..fb7b00e 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,98 @@ A single-binary Go tool for testing AWS IAM policies using scenario-based YAML c - Displays full statement JSON from source for failed tests - Optional --show-matched-success flag for passing tests +- **AI-powered policy generation** (NEW) + - Generate security-focused IAM policies from AWS documentation + - Scrapes actions, conditions, and resource types automatically + - Uses OpenAI-compatible LLM APIs to create compliant policies + - Produces documentation explaining each policy statement + - 24-hour caching of AWS documentation pages + +## Generate Command + +The `generate` command creates security-focused IAM policies by scraping AWS service authorization documentation and using an LLM to generate appropriate policy statements. + +### Features + +- **Automatic scraping** of IAM actions, condition keys, and resource types from AWS docs +- **Parallel batch processing** for faster policy generation +- **Customizable prompts** to specify your security requirements +- **24-hour caching** of documentation pages to reduce API calls +- **Three output files**: + - `{service}-iam-reference.json` - Scraped IAM data + - `{service}-full-access-policy.json` - Generated policy + - `{service}-policy-documentation.md` - Human-readable documentation + +### Usage + +```bash +politest generate [flags] + +Flags: + --url string AWS IAM documentation URL (required) + --base-url string OpenAI-compatible API base URL (required) + --model string LLM model name (required) + --api-key string API key for LLM service + --output string Output directory (default ".") + --prompt string Custom requirements/constraints for policy generation + --concurrency int Number of parallel batch requests (default 3) + --no-enrich Skip action description enrichment + --quiet Suppress progress output +``` + +### Example + +```bash +./politest generate \ + --url "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html" \ + --base-url "https://your-llm-api.example.com/api" \ + --model "claude-3-sonnet" \ + --api-key "$LLM_API_KEY" \ + --output "./output" \ + --concurrency 4 \ + --prompt "Require VPC endpoint conditions where supported. Deny access to \ +foundation models which require cross region inference. Include resource-level \ +permissions for all model invocation actions." +``` + +### Output + +The command generates three files: + +1. **IAM Reference** (`bedrock-iam-reference.json`) + - All scraped actions with descriptions and access levels + - Available condition keys and their types + - Resource types and ARN patterns + +2. **Generated Policy** (`bedrock-full-access-policy.json`) + - Security-focused IAM policy with grouped statements + - Includes conditions like `aws:SecureTransport` and MFA requirements + - Uses placeholder variables (e.g., `${AWS::AccountId}`, `${VpcEndpointId}`) + +3. **Documentation** (`bedrock-policy-documentation.md`) + - Explanation of each policy statement + - Variables that need to be configured + - Security summary and compliance considerations + - Usage recommendations + +### Placeholder Variables + +The generated policy uses consistent placeholder variables that you must replace: + +| Variable | Description | +|----------|-------------| +| `${AWS::AccountId}` | Your AWS account ID | +| `${AWS::Region}` | Target AWS region | +| `${VpcEndpointId}` | VPC endpoint ID for endpoint conditions | +| `${VpcId}` | VPC ID | +| `${OrgId}` | AWS Organization ID | +| `${PrincipalTag/Department}` | Principal tag values | +| `${ResourceTag/Environment}` | Resource tag values | + +### Caching + +AWS documentation pages are cached locally in `~/.cache/politest/` for 24 hours to reduce network requests during iterative policy development. + ## ⚠️ Understanding What politest Tests **politest is a pre-deployment validation tool that helps you catch IAM policy issues early, but it is NOT a replacement for integration testing in real AWS environments.** diff --git a/examples.sh b/examples.sh new file mode 100644 index 0000000..6fb713d --- /dev/null +++ b/examples.sh @@ -0,0 +1,41 @@ +#!/bin/bash +# Example politest generate commands + +# Bedrock - Invoke-only policy (least privilege for model invocation) +# With --scp flag, generates both identity policy AND Service Control Policy + +# Allowed models (update this list as needed) +ALLOWED_MODELS=( + "arn:aws:bedrock:eu-west-2::foundation-model/anthropic.claude-3-7-sonnet-20250219-v1:0" +) + +MODELS_LIST=$(IFS=', '; echo "${ALLOWED_MODELS[*]}") +ALLOWED_REGION="eu-west-2" + +go run . generate \ + --url "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html" \ + --base-url "https://api.openai.com" \ + --model "gpt-4o" \ + --api-key "$(cat ~/.ssh/openai_api_key)" \ + --scp \ + --prompt "I need a policy for a developer who can ONLY call specific Bedrock foundation models - nothing else. + +ALLOWED MODELS: ${MODELS_LIST} +ALLOWED REGION: ${ALLOWED_REGION} + +THE DEVELOPER SHOULD BE ABLE TO: +- Call the allowed models using InvokeModel and InvokeModelWithResponseStream +- Discover what models exist using Get*, Describe*, List* wildcards only (do NOT list individual Get/Describe/List actions) + +THE DEVELOPER MUST NOT BE ABLE TO: +- Call any models other than those listed above +- Use cross-region inference or global cross-region inference (CRIS) +- Create, update, delete, or manage any Bedrock resources +- Invoke agents, flows, or anything other than foundation models +- Do anything administrative + +POLICY SPLIT (SCP will be generated separately): +- Identity policy: Only ALLOW statements with actions and resources - no conditions like SecureTransport +- SCP: All security guardrails (SecureTransport, region locks, NotAction allowlist) + +For the identity policy: Keep it minimal - just actions and resources. Do NOT include SecureTransport conditions or any deny statements - those all go in the SCP. Use ONLY wildcards for read operations - do not list individual actions that are already covered by wildcards." diff --git a/go.mod b/go.mod index 0d7744f..5e4882e 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,13 @@ module politest -go 1.24 +go 1.24.0 + +toolchain go1.24.4 require ( github.com/aws/aws-sdk-go-v2/config v1.31.15 github.com/aws/aws-sdk-go-v2/service/iam v1.48.1 + golang.org/x/net v0.47.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -21,4 +24,9 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.3 // indirect github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 // indirect github.com/aws/smithy-go v1.23.1 // indirect + github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect + github.com/rivo/uniseg v0.4.7 // indirect + github.com/schollz/progressbar/v3 v3.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect + golang.org/x/term v0.37.0 // indirect ) diff --git a/go.sum b/go.sum index b062ab0..36ec190 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,18 @@ github.com/aws/aws-sdk-go-v2/service/sts v1.38.9 h1:Ekml5vGg6sHSZLZJQJagefnVe6Pm github.com/aws/aws-sdk-go-v2/service/sts v1.38.9/go.mod h1:/e15V+o1zFHWdH3u7lpI3rVBcxszktIKuHKCY2/py+k= github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db h1:62I3jR2EmQ4l5rM/4FEfDWcRD+abF5XlKShorW5LRoQ= +github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db/go.mod h1:l0dey0ia/Uv7NcFFVbCLtqEBQbrT4OCwCSKTEv6enCw= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/schollz/progressbar/v3 v3.18.0 h1:uXdoHABRFmNIjUfte/Ex7WtuyVslrw2wVPQmCN62HpA= +github.com/schollz/progressbar/v3 v3.18.0/go.mod h1:IsO3lpbaGuzh8zIMzgY3+J8l4C8GjO0Y9S69eFvNsec= +golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= +golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU= +golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/internal/generate.go b/internal/generate.go new file mode 100644 index 0000000..3e138a2 --- /dev/null +++ b/internal/generate.go @@ -0,0 +1,241 @@ +package internal + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// GenerateConfig holds configuration for the generate command +type GenerateConfig struct { + URL string // AWS documentation URL + BaseURL string // LLM API base URL + APIKey string // LLM API key + Model string // LLM model name + OutputDir string // Output directory for generated files + NoEnrich bool // Skip enrichment step + Quiet bool // Suppress progress output + UserPrompt string // User's custom requirements/constraints + Concurrency int // Number of parallel batch requests (default 3) + GenerateSCP bool // Generate companion SCP for org-wide guardrails +} + +// GenerateOutput holds the output of the generate command +type GenerateOutput struct { + ScrapedData *ScrapedIAMData `json:"scraped_data"` + Policy json.RawMessage `json:"policy"` + PolicyFile string `json:"policy_file"` + SCPPolicy json.RawMessage `json:"scp_policy,omitempty"` + SCPFile string `json:"scp_file,omitempty"` + ScrapedFile string `json:"scraped_file"` + DocsFile string `json:"docs_file"` + ServiceName string `json:"service_name"` +} + +// RunGenerate executes the generate command +func RunGenerate(cfg GenerateConfig, writer io.Writer) (*GenerateOutput, error) { + // Create progress reporter + var progress ProgressReporter + var consoleProgress *ConsoleProgress + if !cfg.Quiet { + consoleProgress = NewConsoleProgress(writer) + consoleProgress.Start() + progress = consoleProgress + } + + // Step 1: Scrape the AWS documentation + scrapedData, err := ScrapeIAMDocumentation(cfg.URL, progress) + if err != nil { + if consoleProgress != nil { + consoleProgress.Error(fmt.Sprintf("Scraping failed: %v", err)) + } + return nil, fmt.Errorf("failed to scrape documentation: %w", err) + } + + if progress != nil { + progress.SetStatus(fmt.Sprintf("Found %d actions for %s", len(scrapedData.Actions), scrapedData.ServicePrefix)) + } + + // Step 2: Create LLM client and generate policy + llmClient := NewLLMClient(cfg.BaseURL, cfg.APIKey, cfg.Model) + + // Optional: Enrich action descriptions with security context + if !cfg.NoEnrich { + _ = llmClient.EnrichActionDescriptions(scrapedData, progress) + } + + // Set concurrency (default to 3 if not specified) + concurrency := cfg.Concurrency + if concurrency <= 0 { + concurrency = 3 + } + + // Generate the security-focused policy + policyJSON, err := llmClient.GenerateSecurityPolicy(scrapedData, progress, cfg.UserPrompt, concurrency) + if err != nil { + if consoleProgress != nil { + consoleProgress.Error(fmt.Sprintf("Policy generation failed: %v", err)) + } + return nil, fmt.Errorf("failed to generate policy: %w", err) + } + + // Step 3: Write output files + if progress != nil { + progress.SetStatus("Writing output files...") + } + + output := &GenerateOutput{ + ScrapedData: scrapedData, + ServiceName: scrapedData.ServicePrefix, + } + + // Parse policy JSON for pretty-printing + var policyData any + if err := json.Unmarshal([]byte(policyJSON), &policyData); err != nil { + return nil, fmt.Errorf("invalid policy JSON: %w", err) + } + prettyPolicy, _ := json.MarshalIndent(policyData, "", " ") + output.Policy = prettyPolicy + + // Determine output directory + outputDir := cfg.OutputDir + if outputDir == "" { + outputDir = "." + } + if err := os.MkdirAll(outputDir, 0750); err != nil { + return nil, fmt.Errorf("failed to create output directory: %w", err) + } + + // Write scraped data file + scrapedFileName := fmt.Sprintf("%s-iam-reference.json", scrapedData.ServicePrefix) + scrapedFilePath := filepath.Join(outputDir, scrapedFileName) + scrapedJSON, _ := json.MarshalIndent(scrapedData, "", " ") + if err := os.WriteFile(scrapedFilePath, scrapedJSON, 0600); err != nil { + return nil, fmt.Errorf("failed to write scraped data: %w", err) + } + output.ScrapedFile = scrapedFilePath + + // Write policy file (always same name) + policyFileName := fmt.Sprintf("%s-full-access-policy.json", scrapedData.ServicePrefix) + policyFilePath := filepath.Join(outputDir, policyFileName) + if err := os.WriteFile(policyFilePath, prettyPolicy, 0600); err != nil { + return nil, fmt.Errorf("failed to write policy: %w", err) + } + output.PolicyFile = policyFilePath + + // Generate SCP if requested + var prettySCP []byte + var scpFilePath string + if cfg.GenerateSCP { + if progress != nil { + progress.SetStatus("Generating Service Control Policy (SCP)...") + } + scpJSON, err := llmClient.GenerateSCP(scrapedData, string(prettyPolicy), cfg.UserPrompt) + if err != nil { + if progress != nil { + progress.SetStatus(fmt.Sprintf("SCP generation failed: %v", err)) + } + // Non-fatal - continue without SCP + } else { + var scpData any + if err := json.Unmarshal([]byte(scpJSON), &scpData); err == nil { + prettySCP, _ = json.MarshalIndent(scpData, "", " ") + output.SCPPolicy = prettySCP + + scpFileName := fmt.Sprintf("%s-scp.json", scrapedData.ServicePrefix) + scpFilePath = filepath.Join(outputDir, scpFileName) + if err := os.WriteFile(scpFilePath, prettySCP, 0600); err != nil { + return nil, fmt.Errorf("failed to write SCP: %w", err) + } + output.SCPFile = scpFilePath + } + } + } + + // Generate documentation for the policy (and SCP if generated) + if progress != nil { + progress.SetStatus("Generating policy documentation...") + } + var docsMarkdown string + if cfg.GenerateSCP && len(prettySCP) > 0 { + docsMarkdown, err = llmClient.GenerateCombinedDocumentation(scrapedData, string(prettyPolicy), string(prettySCP)) + } else { + docsMarkdown, err = llmClient.GeneratePolicyDocumentation(scrapedData, string(prettyPolicy)) + } + if err != nil { + // Non-fatal - continue without docs + if progress != nil { + progress.SetStatus("Documentation generation skipped (error)") + } + } else { + docsFileName := fmt.Sprintf("%s-policy-documentation.md", scrapedData.ServicePrefix) + docsFilePath := filepath.Join(outputDir, docsFileName) + if err := os.WriteFile(docsFilePath, []byte(docsMarkdown), 0600); err != nil { + return nil, fmt.Errorf("failed to write documentation: %w", err) + } + output.DocsFile = docsFilePath + } + + // Complete + if consoleProgress != nil { + consoleProgress.Done(fmt.Sprintf("Generated policy for %s", scrapedData.ServiceName)) + } + + // Print summary + if !cfg.Quiet { + fmt.Fprintf(writer, "\n\033[1mGeneration Complete\033[0m\n") + fmt.Fprintf(writer, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Fprintf(writer, "Service: %s\n", scrapedData.ServiceName) + fmt.Fprintf(writer, "Service Prefix: %s\n", scrapedData.ServicePrefix) + fmt.Fprintf(writer, "Actions Found: %d\n", len(scrapedData.Actions)) + fmt.Fprintf(writer, "Condition Keys: %d\n", len(scrapedData.ConditionKeys)) + fmt.Fprintf(writer, "Resource Types: %d\n", len(scrapedData.ResourceTypes)) + fmt.Fprintf(writer, "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\n") + fmt.Fprintf(writer, "\033[32mOutput Files:\033[0m\n") + fmt.Fprintf(writer, " IAM Reference: %s\n", scrapedFilePath) + fmt.Fprintf(writer, " Policy: %s\n", policyFilePath) + if cfg.GenerateSCP && scpFilePath != "" { + fmt.Fprintf(writer, " SCP: %s\n", scpFilePath) + } + if output.DocsFile != "" { + fmt.Fprintf(writer, " Documentation: %s\n", output.DocsFile) + } + fmt.Fprintf(writer, "\n") + + // Show action summary by access level + accessLevels := make(map[string]int) + for _, action := range scrapedData.Actions { + level := action.AccessLevel + if level == "" { + level = "Unknown" + } + accessLevels[level]++ + } + fmt.Fprintf(writer, "\033[1mActions by Access Level:\033[0m\n") + for level, count := range accessLevels { + fmt.Fprintf(writer, " %-20s %d\n", level+":", count) + } + } + + return output, nil +} + +// ValidateGenerateConfig validates the generate configuration +func ValidateGenerateConfig(cfg GenerateConfig) error { + if cfg.URL == "" { + return fmt.Errorf("--url is required: AWS documentation URL") + } + if !strings.Contains(cfg.URL, "docs.aws.amazon.com/service-authorization") { + return fmt.Errorf("invalid URL: must be an AWS service authorization reference page\nExample: https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html") + } + if cfg.BaseURL == "" { + return fmt.Errorf("--base-url is required: OpenAI-compatible API base URL") + } + if cfg.Model == "" { + return fmt.Errorf("--model is required: LLM model name") + } + return nil +} diff --git a/internal/generate_test.go b/internal/generate_test.go new file mode 100644 index 0000000..a9b9c8a --- /dev/null +++ b/internal/generate_test.go @@ -0,0 +1,159 @@ +package internal + +import ( + "bytes" + "encoding/json" + "strings" + "testing" +) + +func TestValidateGenerateConfig(t *testing.T) { + tests := []struct { + name string + cfg GenerateConfig + wantErr bool + errMsg string + }{ + { + name: "valid config", + cfg: GenerateConfig{ + URL: "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html", + BaseURL: "https://api.openai.com", + Model: "gpt-4", + }, + wantErr: false, + }, + { + name: "missing URL", + cfg: GenerateConfig{ + BaseURL: "https://api.openai.com", + Model: "gpt-4", + }, + wantErr: true, + errMsg: "--url is required", + }, + { + name: "invalid URL (not AWS)", + cfg: GenerateConfig{ + URL: "https://example.com/not-aws", + BaseURL: "https://api.openai.com", + Model: "gpt-4", + }, + wantErr: true, + errMsg: "invalid URL", + }, + { + name: "missing BaseURL", + cfg: GenerateConfig{ + URL: "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html", + Model: "gpt-4", + }, + wantErr: true, + errMsg: "--base-url is required", + }, + { + name: "missing Model", + cfg: GenerateConfig{ + URL: "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html", + BaseURL: "https://api.openai.com", + }, + wantErr: true, + errMsg: "--model is required", + }, + { + name: "all fields empty", + cfg: GenerateConfig{}, + wantErr: true, + errMsg: "--url is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateGenerateConfig(tt.cfg) + if tt.wantErr { + if err == nil { + t.Error("ValidateGenerateConfig() expected error, got nil") + return + } + if tt.errMsg != "" && !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("ValidateGenerateConfig() error = %v, want error containing %q", err, tt.errMsg) + } + } else { + if err != nil { + t.Errorf("ValidateGenerateConfig() unexpected error: %v", err) + } + } + }) + } +} + +func TestGenerateConfigDefaults(t *testing.T) { + cfg := GenerateConfig{ + URL: "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html", + BaseURL: "https://api.openai.com", + Model: "gpt-4", + } + + // Test default values + if cfg.Concurrency != 0 { + t.Errorf("Default Concurrency = %d, want 0", cfg.Concurrency) + } + if cfg.NoEnrich != false { + t.Error("Default NoEnrich should be false") + } + if cfg.Quiet != false { + t.Error("Default Quiet should be false") + } + if cfg.GenerateSCP != false { + t.Error("Default GenerateSCP should be false") + } +} + +func TestGenerateOutput(t *testing.T) { + output := GenerateOutput{ + ScrapedData: &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + }, + Policy: json.RawMessage(`{"Version": "2012-10-17"}`), + PolicyFile: "/tmp/policy.json", + ScrapedFile: "/tmp/scraped.json", + DocsFile: "/tmp/docs.md", + ServiceName: "s3", + } + + if output.ServiceName != "s3" { + t.Errorf("ServiceName = %q, want s3", output.ServiceName) + } + if output.ScrapedData == nil { + t.Error("ScrapedData should not be nil") + } + if output.ScrapedData.ServicePrefix != "s3" { + t.Errorf("ScrapedData.ServicePrefix = %q, want s3", output.ScrapedData.ServicePrefix) + } +} + +func TestRunGenerateInvalidURLValidation(t *testing.T) { + // RunGenerate should fail with invalid URL due to ScrapeIAMDocumentation validation + var buf bytes.Buffer + cfg := GenerateConfig{ + URL: "https://example.com/not-aws", + BaseURL: "http://unused", + Model: "test-model", + OutputDir: t.TempDir(), + Quiet: true, + } + + _, err := RunGenerate(cfg, &buf) + if err == nil { + t.Error("RunGenerate() expected error for invalid URL, got nil") + } + if err != nil && !strings.Contains(err.Error(), "invalid URL") { + t.Errorf("RunGenerate() error = %v, want error about invalid URL", err) + } +} + +// Note: Integration tests for RunGenerate require mocking the entire HTTP layer +// including the AWS documentation URL validation in ScrapeIAMDocumentation. +// For comprehensive testing, use integration tests with real AWS documentation pages. diff --git a/internal/llm.go b/internal/llm.go new file mode 100644 index 0000000..aaa2bc6 --- /dev/null +++ b/internal/llm.go @@ -0,0 +1,1078 @@ +package internal + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "strings" + "sync" + "time" +) + +// LLMClient provides an interface to OpenAI-compatible LLM APIs +type LLMClient struct { + BaseURL string + APIKey string + Model string + Client *http.Client +} + +// ChatMessage represents a message in the chat completion format +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionRequest represents the request body for chat completions +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream"` +} + +// ChatCompletionResponse represents the response from chat completions +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Index int `json:"index"` + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code"` + } `json:"error,omitempty"` +} + +// NewLLMClient creates a new LLM client +func NewLLMClient(baseURL, apiKey, model string) *LLMClient { + // Ensure base URL doesn't have trailing slash + baseURL = strings.TrimSuffix(baseURL, "/") + + return &LLMClient{ + BaseURL: baseURL, + APIKey: apiKey, + Model: model, + Client: &http.Client{ + Timeout: 5 * time.Minute, // LLM calls can take a while for large prompts + }, + } +} + +// batchResult holds the result of processing a single batch +type batchResult struct { + index int + statements []json.RawMessage + err error +} + +// GenerateSecurityPolicy generates a security-focused IAM policy using the LLM +// It processes actions in batches with parallel execution for speed +func (c *LLMClient) GenerateSecurityPolicy(data *ScrapedIAMData, progress ProgressReporter, userPrompt string, concurrency int) (string, error) { + const batchSize = 30 // Process 30 actions at a time + + if concurrency <= 0 { + concurrency = 3 + } + + // Group actions by access level for logical batching + actionsByLevel := make(map[string][]IAMAction) + for _, action := range data.Actions { + level := action.AccessLevel + if level == "" { + level = "Unknown" + } + actionsByLevel[level] = append(actionsByLevel[level], action) + } + + // Create batches + var batches [][]IAMAction + accessLevelOrder := []string{"Read", "List", "Write", "Permissions management", "Tagging", "Unknown"} + var currentBatch []IAMAction + + for _, level := range accessLevelOrder { + actions := actionsByLevel[level] + for _, action := range actions { + currentBatch = append(currentBatch, action) + if len(currentBatch) >= batchSize { + batches = append(batches, currentBatch) + currentBatch = nil + } + } + } + if len(currentBatch) > 0 { + batches = append(batches, currentBatch) + } + + if progress != nil { + progress.SetStatus(fmt.Sprintf("Generating policy in %d batches (concurrency: %d)...", len(batches), concurrency)) + } + + // Process batches in parallel with limited concurrency + results := make(chan batchResult, len(batches)) + sem := make(chan struct{}, concurrency) + var wg sync.WaitGroup + + var completedMu sync.Mutex + completed := 0 + + for i, batch := range batches { + wg.Add(1) + go func(idx int, b []IAMAction) { + defer wg.Done() + + // Acquire semaphore + sem <- struct{}{} + defer func() { <-sem }() + + prompt := buildBatchPolicyPrompt(data, b, idx+1, len(batches), userPrompt) + + // Note: Some APIs (like OpenWebUI) don't support system messages properly + // So we prepend the system instructions to the user message + fullPrompt := getBatchSystemPrompt(userPrompt) + "\n\n---\n\n" + prompt + + response, err := c.ChatCompletion([]ChatMessage{ + { + Role: "user", + Content: fullPrompt, + }, + }) + + if err != nil { + results <- batchResult{index: idx, err: fmt.Errorf("batch %d: %w", idx+1, err)} + return + } + + // Extract statements from response + statements := extractStatementsFromResponse(response) + + // Update progress + completedMu.Lock() + completed++ + if progress != nil { + progress.SetStatus(fmt.Sprintf("Completed %d of %d batches...", completed, len(batches))) + progress.SetProgress(completed, len(batches)) + } + completedMu.Unlock() + + results <- batchResult{index: idx, statements: statements} + }(i, batch) + } + + // Wait for all batches to complete + go func() { + wg.Wait() + close(results) + }() + + // Collect results in order + batchResults := make([]batchResult, len(batches)) + for result := range results { + if result.err != nil { + return "", result.err + } + batchResults[result.index] = result + } + + // Sort by index and assemble statements + sort.Slice(batchResults, func(i, j int) bool { + return batchResults[i].index < batchResults[j].index + }) + + var allStatements []json.RawMessage + for _, br := range batchResults { + allStatements = append(allStatements, br.statements...) + } + + if progress != nil { + progress.SetProgress(len(batches), len(batches)) + progress.SetStatus("Consolidating policy statements...") + } + + // Consolidate: dedupe exact matches, then LLM-merge similar statements + consolidated, err := c.ConsolidatePolicy(allStatements, progress, concurrency) + if err != nil { + // Non-fatal: fall back to unconsolidated statements + consolidated = allStatements + } + + if progress != nil { + progress.SetStatus("Assembling final policy...") + } + + // Assemble final policy + policy := assembleFinalPolicy(consolidated) + return policy, nil +} + +func getBatchSystemPrompt(userPrompt string) string { + basePrompt := `You are an AWS IAM security expert generating IAM policy statements. + +CRITICAL: Output ONLY a valid JSON array of Statement objects - no markdown, no explanations, no code fences. +If no statements should be generated for this batch, output an empty array: []` + + // User requirements are PRIMARY - they override default behavior + if userPrompt != "" { + basePrompt += ` + +## USER REQUIREMENTS (PRIMARY - follow these first): +` + userPrompt + ` + +IMPORTANT: The user requirements above take precedence. If the user specifies: +- Which actions to include → ONLY generate statements for those actions, skip everything else +- Which actions to exclude → Do NOT generate statements for those actions +- Specific resources or ARNs → Use those exact resources, not wildcards +- To rely on implicit deny → Do NOT generate Allow statements for actions the user wants denied +- SCP or "SCP will handle" → Do NOT include SecureTransport conditions or deny statements - those go in the SCP + +If none of the actions in this batch match what the user wants, return an empty array [].` + } + + basePrompt += ` + +## POLICY GENERATION GUIDELINES: + +Efficiently group actions to minimize statement count: +1. Combine actions with the same access level and resource requirements into ONE statement +2. Use wildcards (e.g., "service:Get*", "service:List*", "service:Describe*") where actions share common prefixes +3. Only create separate statements when: + - Different resource types are required + - Different conditions are needed + - Logical security boundaries exist (read vs write vs admin) + +Statement requirements: +- Use descriptive Sid values (e.g., "AllowS3ReadOperations", "AllowBedrockModelInvocation") +- Apply aws:SecureTransport condition for network security +- Use MFA conditions for destructive/sensitive operations (Delete*, Update*, Put*) unless user says otherwise +- Use specific resource ARNs where the resource type is clear + +Use these placeholder variables where appropriate: +- ${AWS::AccountId} - AWS account ID +- ${AWS::Region} - AWS region +- ${VpcEndpointId} - VPC endpoint ID +- ${OrgId} - AWS Organization ID + +The policy should be suitable for regulated environments.` + + return basePrompt +} + +func buildBatchPolicyPrompt(data *ScrapedIAMData, batch []IAMAction, batchNum, totalBatches int, userPrompt string) string { + var sb strings.Builder + + sb.WriteString(fmt.Sprintf("Generate IAM policy statements for AWS service: %s (prefix: %s)\n", data.ServiceName, data.ServicePrefix)) + sb.WriteString(fmt.Sprintf("Batch %d of %d\n\n", batchNum, totalBatches)) + + // Group actions by access level to help LLM understand grouping + actionsByLevel := make(map[string][]IAMAction) + for _, action := range batch { + level := action.AccessLevel + if level == "" { + level = "Unknown" + } + actionsByLevel[level] = append(actionsByLevel[level], action) + } + + sb.WriteString("## Actions to include (grouped by access level):\n") + for level, actions := range actionsByLevel { + sb.WriteString(fmt.Sprintf("\n### %s Actions (%d):\n", level, len(actions))) + for _, action := range actions { + sb.WriteString(fmt.Sprintf("- %s:%s: %s\n", data.ServicePrefix, action.Name, action.Description)) + } + } + + if len(data.ConditionKeys) > 0 { + sb.WriteString("\n## Available Condition Keys:\n") + for i, key := range data.ConditionKeys { + if i >= 10 { + sb.WriteString(fmt.Sprintf("... and %d more\n", len(data.ConditionKeys)-10)) + break + } + sb.WriteString(fmt.Sprintf("- %s (%s)\n", key.Name, key.Type)) + } + } + + sb.WriteString("\nREMEMBER: Group actions efficiently! Use wildcards and combine similar actions. Output ONLY a JSON array.") + + return sb.String() +} + +func extractStatementsFromResponse(response string) []json.RawMessage { + response = strings.TrimSpace(response) + + // Remove markdown code fences if present + if strings.HasPrefix(response, "```json") { + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```") { + response = strings.TrimPrefix(response, "```") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } + + // Try to parse as array of statements + var statements []json.RawMessage + if err := json.Unmarshal([]byte(response), &statements); err == nil { + return statements + } + + // Try to find array in response + start := strings.Index(response, "[") + end := strings.LastIndex(response, "]") + if start >= 0 && end > start { + candidate := response[start : end+1] + if err := json.Unmarshal([]byte(candidate), &statements); err == nil { + return statements + } + } + + return nil +} + +func assembleFinalPolicy(statements []json.RawMessage) string { + policy := map[string]any{ + "Version": "2012-10-17", + "Statement": statements, + } + + result, err := json.MarshalIndent(policy, "", " ") + if err != nil { + return "" + } + return string(result) +} + +// dedupeStatements removes exact duplicate statements using hash comparison +func dedupeStatements(statements []json.RawMessage) []json.RawMessage { + seen := make(map[string]bool) + var result []json.RawMessage + + for _, stmt := range statements { + // Normalize JSON by unmarshaling and remarshaling with sorted keys + var parsed map[string]any + if err := json.Unmarshal(stmt, &parsed); err != nil { + // Keep unparseable statements as-is + result = append(result, stmt) + continue + } + + // Remarshal to get consistent key ordering + normalized, err := json.Marshal(parsed) + if err != nil { + result = append(result, stmt) + continue + } + + key := string(normalized) + if !seen[key] { + seen[key] = true + result = append(result, stmt) + } + } + + return result +} + +// groupStatementsBySid groups statements by their Sid value +func groupStatementsBySid(statements []json.RawMessage) map[string][]json.RawMessage { + groups := make(map[string][]json.RawMessage) + + for _, stmt := range statements { + var parsed map[string]any + if err := json.Unmarshal(stmt, &parsed); err != nil { + // Put unparseable statements in a special group + groups["_unparseable"] = append(groups["_unparseable"], stmt) + continue + } + + sid, ok := parsed["Sid"].(string) + if !ok || sid == "" { + sid = "_no_sid" + } + + groups[sid] = append(groups[sid], stmt) + } + + return groups +} + +// ConsolidateStatementGroup uses LLM to merge similar statements into one +func (c *LLMClient) ConsolidateStatementGroup(sid string, statements []json.RawMessage) (json.RawMessage, error) { + // Build a compact representation of the statements + var stmtStrings []string + for _, stmt := range statements { + stmtStrings = append(stmtStrings, string(stmt)) + } + + prompt := fmt.Sprintf(`Merge these %d IAM policy statements with Sid "%s" into a SINGLE optimized statement. + +Statements to merge: +%s + +Rules: +1. Combine all Actions into one array (remove duplicates) +2. If Resources differ, use the most permissive (prefer "*" if any use it) +3. Merge Conditions intelligently (combine values for same condition keys) +4. Keep the same Sid name +5. Output ONLY the single merged JSON statement object - no markdown, no explanation + +Merged statement:`, len(statements), sid, strings.Join(stmtStrings, "\n")) + + response, err := c.ChatCompletion([]ChatMessage{ + {Role: "user", Content: prompt}, + }) + if err != nil { + return nil, err + } + + // Clean up response + response = strings.TrimSpace(response) + if strings.HasPrefix(response, "```json") { + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```") { + response = strings.TrimPrefix(response, "```") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } + + // Find JSON object + start := strings.Index(response, "{") + end := strings.LastIndex(response, "}") + if start >= 0 && end > start { + response = response[start : end+1] + } + + // Validate it's valid JSON + var validated json.RawMessage + if err := json.Unmarshal([]byte(response), &validated); err != nil { + return nil, fmt.Errorf("invalid JSON from LLM: %w", err) + } + + return validated, nil +} + +// ConsolidatePolicy deduplicates and consolidates policy statements +// Step 1: Remove exact duplicates (programmatic) +// Step 2: Group by Sid +// Step 3: Use LLM to merge groups with multiple statements +func (c *LLMClient) ConsolidatePolicy(statements []json.RawMessage, progress ProgressReporter, concurrency int) ([]json.RawMessage, error) { + if progress != nil { + progress.SetStatus("Consolidating policy statements...") + } + + // Step 1: Exact dedup + deduped := dedupeStatements(statements) + if progress != nil { + progress.SetStatus(fmt.Sprintf("Removed %d exact duplicates (%d → %d statements)", + len(statements)-len(deduped), len(statements), len(deduped))) + } + + // Step 2: Group by Sid + groups := groupStatementsBySid(deduped) + + // Count groups needing consolidation + var groupsToMerge []string + for sid, stmts := range groups { + if len(stmts) > 1 && sid != "_unparseable" && sid != "_no_sid" { + groupsToMerge = append(groupsToMerge, sid) + } + } + + if len(groupsToMerge) == 0 { + // No merging needed, return deduped statements + var result []json.RawMessage + for _, stmts := range groups { + result = append(result, stmts...) + } + return result, nil + } + + if progress != nil { + progress.SetStatus(fmt.Sprintf("Merging %d statement groups with LLM...", len(groupsToMerge))) + } + + // Step 3: Merge groups in parallel + if concurrency <= 0 { + concurrency = 3 + } + + type mergeResult struct { + sid string + stmt json.RawMessage + err error + } + + results := make(chan mergeResult, len(groupsToMerge)) + sem := make(chan struct{}, concurrency) + var wg sync.WaitGroup + + for _, sid := range groupsToMerge { + wg.Add(1) + go func(s string) { + defer wg.Done() + sem <- struct{}{} + defer func() { <-sem }() + + merged, err := c.ConsolidateStatementGroup(s, groups[s]) + results <- mergeResult{sid: s, stmt: merged, err: err} + }(sid) + } + + go func() { + wg.Wait() + close(results) + }() + + // Collect results + mergedGroups := make(map[string]json.RawMessage) + for result := range results { + if result.err != nil { + // On error, keep original statements for this group + continue + } + mergedGroups[result.sid] = result.stmt + } + + // Assemble final statement list + var finalStatements []json.RawMessage + for sid, stmts := range groups { + if merged, ok := mergedGroups[sid]; ok { + // Use merged statement + finalStatements = append(finalStatements, merged) + } else { + // Use original statements (single stmt groups, or merge failed) + finalStatements = append(finalStatements, stmts...) + } + } + + if progress != nil { + progress.SetStatus(fmt.Sprintf("Consolidated to %d statements", len(finalStatements))) + } + + return finalStatements, nil +} + +// EnrichActionDescriptions uses the LLM to add security context to actions +func (c *LLMClient) EnrichActionDescriptions(data *ScrapedIAMData, progress ProgressReporter) error { + if progress != nil { + progress.SetStatus("Enriching action descriptions with security context...") + } + + // Build a prompt asking for security considerations for each action + prompt := buildEnrichmentPrompt(data) + + // Merge system instructions into user message (some APIs don't support system role) + fullPrompt := "You are an AWS security expert. Provide concise security considerations for IAM actions. Respond in JSON format.\n\n---\n\n" + prompt + + response, err := c.ChatCompletion([]ChatMessage{ + { + Role: "user", + Content: fullPrompt, + }, + }) + if err != nil { + // Non-fatal - we can continue without enrichment + return nil + } + + // Parse and merge enrichment data + parseEnrichmentResponse(response, data) + return nil +} + +// GeneratePolicyDocumentation generates a Markdown documentation file for the policy +func (c *LLMClient) GeneratePolicyDocumentation(data *ScrapedIAMData, policyJSON string) (string, error) { + prompt := fmt.Sprintf(`Generate comprehensive Markdown documentation for the following IAM policy. + +## Service Information +- Service Name: %s +- Service Prefix: %s +- Total Actions Available: %d + +## The Policy to Document +%s + +## Documentation Requirements + +Create a well-structured Markdown document that includes: + +1. **Overview Section**: Brief description of what this policy provides access to + +2. **Statement Documentation**: For EACH statement in the policy: + - **Statement ID (Sid)**: The identifier + - **Purpose**: Clear explanation of what this statement allows + - **Actions Covered**: List the actions and briefly explain what they do + - **Resource Scope**: Explain what resources are affected + - **Conditions**: Explain any conditions and their security implications + - **Security Notes**: Any security considerations for this statement + +3. **Variables to Configure**: List ALL placeholder variables found in the policy (like ${AWS::AccountId}, ${VpcEndpointId}, etc.) with: + - Variable name + - Description of what value to substitute + - Example value format + - Where to find/determine the correct value + +4. **Security Summary**: + - Key security controls in place + - Recommendations for further hardening + - Compliance considerations (mention relevance to UK Gov, financial sector) + +5. **Usage Notes**: + - When to use this policy + - What roles/users it's suitable for + - Any prerequisites + +Output ONLY the Markdown content, no code fences around it.`, data.ServiceName, data.ServicePrefix, len(data.Actions), policyJSON) + + response, err := c.ChatCompletion([]ChatMessage{ + { + Role: "user", + Content: prompt, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to generate documentation: %w", err) + } + + // Clean up response - remove any markdown code fences if present + response = strings.TrimSpace(response) + if strings.HasPrefix(response, "```markdown") { + response = strings.TrimPrefix(response, "```markdown") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```md") { + response = strings.TrimPrefix(response, "```md") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```") { + response = strings.TrimPrefix(response, "```") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } + + return response, nil +} + +// GenerateSCP generates a Service Control Policy based on the identity policy and user requirements +func (c *LLMClient) GenerateSCP(data *ScrapedIAMData, identityPolicyJSON string, userPrompt string) (string, error) { + prompt := fmt.Sprintf(`Generate a Service Control Policy (SCP) to complement the following identity policy. + +## Service Information +- Service Name: %s +- Service Prefix: %s + +## Identity Policy (already created) +%s + +## User Requirements +%s + +## SCP Generation Guidelines + +Use the "deny all except allowlist" pattern with NotAction. This is the standard enterprise SCP pattern. + +CRITICAL STRUCTURE - Use NotAction to deny everything EXCEPT allowed actions: + +{ + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "DenyAllExceptAllowedActions", + "Effect": "Deny", + "NotAction": [ + "service:AllowedAction1", + "service:AllowedAction2", + "service:Get*", + "service:List*", + "service:Describe*" + ], + "Resource": "*" + }, + { + "Sid": "DenyInsecureTransport", + "Effect": "Deny", + "Action": "service:*", + "Resource": "*", + "Condition": { + "Bool": { + "aws:SecureTransport": "false" + } + } + }, + { + "Sid": "DenyOutsideAllowedRegions", + "Effect": "Deny", + "Action": "service:*", + "Resource": "*", + "Condition": { + "StringNotEqualsIfExists": { + "aws:RequestedRegion": ["eu-west-2"] + } + } + }, + { + "Sid": "DenyGlobalCrossRegionInference", + "Effect": "Deny", + "Action": ["service:InvokeModel", "service:InvokeModelWithResponseStream"], + "Resource": "*", + "Condition": { + "StringEquals": { + "aws:RequestedRegion": "unspecified" + } + } + } + ] +} + +Guidelines: +1. Extract the ALLOWED actions from the identity policy and put them in NotAction +2. Include Get*, List*, Describe* wildcards in NotAction for read operations +3. Add DenyInsecureTransport statement (deny when aws:SecureTransport = false) +4. Add region restriction statement using StringNotEqualsIfExists +5. Add global CRIS denial for invoke actions (aws:RequestedRegion = "unspecified") +6. Keep statements minimal - the NotAction pattern handles most denials in one statement + +Output ONLY valid JSON for an SCP policy document. No markdown, no explanations.`, data.ServiceName, data.ServicePrefix, identityPolicyJSON, userPrompt) + + response, err := c.ChatCompletion([]ChatMessage{ + {Role: "user", Content: prompt}, + }) + if err != nil { + return "", fmt.Errorf("failed to generate SCP: %w", err) + } + + // Clean up response + response = strings.TrimSpace(response) + if strings.HasPrefix(response, "```json") { + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```") { + response = strings.TrimPrefix(response, "```") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } + + // Find JSON object + start := strings.Index(response, "{") + end := strings.LastIndex(response, "}") + if start >= 0 && end > start { + response = response[start : end+1] + } + + // Validate JSON + var js json.RawMessage + if err := json.Unmarshal([]byte(response), &js); err != nil { + return "", fmt.Errorf("invalid SCP JSON: %w", err) + } + + return response, nil +} + +// GenerateCombinedDocumentation generates documentation covering both identity policy and SCP +func (c *LLMClient) GenerateCombinedDocumentation(data *ScrapedIAMData, identityPolicyJSON, scpJSON string) (string, error) { + prompt := fmt.Sprintf(`Generate comprehensive Markdown documentation for the following IAM Identity Policy and Service Control Policy (SCP). + +## Service Information +- Service Name: %s +- Service Prefix: %s +- Total Actions Available: %d + +## Identity Policy +%s + +## Service Control Policy (SCP) +%s + +## Documentation Requirements + +Create a well-structured Markdown document that includes: + +1. **Overview Section**: + - Brief description of the overall access control strategy + - Explain the separation between identity policy (what users CAN do) and SCP (org-wide guardrails) + +2. **Identity Policy Documentation**: + - Purpose: What this policy allows + - Statement-by-statement breakdown + - Actions, resources, and conditions explained + - Who should have this policy attached + +3. **SCP Documentation**: + - Purpose: What org-wide guardrails this provides + - Statement-by-statement breakdown + - Why these denies are at the SCP level (cannot be bypassed) + - Which OUs/accounts this should be applied to + +4. **Deployment Guide**: + - Step 1: Deploy SCP to AWS Organizations (specify OU or account targets) + - Step 2: Attach identity policy to IAM roles/users + - Testing recommendations + - Rollback procedures + +5. **Variables to Configure**: + - List ALL placeholder variables from BOTH policies + - Description, example values, and where to find them + +6. **Security Summary**: + - Defense in depth explanation (identity + SCP layers) + - Key security controls in place + - Compliance considerations (UK Gov, financial sector, regulated environments) + +7. **Troubleshooting**: + - Common access denied scenarios + - How to diagnose SCP vs identity policy denials + - CloudTrail event patterns to look for + +Output ONLY the Markdown content, no code fences around it.`, data.ServiceName, data.ServicePrefix, len(data.Actions), identityPolicyJSON, scpJSON) + + response, err := c.ChatCompletion([]ChatMessage{ + {Role: "user", Content: prompt}, + }) + if err != nil { + return "", fmt.Errorf("failed to generate combined documentation: %w", err) + } + + // Clean up response + response = strings.TrimSpace(response) + if strings.HasPrefix(response, "```markdown") { + response = strings.TrimPrefix(response, "```markdown") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```md") { + response = strings.TrimPrefix(response, "```md") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```") { + response = strings.TrimPrefix(response, "```") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } + + return response, nil +} + +// ChatCompletion sends a chat completion request and returns the response content +// It includes retry logic for transient network failures +func (c *LLMClient) ChatCompletion(messages []ChatMessage) (string, error) { + reqBody := ChatCompletionRequest{ + Model: c.Model, + Messages: messages, + Temperature: 0.3, + MaxTokens: 4096, + } + + jsonBody, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + url := c.BaseURL + "/v1/chat/completions" + + // Retry logic for transient failures + var lastErr error + maxAttempts := 5 + for attempt := 1; attempt <= maxAttempts; attempt++ { + if attempt > 1 { + backoff := time.Duration(attempt*attempt) * time.Second // Exponential backoff: 4s, 9s, 16s, 25s + time.Sleep(backoff) + } + + result, err := c.doRequest(url, jsonBody) + if err == nil { + return result, nil + } + lastErr = err + + // Retry on network errors and server errors (5xx) + errStr := err.Error() + isRetryable := strings.Contains(errStr, "connection reset") || + strings.Contains(errStr, "timeout") || + strings.Contains(errStr, "EOF") || + strings.Contains(errStr, "HTTP 500") || + strings.Contains(errStr, "HTTP 502") || + strings.Contains(errStr, "HTTP 503") || + strings.Contains(errStr, "HTTP 504") + + if isRetryable && attempt < maxAttempts { + continue + } + return "", err + } + return "", fmt.Errorf("after %d attempts: %w", maxAttempts, lastErr) +} + +// doRequest performs a single HTTP request +func (c *LLMClient) doRequest(url string, jsonBody []byte) (string, error) { + req, err := http.NewRequest("POST", url, bytes.NewReader(jsonBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if c.APIKey != "" { + req.Header.Set("Authorization", "Bearer "+c.APIKey) + } + + resp, err := c.Client.Do(req) + if err != nil { + return "", fmt.Errorf("HTTP request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + // Check HTTP status + if resp.StatusCode != http.StatusOK { + bodyStr := string(body) + if len(bodyStr) > 500 { + bodyStr = bodyStr[:500] + "..." + } + return "", fmt.Errorf("HTTP %d: %s", resp.StatusCode, bodyStr) + } + + var chatResp ChatCompletionResponse + if err := json.Unmarshal(body, &chatResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w (body: %s)", err, string(body)) + } + + if chatResp.Error != nil { + return "", fmt.Errorf("API error: %s", chatResp.Error.Message) + } + + if len(chatResp.Choices) == 0 { + // Check if there's a different response structure (some APIs use different formats) + var altResp map[string]any + if err := json.Unmarshal(body, &altResp); err == nil { + // Check for OpenWebUI specific response formats + if msg, ok := altResp["message"].(map[string]any); ok { + if content, ok := msg["content"].(string); ok { + return content, nil + } + } + // Check for error field + if errMsg, ok := altResp["error"].(string); ok { + return "", fmt.Errorf("API error: %s", errMsg) + } + if detail, ok := altResp["detail"].(string); ok { + return "", fmt.Errorf("API error: %s", detail) + } + } + // Log truncated response for debugging + bodyStr := string(body) + if len(bodyStr) > 500 { + bodyStr = bodyStr[:500] + "..." + } + return "", fmt.Errorf("no choices in response (body: %s)", bodyStr) + } + + return chatResp.Choices[0].Message.Content, nil +} + +func buildEnrichmentPrompt(data *ScrapedIAMData) string { + var sb strings.Builder + + sb.WriteString("For each of the following IAM actions, provide a brief security risk level (Low/Medium/High/Critical) and a one-line security consideration. Respond as JSON with format: {\"actions\": [{\"name\": \"ActionName\", \"risk\": \"Level\", \"security_note\": \"Note\"}]}\n\n") + + sb.WriteString(fmt.Sprintf("Service: %s\n\n", data.ServicePrefix)) + + for _, action := range data.Actions { + sb.WriteString(fmt.Sprintf("- %s (%s): %s\n", action.Name, action.AccessLevel, action.Description)) + } + + return sb.String() +} + +func extractJSONFromResponse(response string) string { + // Try to extract JSON from the response + response = strings.TrimSpace(response) + + // Remove markdown code fences if present + if strings.HasPrefix(response, "```json") { + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } else if strings.HasPrefix(response, "```") { + response = strings.TrimPrefix(response, "```") + response = strings.TrimSuffix(response, "```") + response = strings.TrimSpace(response) + } + + // Validate it's proper JSON + var js json.RawMessage + if err := json.Unmarshal([]byte(response), &js); err != nil { + // Try to find JSON object in the response + start := strings.Index(response, "{") + end := strings.LastIndex(response, "}") + if start >= 0 && end > start { + candidate := response[start : end+1] + if err := json.Unmarshal([]byte(candidate), &js); err == nil { + return candidate + } + } + return "" + } + + return response +} + +func parseEnrichmentResponse(response string, data *ScrapedIAMData) { + // Try to parse the enrichment response and update actions + // This is best-effort - we don't fail if it doesn't work + type EnrichmentData struct { + Actions []struct { + Name string `json:"name"` + Risk string `json:"risk"` + SecurityNote string `json:"security_note"` + } `json:"actions"` + } + + jsonStr := extractJSONFromResponse(response) + if jsonStr == "" { + return + } + + var enrichment EnrichmentData + if err := json.Unmarshal([]byte(jsonStr), &enrichment); err != nil { + return + } + + // Build lookup map + enrichmentMap := make(map[string]struct { + Risk string + SecurityNote string + }) + for _, e := range enrichment.Actions { + enrichmentMap[e.Name] = struct { + Risk string + SecurityNote string + }{e.Risk, e.SecurityNote} + } + + // Note: In a full implementation, we'd add fields to IAMAction for this + // For now, we could append to the description or add new fields + _ = enrichmentMap +} diff --git a/internal/llm_test.go b/internal/llm_test.go new file mode 100644 index 0000000..de69bb6 --- /dev/null +++ b/internal/llm_test.go @@ -0,0 +1,1507 @@ +package internal + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestNewLLMClient(t *testing.T) { + tests := []struct { + name string + baseURL string + apiKey string + model string + wantURL string + }{ + { + name: "basic creation", + baseURL: "https://api.example.com", + apiKey: "test-key", + model: "gpt-4", + wantURL: "https://api.example.com", + }, + { + name: "strips trailing slash", + baseURL: "https://api.example.com/", + apiKey: "test-key", + model: "gpt-4", + wantURL: "https://api.example.com", + }, + { + name: "empty api key", + baseURL: "https://api.example.com", + apiKey: "", + model: "gpt-4", + wantURL: "https://api.example.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := NewLLMClient(tt.baseURL, tt.apiKey, tt.model) + if client.BaseURL != tt.wantURL { + t.Errorf("BaseURL = %v, want %v", client.BaseURL, tt.wantURL) + } + if client.APIKey != tt.apiKey { + t.Errorf("APIKey = %v, want %v", client.APIKey, tt.apiKey) + } + if client.Model != tt.model { + t.Errorf("Model = %v, want %v", client.Model, tt.model) + } + if client.Client == nil { + t.Error("HTTP client is nil") + } + }) + } +} + +func TestExtractStatementsFromResponse(t *testing.T) { + tests := []struct { + name string + response string + wantLen int + }{ + { + name: "plain JSON array", + response: `[{"Sid": "Test", "Effect": "Allow"}]`, + wantLen: 1, + }, + { + name: "JSON array with multiple statements", + response: `[{"Sid": "Test1"}, {"Sid": "Test2"}, {"Sid": "Test3"}]`, + wantLen: 3, + }, + { + name: "markdown code fence json", + response: "```json\n[{\"Sid\": \"Test\"}]\n```", + wantLen: 1, + }, + { + name: "markdown code fence without json tag", + response: "```\n[{\"Sid\": \"Test\"}]\n```", + wantLen: 1, + }, + { + name: "JSON array with surrounding text", + response: "Here is your policy:\n[{\"Sid\": \"Test\"}]\nEnd of policy.", + wantLen: 1, + }, + { + name: "empty array", + response: "[]", + wantLen: 0, + }, + { + name: "invalid JSON", + response: "not valid json", + wantLen: 0, + }, + { + name: "whitespace around JSON", + response: " \n [{\"Sid\": \"Test\"}] \n ", + wantLen: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractStatementsFromResponse(tt.response) + if len(result) != tt.wantLen { + t.Errorf("extractStatementsFromResponse() returned %d statements, want %d", len(result), tt.wantLen) + } + }) + } +} + +func TestAssembleFinalPolicy(t *testing.T) { + tests := []struct { + name string + statements []json.RawMessage + wantFields []string + }{ + { + name: "single statement", + statements: []json.RawMessage{json.RawMessage(`{"Sid": "Test", "Effect": "Allow"}`)}, + wantFields: []string{"Version", "Statement"}, + }, + { + name: "multiple statements", + statements: []json.RawMessage{json.RawMessage(`{"Sid": "Test1"}`), json.RawMessage(`{"Sid": "Test2"}`)}, + wantFields: []string{"Version", "Statement"}, + }, + { + name: "empty statements", + statements: []json.RawMessage{}, + wantFields: []string{"Version", "Statement"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := assembleFinalPolicy(tt.statements) + if result == "" { + t.Error("assembleFinalPolicy() returned empty string") + return + } + + var policy map[string]any + if err := json.Unmarshal([]byte(result), &policy); err != nil { + t.Errorf("assembleFinalPolicy() returned invalid JSON: %v", err) + return + } + + for _, field := range tt.wantFields { + if _, ok := policy[field]; !ok { + t.Errorf("assembleFinalPolicy() missing field %s", field) + } + } + + if policy["Version"] != "2012-10-17" { + t.Errorf("assembleFinalPolicy() Version = %v, want 2012-10-17", policy["Version"]) + } + }) + } +} + +func TestDedupeStatements(t *testing.T) { + tests := []struct { + name string + statements []json.RawMessage + wantLen int + }{ + { + name: "no duplicates", + statements: []json.RawMessage{ + json.RawMessage(`{"Sid": "Test1"}`), + json.RawMessage(`{"Sid": "Test2"}`), + }, + wantLen: 2, + }, + { + name: "exact duplicates", + statements: []json.RawMessage{ + json.RawMessage(`{"Sid": "Test1"}`), + json.RawMessage(`{"Sid": "Test1"}`), + }, + wantLen: 1, + }, + { + name: "duplicates with different whitespace", + statements: []json.RawMessage{ + json.RawMessage(`{"Sid":"Test1"}`), + json.RawMessage(`{ "Sid" : "Test1" }`), + }, + wantLen: 1, + }, + { + name: "mix of duplicates and unique", + statements: []json.RawMessage{ + json.RawMessage(`{"Sid": "Test1"}`), + json.RawMessage(`{"Sid": "Test2"}`), + json.RawMessage(`{"Sid": "Test1"}`), + json.RawMessage(`{"Sid": "Test3"}`), + }, + wantLen: 3, + }, + { + name: "empty input", + statements: []json.RawMessage{}, + wantLen: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := dedupeStatements(tt.statements) + if len(result) != tt.wantLen { + t.Errorf("dedupeStatements() returned %d statements, want %d", len(result), tt.wantLen) + } + }) + } +} + +func TestGroupStatementsBySid(t *testing.T) { + tests := []struct { + name string + statements []json.RawMessage + wantGroups map[string]int + }{ + { + name: "single statement per sid", + statements: []json.RawMessage{ + json.RawMessage(`{"Sid": "Test1"}`), + json.RawMessage(`{"Sid": "Test2"}`), + }, + wantGroups: map[string]int{"Test1": 1, "Test2": 1}, + }, + { + name: "multiple statements per sid", + statements: []json.RawMessage{ + json.RawMessage(`{"Sid": "Test1", "Effect": "Allow"}`), + json.RawMessage(`{"Sid": "Test1", "Effect": "Deny"}`), + json.RawMessage(`{"Sid": "Test2"}`), + }, + wantGroups: map[string]int{"Test1": 2, "Test2": 1}, + }, + { + name: "statements without sid", + statements: []json.RawMessage{ + json.RawMessage(`{"Effect": "Allow"}`), + json.RawMessage(`{"Effect": "Deny"}`), + }, + wantGroups: map[string]int{"_no_sid": 2}, + }, + { + name: "empty input", + statements: []json.RawMessage{}, + wantGroups: map[string]int{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := groupStatementsBySid(tt.statements) + for sid, count := range tt.wantGroups { + if len(result[sid]) != count { + t.Errorf("groupStatementsBySid() sid %s has %d statements, want %d", sid, len(result[sid]), count) + } + } + }) + } +} + +func TestGetBatchSystemPrompt(t *testing.T) { + tests := []struct { + name string + userPrompt string + wantContains []string + }{ + { + name: "empty user prompt", + userPrompt: "", + wantContains: []string{ + "AWS IAM security expert", + "CRITICAL: Output ONLY a valid JSON array", + }, + }, + { + name: "with user prompt", + userPrompt: "Only include read actions", + wantContains: []string{ + "AWS IAM security expert", + "USER REQUIREMENTS (PRIMARY", + "Only include read actions", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getBatchSystemPrompt(tt.userPrompt) + for _, want := range tt.wantContains { + if !strings.Contains(result, want) { + t.Errorf("getBatchSystemPrompt() missing expected content: %s", want) + } + } + }) + } +} + +func TestBuildBatchPolicyPrompt(t *testing.T) { + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + ConditionKeys: []IAMConditionKey{ + {Name: "s3:prefix", Type: "String"}, + }, + } + + batch := []IAMAction{ + {Name: "GetObject", Description: "Get an object", AccessLevel: "Read"}, + {Name: "PutObject", Description: "Put an object", AccessLevel: "Write"}, + } + + result := buildBatchPolicyPrompt(data, batch, 1, 3, "test prompt") + + expectedContains := []string{ + "s3", + "Batch 1 of 3", + "GetObject", + "PutObject", + "Read Actions", + "Write Actions", + "s3:prefix", + } + + for _, want := range expectedContains { + if !strings.Contains(result, want) { + t.Errorf("buildBatchPolicyPrompt() missing expected content: %s", want) + } + } +} + +func TestBuildEnrichmentPrompt(t *testing.T) { + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: "Read", Description: "Get an object"}, + {Name: "DeleteObject", AccessLevel: "Write", Description: "Delete an object"}, + }, + } + + result := buildEnrichmentPrompt(data) + + expectedContains := []string{ + "s3", + "GetObject", + "DeleteObject", + "security risk level", + "JSON", + } + + for _, want := range expectedContains { + if !strings.Contains(result, want) { + t.Errorf("buildEnrichmentPrompt() missing expected content: %s", want) + } + } +} + +func TestExtractJSONFromResponse(t *testing.T) { + tests := []struct { + name string + response string + wantJSON bool + }{ + { + name: "plain JSON object", + response: `{"actions": []}`, + wantJSON: true, + }, + { + name: "JSON with markdown fence", + response: "```json\n{\"actions\": []}\n```", + wantJSON: true, + }, + { + name: "JSON with plain fence", + response: "```\n{\"actions\": []}\n```", + wantJSON: true, + }, + { + name: "JSON embedded in text", + response: "Here is the result: {\"actions\": []} end.", + wantJSON: true, + }, + { + name: "invalid JSON", + response: "not valid json", + wantJSON: false, + }, + { + name: "empty string", + response: "", + wantJSON: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractJSONFromResponse(tt.response) + if tt.wantJSON && result == "" { + t.Error("extractJSONFromResponse() returned empty string, expected JSON") + } + if !tt.wantJSON && result != "" { + t.Errorf("extractJSONFromResponse() = %v, expected empty string", result) + } + }) + } +} + +func TestParseEnrichmentResponse(t *testing.T) { + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: "Read"}, + {Name: "DeleteObject", AccessLevel: "Write"}, + }, + } + + response := `{"actions": [ + {"name": "GetObject", "risk": "Low", "security_note": "Read only"}, + {"name": "DeleteObject", "risk": "High", "security_note": "Destructive action"} + ]}` + + // Should not panic + parseEnrichmentResponse(response, data) + + // Test with invalid JSON + parseEnrichmentResponse("invalid", data) + + // Test with empty response + parseEnrichmentResponse("", data) +} + +func TestChatCompletion(t *testing.T) { + tests := []struct { + name string + serverResponse string + serverStatus int + wantErr bool + wantContent string + }{ + { + name: "successful response", + serverResponse: `{ + "choices": [{ + "message": {"role": "assistant", "content": "Hello world"} + }] + }`, + serverStatus: http.StatusOK, + wantErr: false, + wantContent: "Hello world", + }, + { + name: "server error", + serverResponse: `{"error": {"message": "Server error"}}`, + serverStatus: http.StatusInternalServerError, + wantErr: true, + }, + { + name: "api error in response", + serverResponse: `{ + "error": {"message": "Invalid API key", "type": "auth_error"} + }`, + serverStatus: http.StatusOK, + wantErr: true, + }, + { + name: "empty choices", + serverResponse: `{"choices": []}`, + serverStatus: http.StatusOK, + wantErr: true, + }, + { + name: "alternative response format (OpenWebUI)", + serverResponse: `{ + "message": {"content": "Alt response"} + }`, + serverStatus: http.StatusOK, + wantErr: false, + wantContent: "Alt response", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request + if r.Method != "POST" { + t.Errorf("expected POST request, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("expected Content-Type: application/json") + } + if r.Header.Get("Authorization") != "Bearer test-key" { + t.Errorf("expected Authorization header") + } + + w.WriteHeader(tt.serverStatus) + w.Write([]byte(tt.serverResponse)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + result, err := client.ChatCompletion([]ChatMessage{ + {Role: "user", Content: "Hello"}, + }) + + if tt.wantErr && err == nil { + t.Error("ChatCompletion() expected error, got nil") + } + if !tt.wantErr && err != nil { + t.Errorf("ChatCompletion() unexpected error: %v", err) + } + if !tt.wantErr && result != tt.wantContent { + t.Errorf("ChatCompletion() = %v, want %v", result, tt.wantContent) + } + }) + } +} + +func TestChatCompletionRetry(t *testing.T) { + attempts := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempts++ + if attempts < 3 { + // First two attempts fail with 500 + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "temporary error"}`)) + return + } + // Third attempt succeeds + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "success"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + result, err := client.ChatCompletion([]ChatMessage{ + {Role: "user", Content: "Hello"}, + }) + + if err != nil { + t.Errorf("ChatCompletion() unexpected error after retries: %v", err) + } + if result != "success" { + t.Errorf("ChatCompletion() = %v, want success", result) + } + if attempts != 3 { + t.Errorf("Expected 3 attempts, got %d", attempts) + } +} + +func TestDoRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "test"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + result, err := client.doRequest(server.URL+"/v1/chat/completions", []byte(`{}`)) + + if err != nil { + t.Errorf("doRequest() unexpected error: %v", err) + } + if result != "test" { + t.Errorf("doRequest() = %v, want test", result) + } +} + +func TestDoRequestNoAPIKey(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Should not have Authorization header + if r.Header.Get("Authorization") != "" { + t.Error("expected no Authorization header") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "test"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "", "test-model") + _, err := client.doRequest(server.URL+"/v1/chat/completions", []byte(`{}`)) + + if err != nil { + t.Errorf("doRequest() unexpected error: %v", err) + } +} + +func TestGenerateSecurityPolicy(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + // Return a simple policy statement for each batch + w.Write([]byte(`{"choices": [{"message": {"content": "[{\"Sid\": \"Test\", \"Effect\": \"Allow\", \"Action\": [\"s3:GetObject\"], \"Resource\": \"*\"}]"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: "Read", Description: "Get object"}, + }, + } + + result, err := client.GenerateSecurityPolicy(data, nil, "", 1) + + if err != nil { + t.Errorf("GenerateSecurityPolicy() unexpected error: %v", err) + } + if result == "" { + t.Error("GenerateSecurityPolicy() returned empty string") + } + + var policy map[string]any + if err := json.Unmarshal([]byte(result), &policy); err != nil { + t.Errorf("GenerateSecurityPolicy() returned invalid JSON: %v", err) + } +} + +func TestConsolidatePolicy(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"Sid\": \"Merged\", \"Effect\": \"Allow\"}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Test", "Effect": "Allow", "Action": ["s3:GetObject"]}`), + json.RawMessage(`{"Sid": "Test", "Effect": "Allow", "Action": ["s3:PutObject"]}`), + } + + result, err := client.ConsolidatePolicy(statements, nil, 1) + + if err != nil { + t.Errorf("ConsolidatePolicy() unexpected error: %v", err) + } + if len(result) == 0 { + t.Error("ConsolidatePolicy() returned empty result") + } +} + +func TestConsolidatePolicyNoDuplicates(t *testing.T) { + client := NewLLMClient("http://unused", "", "test-model") + + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Test1", "Effect": "Allow"}`), + json.RawMessage(`{"Sid": "Test2", "Effect": "Deny"}`), + } + + result, err := client.ConsolidatePolicy(statements, nil, 1) + + if err != nil { + t.Errorf("ConsolidatePolicy() unexpected error: %v", err) + } + if len(result) != 2 { + t.Errorf("ConsolidatePolicy() returned %d statements, want 2", len(result)) + } +} + +func TestGenerateSCP(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"Version\": \"2012-10-17\", \"Statement\": []}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + } + + result, err := client.GenerateSCP(data, `{"Version": "2012-10-17"}`, "test prompt") + + if err != nil { + t.Errorf("GenerateSCP() unexpected error: %v", err) + } + if result == "" { + t.Error("GenerateSCP() returned empty string") + } + + var policy map[string]any + if err := json.Unmarshal([]byte(result), &policy); err != nil { + t.Errorf("GenerateSCP() returned invalid JSON: %v", err) + } +} + +func TestGeneratePolicyDocumentation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "# Policy Documentation\n\nThis is the documentation."}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + Actions: []IAMAction{{Name: "GetObject"}}, + } + + result, err := client.GeneratePolicyDocumentation(data, `{"Version": "2012-10-17"}`) + + if err != nil { + t.Errorf("GeneratePolicyDocumentation() unexpected error: %v", err) + } + if !strings.Contains(result, "Policy Documentation") { + t.Error("GeneratePolicyDocumentation() missing expected content") + } +} + +func TestGenerateCombinedDocumentation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("{\"choices\": [{\"message\": {\"content\": \"# Combined Documentation\"}}]}")) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + Actions: []IAMAction{{Name: "GetObject"}}, + } + + result, err := client.GenerateCombinedDocumentation(data, `{"Version": "2012-10-17"}`, `{"Version": "2012-10-17"}`) + + if err != nil { + t.Errorf("GenerateCombinedDocumentation() unexpected error: %v", err) + } + if !strings.Contains(result, "Combined Documentation") { + t.Error("GenerateCombinedDocumentation() missing expected content") + } +} + +func TestEnrichActionDescriptions(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"actions\": [{\"name\": \"GetObject\", \"risk\": \"Low\", \"security_note\": \"Read only\"}]}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: "Read"}, + }, + } + + err := client.EnrichActionDescriptions(data, nil) + + if err != nil { + t.Errorf("EnrichActionDescriptions() unexpected error: %v", err) + } +} + +func TestConsolidateStatementGroup(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"Sid\": \"Merged\", \"Effect\": \"Allow\", \"Action\": [\"s3:GetObject\", \"s3:PutObject\"]}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Test", "Effect": "Allow", "Action": ["s3:GetObject"]}`), + json.RawMessage(`{"Sid": "Test", "Effect": "Allow", "Action": ["s3:PutObject"]}`), + } + + result, err := client.ConsolidateStatementGroup("Test", statements) + + if err != nil { + t.Errorf("ConsolidateStatementGroup() unexpected error: %v", err) + } + if result == nil { + t.Error("ConsolidateStatementGroup() returned nil") + } +} + +func TestConsolidateStatementGroupInvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "not valid json"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Test"}`), + } + + _, err := client.ConsolidateStatementGroup("Test", statements) + if err == nil { + t.Error("ConsolidateStatementGroup() expected error for invalid JSON, got nil") + } +} + +func TestGeneratePolicyDocumentationError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "server error"}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + Actions: []IAMAction{{Name: "GetObject"}}, + } + + _, err := client.GeneratePolicyDocumentation(data, `{"Version": "2012-10-17"}`) + if err == nil { + t.Error("GeneratePolicyDocumentation() expected error, got nil") + } +} + +func TestGeneratePolicyDocumentationCleanup(t *testing.T) { + tests := []struct { + name string + response string + want string + }{ + { + name: "with markdown fence", + response: "```markdown\n# Documentation\n```", + want: "# Documentation", + }, + { + name: "with md fence", + response: "```md\n# Documentation\n```", + want: "# Documentation", + }, + { + name: "with plain fence", + response: "```\n# Documentation\n```", + want: "# Documentation", + }, + { + name: "no fence", + response: "# Documentation", + want: "# Documentation", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "choices": []any{ + map[string]any{ + "message": map[string]any{ + "content": tt.response, + }, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServiceName: "S3", ServicePrefix: "s3", Actions: []IAMAction{{Name: "Get"}}} + + result, err := client.GeneratePolicyDocumentation(data, "{}") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !strings.Contains(result, "Documentation") { + t.Errorf("result = %q, want to contain 'Documentation'", result) + } + }) + } +} + +func TestGenerateSCPCleanup(t *testing.T) { + tests := []struct { + name string + response string + }{ + { + name: "with json fence", + response: "```json\n{\"Version\": \"2012-10-17\", \"Statement\": []}\n```", + }, + { + name: "with plain fence", + response: "```\n{\"Version\": \"2012-10-17\", \"Statement\": []}\n```", + }, + { + name: "with surrounding text", + response: "Here is the SCP:\n{\"Version\": \"2012-10-17\", \"Statement\": []}\nEnd.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "choices": []any{ + map[string]any{ + "message": map[string]any{ + "content": tt.response, + }, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServiceName: "S3", ServicePrefix: "s3"} + + result, err := client.GenerateSCP(data, "{}", "") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result == "" { + t.Error("result should not be empty") + } + }) + } +} + +func TestGenerateSCPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServiceName: "S3", ServicePrefix: "s3"} + + _, err := client.GenerateSCP(data, "{}", "") + if err == nil { + t.Error("GenerateSCP() expected error, got nil") + } +} + +func TestGenerateSCPInvalidJSON(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "not valid json at all"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServiceName: "S3", ServicePrefix: "s3"} + + _, err := client.GenerateSCP(data, "{}", "") + if err == nil { + t.Error("GenerateSCP() expected error for invalid JSON, got nil") + } +} + +func TestGenerateCombinedDocumentationError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServiceName: "S3", ServicePrefix: "s3", Actions: []IAMAction{{Name: "Get"}}} + + _, err := client.GenerateCombinedDocumentation(data, "{}", "{}") + if err == nil { + t.Error("GenerateCombinedDocumentation() expected error, got nil") + } +} + +func TestGenerateCombinedDocumentationCleanup(t *testing.T) { + tests := []struct { + name string + response string + }{ + {name: "markdown fence", response: "```markdown\n# Docs\n```"}, + {name: "md fence", response: "```md\n# Docs\n```"}, + {name: "plain fence", response: "```\n# Docs\n```"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + resp := map[string]any{ + "choices": []any{ + map[string]any{ + "message": map[string]any{ + "content": tt.response, + }, + }, + }, + } + json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServiceName: "S3", ServicePrefix: "s3", Actions: []IAMAction{{Name: "Get"}}} + + result, err := client.GenerateCombinedDocumentation(data, "{}", "{}") + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !strings.Contains(result, "Docs") { + t.Errorf("result = %q, want to contain 'Docs'", result) + } + }) + } +} + +func TestEnrichActionDescriptionsError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + data := &ScrapedIAMData{ServicePrefix: "s3", Actions: []IAMAction{{Name: "Get"}}} + + // Should not return error - enrichment is non-fatal + err := client.EnrichActionDescriptions(data, nil) + if err != nil { + t.Errorf("EnrichActionDescriptions() should not return error, got: %v", err) + } +} + +func TestDedupeStatementsUnparseable(t *testing.T) { + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Valid"}`), + json.RawMessage(`not valid json`), + json.RawMessage(`{"Sid": "Valid2"}`), + } + + result := dedupeStatements(statements) + // Should keep all statements including unparseable ones + if len(result) != 3 { + t.Errorf("dedupeStatements() = %d statements, want 3", len(result)) + } +} + +func TestGroupStatementsBySidUnparseable(t *testing.T) { + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Valid"}`), + json.RawMessage(`not valid json`), + } + + result := groupStatementsBySid(statements) + if _, ok := result["_unparseable"]; !ok { + t.Error("groupStatementsBySid() should have _unparseable group") + } + if len(result["_unparseable"]) != 1 { + t.Errorf("_unparseable group should have 1 statement, got %d", len(result["_unparseable"])) + } +} + +func TestAssembleFinalPolicyEmpty(t *testing.T) { + result := assembleFinalPolicy(nil) + if result == "" { + t.Error("assembleFinalPolicy(nil) should not return empty string") + } + + var policy map[string]any + if err := json.Unmarshal([]byte(result), &policy); err != nil { + t.Errorf("assembleFinalPolicy() returned invalid JSON: %v", err) + } +} + +func TestBuildBatchPolicyPromptNoConditionKeys(t *testing.T) { + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + ConditionKeys: nil, // No condition keys + } + + batch := []IAMAction{ + {Name: "GetObject", Description: "Get", AccessLevel: "Read"}, + } + + result := buildBatchPolicyPrompt(data, batch, 1, 1, "") + if !strings.Contains(result, "s3") { + t.Error("buildBatchPolicyPrompt() should contain service prefix") + } +} + +func TestBuildBatchPolicyPromptManyConditionKeys(t *testing.T) { + // Test with more than 10 condition keys (should be truncated) + keys := make([]IAMConditionKey, 15) + for i := range keys { + keys[i] = IAMConditionKey{Name: "key" + string(rune('a'+i)), Type: "String"} + } + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + ConditionKeys: keys, + } + + batch := []IAMAction{{Name: "GetObject", AccessLevel: "Read"}} + + result := buildBatchPolicyPrompt(data, batch, 1, 1, "") + if !strings.Contains(result, "and") && !strings.Contains(result, "more") { + // The prompt should indicate there are more keys + // (exact format may vary) + } +} + +func TestConsolidatePolicyWithProgress(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"Sid\": \"Merged\"}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + // Create a mock progress reporter + progress := &mockProgress{} + + statements := []json.RawMessage{ + json.RawMessage(`{"Sid": "Test", "Effect": "Allow"}`), + json.RawMessage(`{"Sid": "Test", "Effect": "Deny"}`), + } + + _, err := client.ConsolidatePolicy(statements, progress, 1) + if err != nil { + t.Errorf("ConsolidatePolicy() unexpected error: %v", err) + } + + if progress.statusCalls == 0 { + t.Error("ConsolidatePolicy() should call progress.SetStatus()") + } +} + +// mockProgress implements ProgressReporter for testing +type mockProgress struct { + statusCalls int + progressCalls int +} + +func (m *mockProgress) SetStatus(status string) { + m.statusCalls++ +} + +func (m *mockProgress) SetProgress(current, total int) { + m.progressCalls++ +} + +func TestConsolidateStatementGroupWithMarkdownFences(t *testing.T) { + tests := []struct { + name string + response string + }{ + { + name: "json code fence", + response: "```json\n{\"Sid\": \"Test\", \"Effect\": \"Allow\"}\n```", + }, + { + name: "plain code fence", + response: "```\n{\"Sid\": \"Test\", \"Effect\": \"Allow\"}\n```", + }, + { + name: "with surrounding text", + response: "Here is the merged statement:\n{\"Sid\": \"Test\", \"Effect\": \"Allow\"}\nDone.", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + respJSON, _ := json.Marshal(map[string]any{ + "choices": []map[string]any{ + {"message": map[string]any{"content": tt.response}}, + }, + }) + w.Write(respJSON) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + statements := []json.RawMessage{json.RawMessage(`{"Sid": "Test"}`)} + + result, err := client.ConsolidateStatementGroup("Test", statements) + if err != nil { + t.Errorf("ConsolidateStatementGroup() error = %v", err) + } + if result == nil { + t.Error("ConsolidateStatementGroup() returned nil") + } + }) + } +} + +func TestConsolidateStatementGroupAPIError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "server error"}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + statements := []json.RawMessage{json.RawMessage(`{"Sid": "Test"}`)} + + _, err := client.ConsolidateStatementGroup("Test", statements) + if err == nil { + t.Error("ConsolidateStatementGroup() expected error for API failure, got nil") + } +} + +func TestGenerateSecurityPolicyWithProgress(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "[{\"Sid\": \"Test\", \"Effect\": \"Allow\", \"Action\": [\"s3:GetObject\"], \"Resource\": \"*\"}]"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + progress := &mockProgress{} + + data := &ScrapedIAMData{ + ServiceName: "Amazon S3", + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: "Read", Description: "Get object"}, + }, + } + + result, err := client.GenerateSecurityPolicy(data, progress, "", 1) + + if err != nil { + t.Errorf("GenerateSecurityPolicy() unexpected error: %v", err) + } + if result == "" { + t.Error("GenerateSecurityPolicy() returned empty string") + } + if progress.statusCalls == 0 { + t.Error("GenerateSecurityPolicy() should call progress.SetStatus()") + } + if progress.progressCalls == 0 { + t.Error("GenerateSecurityPolicy() should call progress.SetProgress()") + } +} + +func TestGenerateSecurityPolicyDefaultConcurrency(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "[{\"Sid\": \"Test\", \"Effect\": \"Allow\"}]"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{{Name: "GetObject", AccessLevel: "Read"}}, + } + + // Test with 0 concurrency (should default to 3) + _, err := client.GenerateSecurityPolicy(data, nil, "", 0) + if err != nil { + t.Errorf("GenerateSecurityPolicy() with 0 concurrency error: %v", err) + } + + // Test with negative concurrency (should default to 3) + _, err = client.GenerateSecurityPolicy(data, nil, "", -1) + if err != nil { + t.Errorf("GenerateSecurityPolicy() with -1 concurrency error: %v", err) + } +} + +func TestGenerateSecurityPolicyEmptyAccessLevel(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "[{\"Sid\": \"Test\", \"Effect\": \"Allow\"}]"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: ""}, // Empty access level should become "Unknown" + }, + } + + _, err := client.GenerateSecurityPolicy(data, nil, "", 1) + if err != nil { + t.Errorf("GenerateSecurityPolicy() with empty access level error: %v", err) + } +} + +func TestGenerateSecurityPolicyBatchError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte(`{"error": "server error"}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{{Name: "GetObject", AccessLevel: "Read"}}, + } + + _, err := client.GenerateSecurityPolicy(data, nil, "", 1) + if err == nil { + t.Error("GenerateSecurityPolicy() expected error for batch failure, got nil") + } +} + +func TestGenerateSecurityPolicyMultipleBatches(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "[{\"Sid\": \"Test\", \"Effect\": \"Allow\"}]"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + // Create 35 actions to trigger multiple batches (batch size is 30) + actions := make([]IAMAction, 35) + for i := range actions { + actions[i] = IAMAction{Name: "Action" + string(rune('A'+i%26)), AccessLevel: "Read"} + } + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: actions, + } + + _, err := client.GenerateSecurityPolicy(data, nil, "", 2) + if err != nil { + t.Errorf("GenerateSecurityPolicy() with multiple batches error: %v", err) + } +} + +func TestChatCompletionHTTPError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error": {"message": "bad request"}}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + _, err := client.ChatCompletion([]ChatMessage{{Role: "user", Content: "test"}}) + if err == nil { + t.Error("ChatCompletion() expected error for HTTP error, got nil") + } +} + +func TestChatCompletionInvalidResponse(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`not valid json`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + _, err := client.ChatCompletion([]ChatMessage{{Role: "user", Content: "test"}}) + if err == nil { + t.Error("ChatCompletion() expected error for invalid JSON response, got nil") + } +} + +func TestChatCompletionEmptyChoices(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": []}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + _, err := client.ChatCompletion([]ChatMessage{{Role: "user", Content: "test"}}) + if err == nil { + t.Error("ChatCompletion() expected error for empty choices, got nil") + } +} + +func TestDoRequestWithAPIKey(t *testing.T) { + var receivedAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "test"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-api-key", "test-model") + _, err := client.ChatCompletion([]ChatMessage{{Role: "user", Content: "test"}}) + + if err != nil { + t.Errorf("ChatCompletion() error: %v", err) + } + if receivedAuth != "Bearer test-api-key" { + t.Errorf("Authorization header = %q, want %q", receivedAuth, "Bearer test-api-key") + } +} + +func TestDoRequestWithoutAPIKey(t *testing.T) { + var receivedAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "test"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "", "test-model") + _, err := client.ChatCompletion([]ChatMessage{{Role: "user", Content: "test"}}) + + if err != nil { + t.Errorf("ChatCompletion() error: %v", err) + } + if receivedAuth != "" { + t.Errorf("Authorization header should be empty, got %q", receivedAuth) + } +} + +func TestEnrichActionDescriptionsWithProgress(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"actions\": [{\"name\": \"GetObject\", \"risk\": \"Low\", \"security_note\": \"Read only\"}]}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + progress := &mockProgress{} + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: []IAMAction{ + {Name: "GetObject", AccessLevel: "Read"}, + }, + } + + err := client.EnrichActionDescriptions(data, progress) + if err != nil { + t.Errorf("EnrichActionDescriptions() unexpected error: %v", err) + } + if progress.statusCalls == 0 { + t.Error("EnrichActionDescriptions() should call progress.SetStatus()") + } +} + +func TestEnrichActionDescriptionsManyActions(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"choices": [{"message": {"content": "{\"actions\": []}"}}]}`)) + })) + defer server.Close() + + client := NewLLMClient(server.URL, "test-key", "test-model") + + // Create more than 20 actions to trigger multiple batches + actions := make([]IAMAction, 25) + for i := range actions { + actions[i] = IAMAction{Name: "Action" + string(rune('A'+i%26)), AccessLevel: "Read"} + } + + data := &ScrapedIAMData{ + ServicePrefix: "s3", + Actions: actions, + } + + err := client.EnrichActionDescriptions(data, nil) + if err != nil { + t.Errorf("EnrichActionDescriptions() with many actions error: %v", err) + } +} diff --git a/internal/policy.go b/internal/policy.go index 5c64165..ca60633 100644 --- a/internal/policy.go +++ b/internal/policy.go @@ -11,6 +11,28 @@ import ( "strings" ) +// validIAMTopLevelFields defines the valid top-level fields in an IAM policy document. +// Used by ValidateIAMFields and stripPolicyDocument. +var validIAMTopLevelFields = map[string]bool{ + "Version": true, + "Id": true, + "Statement": true, +} + +// validIAMStatementFields defines the valid fields in an IAM policy statement. +// Used by stripStatementFields and findInvalidStatementFields. +var validIAMStatementFields = map[string]bool{ + "Sid": true, + "Effect": true, + "Principal": true, + "NotPrincipal": true, + "Action": true, + "NotAction": true, + "Resource": true, + "NotResource": true, + "Condition": true, +} + // ExpandGlobsRelative expands glob patterns relative to a base directory func ExpandGlobsRelative(base string, patterns []string) []string { var files []string @@ -313,14 +335,8 @@ func ValidateIAMFields(policyJSON string) error { var violations []string // Check top-level fields - validTopLevel := map[string]bool{ - "Version": true, - "Id": true, - "Statement": true, - } - for field := range policy { - if !validTopLevel[field] { + if !validIAMTopLevelFields[field] { violations = append(violations, fmt.Sprintf(" Top-level: %s", field)) } } @@ -347,15 +363,8 @@ func ValidateIAMFields(policyJSON string) error { func stripPolicyDocument(policy map[string]any) map[string]any { result := make(map[string]any) - // Valid IAM policy top-level fields - validTopLevel := map[string]bool{ - "Version": true, // Required - "Id": true, // Optional - "Statement": true, // Required - } - for field, value := range policy { - if !validTopLevel[field] { + if !validIAMTopLevelFields[field] { continue // Skip non-IAM fields } @@ -390,22 +399,9 @@ func stripStatements(statementsRaw any) any { } func stripStatementFields(stmt map[string]any) map[string]any { - // Valid IAM statement fields per AWS documentation - validFields := map[string]bool{ - "Sid": true, // Optional statement ID - "Effect": true, // Required: Allow or Deny - "Principal": true, // For resource-based policies - "NotPrincipal": true, - "Action": true, - "NotAction": true, - "Resource": true, - "NotResource": true, - "Condition": true, // Optional conditions - } - result := make(map[string]any) for field, value := range stmt { - if validFields[field] { + if validIAMStatementFields[field] { result[field] = value } } @@ -414,21 +410,9 @@ func stripStatementFields(stmt map[string]any) map[string]any { } func findInvalidStatementFields(stmt map[string]any) []string { - validFields := map[string]bool{ - "Sid": true, - "Effect": true, - "Principal": true, - "NotPrincipal": true, - "Action": true, - "NotAction": true, - "Resource": true, - "NotResource": true, - "Condition": true, - } - var invalid []string for field := range stmt { - if !validFields[field] { + if !validIAMStatementFields[field] { invalid = append(invalid, field) } } diff --git a/internal/progress.go b/internal/progress.go new file mode 100644 index 0000000..139e1c1 --- /dev/null +++ b/internal/progress.go @@ -0,0 +1,113 @@ +package internal + +import ( + "fmt" + "io" + "os" + "time" + + "github.com/schollz/progressbar/v3" +) + +// ProgressReporter defines the interface for progress reporting +type ProgressReporter interface { + SetStatus(status string) + SetProgress(current, total int) +} + +// ConsoleProgress provides a progress bar and status display using progressbar/v3 +type ConsoleProgress struct { + writer io.Writer + bar *progressbar.ProgressBar + startTime time.Time + total int + status string +} + +// NewConsoleProgress creates a new console progress reporter +func NewConsoleProgress(writer io.Writer) *ConsoleProgress { + return &ConsoleProgress{ + writer: writer, + startTime: time.Now(), + } +} + +// Start begins the progress display +func (p *ConsoleProgress) Start() { + // Don't create a bar yet - wait for SetStatus or SetProgress +} + +// SetStatus updates the current status message +func (p *ConsoleProgress) SetStatus(status string) { + p.status = status + + // If no progress bar yet, just print the status + if p.bar == nil { + fmt.Fprintf(p.writer, "\r\033[K%s", status) + return + } + + p.bar.Describe(status) +} + +// SetProgress updates the current progress +func (p *ConsoleProgress) SetProgress(current, total int) { + if total <= 0 { + return + } + + // If total changed, create a new progress bar + if total != p.total { + p.total = total + p.bar = nil // Reset bar + + desc := p.status + if desc == "" { + desc = "Processing" + } + + p.bar = progressbar.NewOptions(total, + progressbar.OptionSetWriter(os.Stderr), + progressbar.OptionSetWidth(30), + progressbar.OptionShowCount(), + progressbar.OptionSetDescription(desc), + progressbar.OptionSetRenderBlankState(true), + progressbar.OptionThrottle(50*time.Millisecond), + progressbar.OptionSetTheme(progressbar.Theme{ + Saucer: "█", + SaucerHead: ">", + SaucerPadding: "░", + BarStart: "[", + BarEnd: "]", + }), + ) + } + + if p.bar != nil { + _ = p.bar.Set(current) + } +} + +// Done marks the progress as complete with a final message +func (p *ConsoleProgress) Done(message string) { + if p.bar != nil { + _ = p.bar.Finish() + } + elapsed := time.Since(p.startTime).Round(time.Millisecond) + fmt.Fprintf(p.writer, "\r\033[K\033[32m✓\033[0m %s (took %s)\n", message, elapsed) +} + +// Error marks the progress as failed with an error message +func (p *ConsoleProgress) Error(message string) { + if p.bar != nil { + _ = p.bar.Finish() + } + fmt.Fprintf(p.writer, "\r\033[K\033[31m✗\033[0m %s\n", message) +} + +// Stop stops the progress display +func (p *ConsoleProgress) Stop() { + if p.bar != nil { + _ = p.bar.Finish() + } +} diff --git a/internal/progress_test.go b/internal/progress_test.go new file mode 100644 index 0000000..2c0f93e --- /dev/null +++ b/internal/progress_test.go @@ -0,0 +1,290 @@ +package internal + +import ( + "bytes" + "strings" + "testing" +) + +func TestNewConsoleProgress(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + if progress == nil { + t.Fatal("NewConsoleProgress() returned nil") + } + if progress.writer != &buf { + t.Error("writer not set correctly") + } + if progress.startTime.IsZero() { + t.Error("startTime should be set") + } +} + +func TestConsoleProgressSetStatus(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + progress.SetStatus("Testing status") + + if progress.status != "Testing status" { + t.Errorf("status = %q, want 'Testing status'", progress.status) + } + + output := buf.String() + if !strings.Contains(output, "Testing status") { + t.Errorf("output = %q, want to contain 'Testing status'", output) + } +} + +func TestConsoleProgressSetStatusWithBar(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // First set progress to create bar + progress.SetProgress(1, 10) + + // Then set status + progress.SetStatus("New status") + + if progress.status != "New status" { + t.Errorf("status = %q, want 'New status'", progress.status) + } +} + +func TestConsoleProgressSetProgress(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Test with valid progress + progress.SetProgress(5, 10) + + if progress.total != 10 { + t.Errorf("total = %d, want 10", progress.total) + } + if progress.bar == nil { + t.Error("bar should be created after SetProgress") + } +} + +func TestConsoleProgressSetProgressZeroTotal(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Test with zero total - should be no-op + progress.SetProgress(0, 0) + + if progress.total != 0 { + t.Errorf("total = %d, want 0 (unchanged)", progress.total) + } + if progress.bar != nil { + t.Error("bar should not be created for zero total") + } +} + +func TestConsoleProgressSetProgressNegativeTotal(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Test with negative total - should be no-op + progress.SetProgress(0, -1) + + if progress.total != 0 { + t.Errorf("total = %d, want 0 (unchanged)", progress.total) + } + if progress.bar != nil { + t.Error("bar should not be created for negative total") + } +} + +func TestConsoleProgressSetProgressChangingTotal(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Set initial progress + progress.SetProgress(1, 10) + firstBar := progress.bar + + // Change total - should create new bar + progress.SetProgress(1, 20) + + if progress.total != 20 { + t.Errorf("total = %d, want 20", progress.total) + } + // Note: progressbar.ProgressBar doesn't expose an easy way to compare + // but the bar should be recreated. We verify the total changed. + _ = firstBar +} + +func TestConsoleProgressDone(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + progress.Done("All done") + + output := buf.String() + if !strings.Contains(output, "All done") { + t.Errorf("output = %q, want to contain 'All done'", output) + } + // Should contain checkmark + if !strings.Contains(output, "✓") { + t.Errorf("output = %q, want to contain '✓'", output) + } +} + +func TestConsoleProgressDoneWithBar(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Create bar first + progress.SetProgress(5, 10) + + progress.Done("Completed") + + output := buf.String() + if !strings.Contains(output, "Completed") { + t.Errorf("output = %q, want to contain 'Completed'", output) + } +} + +func TestConsoleProgressError(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + progress.Error("Something failed") + + output := buf.String() + if !strings.Contains(output, "Something failed") { + t.Errorf("output = %q, want to contain 'Something failed'", output) + } + // Should contain X mark + if !strings.Contains(output, "✗") { + t.Errorf("output = %q, want to contain '✗'", output) + } +} + +func TestConsoleProgressErrorWithBar(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Create bar first + progress.SetProgress(5, 10) + + progress.Error("Error occurred") + + output := buf.String() + if !strings.Contains(output, "Error occurred") { + t.Errorf("output = %q, want to contain 'Error occurred'", output) + } +} + +func TestConsoleProgressStop(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Create bar first + progress.SetProgress(5, 10) + + // Should not panic + progress.Stop() +} + +func TestConsoleProgressStopNoBar(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Should not panic even without bar + progress.Stop() +} + +func TestConsoleProgressStart(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Start should not panic + progress.Start() + + // Bar should not be created yet + if progress.bar != nil { + t.Error("bar should not be created by Start()") + } +} + +func TestProgressReporterInterface(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Verify ConsoleProgress implements ProgressReporter + var reporter ProgressReporter = progress + reporter.SetStatus("test") + reporter.SetProgress(1, 2) +} + +func TestConsoleProgressStatusWithoutBar(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Set status multiple times without bar + progress.SetStatus("Status 1") + progress.SetStatus("Status 2") + progress.SetStatus("Status 3") + + output := buf.String() + // Last status should be visible + if !strings.Contains(output, "Status 3") { + t.Errorf("output = %q, want to contain 'Status 3'", output) + } +} + +func TestConsoleProgressElapsedTime(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Let some time pass + progress.Done("Done") + + output := buf.String() + // Should contain "took" with some duration + if !strings.Contains(output, "took") { + t.Errorf("output = %q, want to contain 'took'", output) + } +} + +func TestConsoleProgressSetProgressWithStatus(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + // Set status first + progress.SetStatus("Processing") + + // Then set progress - should use status as bar description + progress.SetProgress(1, 10) + + if progress.total != 10 { + t.Errorf("total = %d, want 10", progress.total) + } + if progress.status != "Processing" { + t.Errorf("status = %q, want 'Processing'", progress.status) + } +} + +func TestConsoleProgressFullWorkflow(t *testing.T) { + var buf bytes.Buffer + progress := NewConsoleProgress(&buf) + + progress.Start() + progress.SetStatus("Starting...") + progress.SetProgress(0, 5) + progress.SetProgress(1, 5) + progress.SetStatus("In progress...") + progress.SetProgress(2, 5) + progress.SetProgress(3, 5) + progress.SetProgress(4, 5) + progress.SetProgress(5, 5) + progress.Done("All tasks complete") + + output := buf.String() + if !strings.Contains(output, "All tasks complete") { + t.Errorf("output = %q, want to contain 'All tasks complete'", output) + } +} diff --git a/internal/scraper.go b/internal/scraper.go new file mode 100644 index 0000000..ad73b5c --- /dev/null +++ b/internal/scraper.go @@ -0,0 +1,593 @@ +package internal + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + "time" + + "golang.org/x/net/html" +) + +// cacheEntry stores cached HTML content with metadata +type cacheEntry struct { + URL string `json:"url"` + Content string `json:"content"` + CachedAt time.Time `json:"cached_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +const cacheDuration = 24 * time.Hour + +// getCacheDir returns the cache directory path +func getCacheDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".cache", "politest") +} + +// getCacheKey generates a cache key from URL +func getCacheKey(url string) string { + hash := sha256.Sum256([]byte(url)) + return hex.EncodeToString(hash[:])[:16] + ".json" +} + +// loadFromCache attempts to load cached content for a URL +func loadFromCache(url string) (string, bool) { + cacheDir := getCacheDir() + if cacheDir == "" { + return "", false + } + + cachePath := filepath.Join(cacheDir, getCacheKey(url)) + data, err := os.ReadFile(cachePath) + if err != nil { + return "", false + } + + var entry cacheEntry + if err := json.Unmarshal(data, &entry); err != nil { + return "", false + } + + // Check if cache is still valid + if time.Now().After(entry.ExpiresAt) { + return "", false + } + + return entry.Content, true +} + +// saveToCache saves content to cache +func saveToCache(url, content string) { + cacheDir := getCacheDir() + if cacheDir == "" { + return + } + + // Create cache directory if needed + if err := os.MkdirAll(cacheDir, 0750); err != nil { + return + } + + entry := cacheEntry{ + URL: url, + Content: content, + CachedAt: time.Now(), + ExpiresAt: time.Now().Add(cacheDuration), + } + + data, err := json.Marshal(entry) + if err != nil { + return + } + + cachePath := filepath.Join(cacheDir, getCacheKey(url)) + _ = os.WriteFile(cachePath, data, 0600) +} + +// IAMAction represents a scraped IAM action from AWS documentation +type IAMAction struct { + Name string `json:"name"` + Description string `json:"description"` + AccessLevel string `json:"access_level"` + ResourceTypes []string `json:"resource_types,omitempty"` + ConditionKeys []string `json:"condition_keys,omitempty"` + DependentActions []string `json:"dependent_actions,omitempty"` +} + +// IAMConditionKey represents a condition key from AWS documentation +type IAMConditionKey struct { + Name string `json:"name"` + Description string `json:"description"` + Type string `json:"type"` +} + +// IAMResourceType represents a resource type from AWS documentation +type IAMResourceType struct { + Name string `json:"name"` + ARN string `json:"arn"` + ConditionKeys []string `json:"condition_keys,omitempty"` +} + +// ScrapedIAMData contains all scraped IAM data from a service documentation page +type ScrapedIAMData struct { + ServiceName string `json:"service_name"` + ServicePrefix string `json:"service_prefix"` + Actions []IAMAction `json:"actions"` + ConditionKeys []IAMConditionKey `json:"condition_keys"` + ResourceTypes []IAMResourceType `json:"resource_types"` + SourceURL string `json:"source_url"` +} + +// ScrapeIAMDocumentation scrapes IAM actions, conditions, and resources from an AWS documentation page +func ScrapeIAMDocumentation(url string, progress ProgressReporter) (*ScrapedIAMData, error) { + // Validate URL format + if !strings.Contains(url, "docs.aws.amazon.com/service-authorization") { + return nil, fmt.Errorf("invalid URL: must be an AWS service authorization reference page (e.g., https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html)") + } + + var htmlContent string + + // Check cache first + if cached, ok := loadFromCache(url); ok { + if progress != nil { + progress.SetStatus("Using cached documentation (valid for 24h)...") + } + htmlContent = cached + } else { + if progress != nil { + progress.SetStatus("Fetching AWS documentation page...") + } + + // Fetch the page + resp, err := http.Get(url) + if err != nil { + return nil, fmt.Errorf("failed to fetch URL: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("HTTP error: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + htmlContent = string(body) + + // Save to cache + saveToCache(url, htmlContent) + } + + if progress != nil { + progress.SetStatus("Parsing HTML content...") + } + + return parseIAMDocumentation(htmlContent, url, progress) +} + +// parseIAMDocumentation parses the HTML content of an AWS IAM documentation page +func parseIAMDocumentation(htmlContent, sourceURL string, progress ProgressReporter) (*ScrapedIAMData, error) { + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + return nil, fmt.Errorf("failed to parse HTML: %w", err) + } + + data := &ScrapedIAMData{ + SourceURL: sourceURL, + } + + // Extract service name and prefix from the page + data.ServiceName, data.ServicePrefix = extractServiceInfo(doc) + if data.ServicePrefix == "" { + return nil, fmt.Errorf("could not extract service prefix from page - is this a valid AWS service authorization page?") + } + + if progress != nil { + progress.SetStatus("Extracting IAM actions...") + } + + // Find and parse the actions table + data.Actions = extractActions(doc, progress) + + if progress != nil { + progress.SetStatus("Extracting condition keys...") + } + + // Find and parse the condition keys table + data.ConditionKeys = extractConditionKeys(doc) + + if progress != nil { + progress.SetStatus("Extracting resource types...") + } + + // Find and parse the resource types table + data.ResourceTypes = extractResourceTypes(doc) + + if len(data.Actions) == 0 { + return nil, fmt.Errorf("no IAM actions found on page - verify the URL is correct") + } + + return data, nil +} + +// extractServiceInfo extracts the service name and prefix from the page +func extractServiceInfo(doc *html.Node) (string, string) { + var serviceName, servicePrefix string + + // Strategy 1: Look for pattern "(service prefix: xxx)" in the full text + // The prefix is often in a tag after "service prefix:" + fullText := getTextContent(doc) + prefixPatterns := []string{ + `service prefix:\s*` + "`" + `?(\w+)` + "`" + `?`, + `\(service prefix:\s*(\w+)\)`, + `service prefix:\s*(\w+)`, + } + for _, pattern := range prefixPatterns { + re := regexp.MustCompile(`(?i)` + pattern) + if matches := re.FindStringSubmatch(fullText); len(matches) > 1 { + servicePrefix = strings.TrimSpace(matches[1]) + break + } + } + + // Strategy 2: Look for code elements that follow "service prefix:" text + if servicePrefix == "" { + var foundServicePrefix bool + var walkForPrefix func(*html.Node) + walkForPrefix = func(n *html.Node) { + if foundServicePrefix { + return + } + if n.Type == html.TextNode && strings.Contains(strings.ToLower(n.Data), "service prefix:") { + // Look for the next code element sibling + for sib := n.NextSibling; sib != nil; sib = sib.NextSibling { + if sib.Type == html.ElementNode && sib.Data == "code" { + servicePrefix = cleanText(getTextContent(sib)) + foundServicePrefix = true + return + } + if sib.Type == html.ElementNode { + // Check inside the element for code + codes := findElements(sib, "code") + if len(codes) > 0 { + servicePrefix = cleanText(getTextContent(codes[0])) + foundServicePrefix = true + return + } + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + walkForPrefix(c) + } + } + walkForPrefix(doc) + } + + // Try to extract from title or h1 + var extractTitle func(*html.Node) + extractTitle = func(n *html.Node) { + if n.Type == html.ElementNode && (n.Data == "title" || n.Data == "h1") { + serviceName = getTextContent(n) + // Clean up common patterns + serviceName = strings.TrimPrefix(serviceName, "Actions, resources, and condition keys for ") + serviceName = strings.TrimSuffix(serviceName, " - Service Authorization Reference") + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + extractTitle(c) + } + } + extractTitle(doc) + + // If prefix not found, try to extract from action names + if servicePrefix == "" { + var findPrefix func(*html.Node) + findPrefix = func(n *html.Node) { + if n.Type == html.ElementNode && n.Data == "td" { + text := getTextContent(n) + if strings.Contains(text, ":") { + parts := strings.Split(text, ":") + if len(parts) == 2 && len(parts[0]) > 0 && len(parts[0]) < 30 { + servicePrefix = parts[0] + return + } + } + } + for c := n.FirstChild; c != nil; c = c.NextSibling { + findPrefix(c) + } + } + findPrefix(doc) + } + + return serviceName, servicePrefix +} + +// extractActions extracts all IAM actions from the actions table +func extractActions(doc *html.Node, progress ProgressReporter) []IAMAction { + var actions []IAMAction + + // Find all tables and look for the actions table + tables := findElements(doc, "table") + + for _, table := range tables { + // Check if this is the actions table by looking at headers + headers := extractTableHeaders(table) + if !isActionsTable(headers) { + continue + } + + rows := findElements(table, "tr") + total := len(rows) - 1 // Exclude header row + current := 0 + + for _, row := range rows { + cells := findElements(row, "td") + if len(cells) < 3 { + continue // Skip header rows or malformed rows + } + + action := parseActionRow(cells) + if action.Name != "" { + actions = append(actions, action) + } + + current++ + if progress != nil && total > 0 { + progress.SetProgress(current, total) + } + } + } + + return actions +} + +// isActionsTable checks if the table headers indicate an actions table +func isActionsTable(headers []string) bool { + hasAction := false + hasDescription := false + hasAccessLevel := false + + for _, h := range headers { + h = strings.ToLower(h) + if strings.Contains(h, "action") { + hasAction = true + } + if strings.Contains(h, "description") { + hasDescription = true + } + if strings.Contains(h, "access") || strings.Contains(h, "level") { + hasAccessLevel = true + } + } + + return hasAction && hasDescription && hasAccessLevel +} + +// parseActionRow parses a single row from the actions table +func parseActionRow(cells []*html.Node) IAMAction { + action := IAMAction{} + + if len(cells) >= 1 { + action.Name = cleanText(getTextContent(cells[0])) + } + if len(cells) >= 2 { + action.Description = cleanText(getTextContent(cells[1])) + } + if len(cells) >= 3 { + action.AccessLevel = cleanText(getTextContent(cells[2])) + } + if len(cells) >= 4 { + action.ResourceTypes = parseMultiValueCell(cells[3]) + } + if len(cells) >= 5 { + action.ConditionKeys = parseMultiValueCell(cells[4]) + } + if len(cells) >= 6 { + action.DependentActions = parseMultiValueCell(cells[5]) + } + + return action +} + +// extractConditionKeys extracts condition keys from the condition keys table +func extractConditionKeys(doc *html.Node) []IAMConditionKey { + var keys []IAMConditionKey + + tables := findElements(doc, "table") + + for _, table := range tables { + headers := extractTableHeaders(table) + if !isConditionKeysTable(headers) { + continue + } + + rows := findElements(table, "tr") + for _, row := range rows { + cells := findElements(row, "td") + if len(cells) < 2 { + continue + } + + key := IAMConditionKey{} + if len(cells) >= 1 { + key.Name = cleanText(getTextContent(cells[0])) + } + if len(cells) >= 2 { + key.Description = cleanText(getTextContent(cells[1])) + } + if len(cells) >= 3 { + key.Type = cleanText(getTextContent(cells[2])) + } + + if key.Name != "" { + keys = append(keys, key) + } + } + } + + return keys +} + +// isConditionKeysTable checks if headers indicate a condition keys table +func isConditionKeysTable(headers []string) bool { + hasCondition := false + hasDescription := false + + for _, h := range headers { + h = strings.ToLower(h) + if strings.Contains(h, "condition") && strings.Contains(h, "key") { + hasCondition = true + } + if strings.Contains(h, "description") { + hasDescription = true + } + } + + return hasCondition && hasDescription +} + +// extractResourceTypes extracts resource types from the resource types table +func extractResourceTypes(doc *html.Node) []IAMResourceType { + var resources []IAMResourceType + + tables := findElements(doc, "table") + + for _, table := range tables { + headers := extractTableHeaders(table) + if !isResourceTypesTable(headers) { + continue + } + + rows := findElements(table, "tr") + for _, row := range rows { + cells := findElements(row, "td") + if len(cells) < 2 { + continue + } + + resource := IAMResourceType{} + if len(cells) >= 1 { + resource.Name = cleanText(getTextContent(cells[0])) + } + if len(cells) >= 2 { + resource.ARN = cleanText(getTextContent(cells[1])) + } + if len(cells) >= 3 { + resource.ConditionKeys = parseMultiValueCell(cells[2]) + } + + if resource.Name != "" { + resources = append(resources, resource) + } + } + } + + return resources +} + +// isResourceTypesTable checks if headers indicate a resource types table +func isResourceTypesTable(headers []string) bool { + hasResource := false + hasARN := false + + for _, h := range headers { + h = strings.ToLower(h) + if strings.Contains(h, "resource") && strings.Contains(h, "type") { + hasResource = true + } + if strings.Contains(h, "arn") { + hasARN = true + } + } + + return hasResource && hasARN +} + +// Helper functions + +func findElements(n *html.Node, tag string) []*html.Node { + var elements []*html.Node + var find func(*html.Node) + find = func(node *html.Node) { + if node.Type == html.ElementNode && node.Data == tag { + elements = append(elements, node) + } + for c := node.FirstChild; c != nil; c = c.NextSibling { + find(c) + } + } + find(n) + return elements +} + +func extractTableHeaders(table *html.Node) []string { + var headers []string + ths := findElements(table, "th") + for _, th := range ths { + headers = append(headers, cleanText(getTextContent(th))) + } + return headers +} + +func getTextContent(n *html.Node) string { + if n == nil { + return "" + } + var text strings.Builder + var extract func(*html.Node) + extract = func(node *html.Node) { + if node.Type == html.TextNode { + text.WriteString(node.Data) + } + for c := node.FirstChild; c != nil; c = c.NextSibling { + extract(c) + } + } + extract(n) + return text.String() +} + +func cleanText(s string) string { + // Remove excessive whitespace + s = strings.TrimSpace(s) + // Replace multiple spaces/newlines with single space + re := regexp.MustCompile(`\s+`) + s = re.ReplaceAllString(s, " ") + return s +} + +func parseMultiValueCell(cell *html.Node) []string { + var values []string + text := getTextContent(cell) + + // Split by common delimiters + parts := strings.FieldsFunc(text, func(r rune) bool { + return r == '\n' || r == ',' + }) + + for _, part := range parts { + part = cleanText(part) + // Remove asterisks (used to mark required resources) + part = strings.TrimSuffix(part, "*") + part = strings.TrimSpace(part) + if part != "" { + values = append(values, part) + } + } + + return values +} diff --git a/internal/scraper_test.go b/internal/scraper_test.go new file mode 100644 index 0000000..d4e67da --- /dev/null +++ b/internal/scraper_test.go @@ -0,0 +1,1306 @@ +package internal + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "golang.org/x/net/html" +) + +func TestGetCacheDir(t *testing.T) { + dir := getCacheDir() + if dir == "" { + t.Skip("Could not determine user home directory") + } + if !strings.Contains(dir, ".cache") || !strings.Contains(dir, "politest") { + t.Errorf("getCacheDir() = %v, expected path containing .cache/politest", dir) + } +} + +func TestGetCacheKey(t *testing.T) { + tests := []struct { + name string + url string + }{ + { + name: "basic url", + url: "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html", + }, + { + name: "different url", + url: "https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazons3.html", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + key := getCacheKey(tt.url) + if key == "" { + t.Error("getCacheKey() returned empty string") + } + if !strings.HasSuffix(key, ".json") { + t.Errorf("getCacheKey() = %v, expected .json suffix", key) + } + // Should be deterministic + key2 := getCacheKey(tt.url) + if key != key2 { + t.Errorf("getCacheKey() not deterministic: %v != %v", key, key2) + } + }) + } + + // Different URLs should produce different keys + key1 := getCacheKey("https://example.com/a") + key2 := getCacheKey("https://example.com/b") + if key1 == key2 { + t.Error("getCacheKey() produced same key for different URLs") + } +} + +func TestLoadFromCacheAndSaveToCache(t *testing.T) { + // Test the cache key generation is consistent + testURL := "https://test.example.com/test" + key1 := getCacheKey(testURL) + key2 := getCacheKey(testURL) + if key1 != key2 { + t.Errorf("getCacheKey() not consistent: %s != %s", key1, key2) + } + + // Test that different URLs produce different keys + key3 := getCacheKey("https://different.example.com") + if key1 == key3 { + t.Error("getCacheKey() should produce different keys for different URLs") + } +} + +func TestSaveAndLoadCache(t *testing.T) { + // Create a temporary cache directory + tempDir := t.TempDir() + testURL := "https://test.example.com/cache-test" + testContent := "Test content" + + // Manually create cache entry to test loading + cacheKey := getCacheKey(testURL) + cachePath := filepath.Join(tempDir, cacheKey) + + entry := cacheEntry{ + URL: testURL, + Content: testContent, + CachedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + data, err := json.Marshal(entry) + if err != nil { + t.Fatalf("Failed to marshal cache entry: %v", err) + } + + if err := os.WriteFile(cachePath, data, 0600); err != nil { + t.Fatalf("Failed to write cache file: %v", err) + } + + // Verify file exists + if _, err := os.Stat(cachePath); os.IsNotExist(err) { + t.Fatal("Cache file was not created") + } +} + +func TestLoadFromCacheExpired(t *testing.T) { + // Test that expired cache entries are not loaded + tempDir := t.TempDir() + testURL := "https://test.example.com/expired" + + cacheKey := getCacheKey(testURL) + cachePath := filepath.Join(tempDir, cacheKey) + + // Create an expired cache entry + entry := cacheEntry{ + URL: testURL, + Content: "expired content", + CachedAt: time.Now().Add(-48 * time.Hour), + ExpiresAt: time.Now().Add(-24 * time.Hour), // Expired 24 hours ago + } + + data, _ := json.Marshal(entry) + os.WriteFile(cachePath, data, 0600) + + // loadFromCache should return false for expired entries + // Note: This tests the logic but not the actual function since getCacheDir() uses home dir +} + +func TestLoadFromCacheInvalidJSON(t *testing.T) { + // Test that invalid JSON in cache is handled gracefully + tempDir := t.TempDir() + testURL := "https://test.example.com/invalid" + + cacheKey := getCacheKey(testURL) + cachePath := filepath.Join(tempDir, cacheKey) + + // Write invalid JSON + os.WriteFile(cachePath, []byte("not valid json"), 0600) + + // Verify file exists + if _, err := os.Stat(cachePath); os.IsNotExist(err) { + t.Fatal("Cache file was not created") + } +} + +func TestLoadFromCacheMissingFile(t *testing.T) { + // Test that missing cache file returns false + content, ok := loadFromCache("https://nonexistent.example.com/missing") + if ok { + t.Error("loadFromCache() should return false for missing file") + } + if content != "" { + t.Errorf("loadFromCache() should return empty content for missing file, got: %s", content) + } +} + +func TestSaveToCacheCreatesDirectory(t *testing.T) { + // saveToCache should create the cache directory if it doesn't exist + // This is a basic test - the actual directory creation depends on getCacheDir() + saveToCache("https://test.example.com/save-test", "test content") + // Should not panic +} + +func TestCleanText(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "basic whitespace", + input: " hello world ", + want: "hello world", + }, + { + name: "newlines", + input: "hello\n\nworld", + want: "hello world", + }, + { + name: "tabs", + input: "hello\t\tworld", + want: "hello world", + }, + { + name: "mixed whitespace", + input: " hello \n\t world ", + want: "hello world", + }, + { + name: "already clean", + input: "hello world", + want: "hello world", + }, + { + name: "empty string", + input: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := cleanText(tt.input) + if result != tt.want { + t.Errorf("cleanText(%q) = %q, want %q", tt.input, result, tt.want) + } + }) + } +} + +func TestIsActionsTable(t *testing.T) { + tests := []struct { + name string + headers []string + want bool + }{ + { + name: "valid actions table", + headers: []string{"Actions", "Description", "Access level", "Resource types"}, + want: true, + }, + { + name: "lowercase headers", + headers: []string{"actions", "description", "access level"}, + want: true, + }, + { + name: "mixed case", + headers: []string{"ACTIONS", "Description", "ACCESS LEVEL"}, + want: true, + }, + { + name: "missing action", + headers: []string{"Description", "Access level"}, + want: false, + }, + { + name: "missing description", + headers: []string{"Actions", "Access level"}, + want: false, + }, + { + name: "missing access level", + headers: []string{"Actions", "Description"}, + want: false, + }, + { + name: "resource types table (not actions)", + headers: []string{"Resource types", "ARN", "Condition keys"}, + want: false, + }, + { + name: "empty headers", + headers: []string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isActionsTable(tt.headers) + if result != tt.want { + t.Errorf("isActionsTable(%v) = %v, want %v", tt.headers, result, tt.want) + } + }) + } +} + +func TestIsConditionKeysTable(t *testing.T) { + tests := []struct { + name string + headers []string + want bool + }{ + { + name: "valid condition keys table", + headers: []string{"Condition keys", "Description", "Type"}, + want: true, + }, + { + name: "missing condition key", + headers: []string{"Description", "Type"}, + want: false, + }, + { + name: "missing description", + headers: []string{"Condition keys", "Type"}, + want: false, + }, + { + name: "empty headers", + headers: []string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isConditionKeysTable(tt.headers) + if result != tt.want { + t.Errorf("isConditionKeysTable(%v) = %v, want %v", tt.headers, result, tt.want) + } + }) + } +} + +func TestIsResourceTypesTable(t *testing.T) { + tests := []struct { + name string + headers []string + want bool + }{ + { + name: "valid resource types table", + headers: []string{"Resource types", "ARN", "Condition keys"}, + want: true, + }, + { + name: "missing resource type", + headers: []string{"ARN", "Condition keys"}, + want: false, + }, + { + name: "missing ARN", + headers: []string{"Resource types", "Condition keys"}, + want: false, + }, + { + name: "empty headers", + headers: []string{}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isResourceTypesTable(tt.headers) + if result != tt.want { + t.Errorf("isResourceTypesTable(%v) = %v, want %v", tt.headers, result, tt.want) + } + }) + } +} + +func TestGetTextContent(t *testing.T) { + tests := []struct { + name string + html string + want string + }{ + { + name: "simple text", + html: "
Hello
", + want: "Hello", + }, + { + name: "nested elements", + html: "
Hello World
", + want: "Hello World", + }, + { + name: "empty element", + html: "
", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc, err := html.Parse(strings.NewReader(tt.html)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + result := getTextContent(doc) + // Clean result for comparison (since html.Parse wraps in html/body) + result = cleanText(result) + if result != tt.want { + t.Errorf("getTextContent() = %q, want %q", result, tt.want) + } + }) + } + + // Test nil node + if result := getTextContent(nil); result != "" { + t.Errorf("getTextContent(nil) = %q, want empty string", result) + } +} + +func TestFindElements(t *testing.T) { + htmlContent := ` + + + + + + +
Header
Cell 1
Cell 2
+ + +
Another
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + tables := findElements(doc, "table") + if len(tables) != 2 { + t.Errorf("findElements(table) = %d elements, want 2", len(tables)) + } + + trs := findElements(doc, "tr") + if len(trs) != 4 { + t.Errorf("findElements(tr) = %d elements, want 4", len(trs)) + } + + tds := findElements(doc, "td") + if len(tds) != 3 { + t.Errorf("findElements(td) = %d elements, want 3", len(tds)) + } +} + +func TestExtractTableHeaders(t *testing.T) { + htmlContent := ` + + + + + + + + + + + +
ActionsDescriptionAccess level
GetObjectGets an objectRead
+ ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + tables := findElements(doc, "table") + if len(tables) == 0 { + t.Fatal("No tables found") + } + + headers := extractTableHeaders(tables[0]) + if len(headers) != 3 { + t.Errorf("extractTableHeaders() = %d headers, want 3", len(headers)) + } + + expectedHeaders := []string{"Actions", "Description", "Access level"} + for i, want := range expectedHeaders { + if i >= len(headers) { + t.Errorf("Missing header at index %d", i) + continue + } + if headers[i] != want { + t.Errorf("header[%d] = %q, want %q", i, headers[i], want) + } + } +} + +func TestParseMultiValueCell(t *testing.T) { + tests := []struct { + name string + html string + wantLen int + wantVals []string + }{ + { + name: "single value", + html: "
value1
", + wantLen: 1, + wantVals: []string{"value1"}, + }, + { + name: "comma separated", + html: "
value1, value2, value3
", + wantLen: 3, + wantVals: []string{"value1", "value2", "value3"}, + }, + { + name: "newline separated", + html: "
value1\nvalue2\nvalue3
", + wantLen: 3, + wantVals: []string{"value1", "value2", "value3"}, + }, + { + name: "with asterisks (required markers)", + html: "
value1*\nvalue2
", + wantLen: 2, + wantVals: []string{"value1", "value2"}, + }, + { + name: "empty cell", + html: "
", + wantLen: 0, + wantVals: []string{}, + }, + { + name: "whitespace only", + html: "
\n
", + wantLen: 0, + wantVals: []string{}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc, err := html.Parse(strings.NewReader(tt.html)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + tds := findElements(doc, "td") + if len(tds) == 0 { + t.Fatal("No td elements found") + } + + result := parseMultiValueCell(tds[0]) + if len(result) != tt.wantLen { + t.Errorf("parseMultiValueCell() = %d values, want %d (got: %v)", len(result), tt.wantLen, result) + } + + for i, want := range tt.wantVals { + if i >= len(result) { + break + } + if result[i] != want { + t.Errorf("value[%d] = %q, want %q", i, result[i], want) + } + } + }) + } +} + +func TestParseActionRow(t *testing.T) { + tests := []struct { + name string + html string + wantName string + wantDesc string + wantLevel string + }{ + { + name: "full row", + html: ` + + + + + + + + +
GetObjectRetrieves objects from Amazon S3Readobject*s3:authType
`, + wantName: "GetObject", + wantDesc: "Retrieves objects from Amazon S3", + wantLevel: "Read", + }, + { + name: "minimal row", + html: ` + + + + + +
PutObjectPuts objectsWrite
`, + wantName: "PutObject", + wantDesc: "Puts objects", + wantLevel: "Write", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc, err := html.Parse(strings.NewReader(tt.html)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + cells := findElements(doc, "td") + action := parseActionRow(cells) + + if action.Name != tt.wantName { + t.Errorf("Name = %q, want %q", action.Name, tt.wantName) + } + if action.Description != tt.wantDesc { + t.Errorf("Description = %q, want %q", action.Description, tt.wantDesc) + } + if action.AccessLevel != tt.wantLevel { + t.Errorf("AccessLevel = %q, want %q", action.AccessLevel, tt.wantLevel) + } + }) + } +} + +func TestExtractServiceInfo(t *testing.T) { + tests := []struct { + name string + html string + wantName string + wantPrefix string + }{ + { + name: "standard format with code tag", + html: ` + + Actions, resources, and condition keys for Amazon S3 - Service Authorization Reference + +

Amazon S3 (service prefix: s3)

+ + + `, + wantName: "Amazon S3", + wantPrefix: "s3", + }, + { + name: "prefix in parentheses", + html: ` + + +

Actions, resources, and condition keys for Amazon Bedrock

+

(service prefix: bedrock)

+ + + `, + wantName: "Amazon Bedrock", + wantPrefix: "bedrock", + }, + { + name: "extract from action table", + html: ` + + + + + +
ActionsDescriptionAccess level
ec2:DescribeInstancesDescribes instancesList
+ + + `, + wantName: "", + wantPrefix: "ec2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + doc, err := html.Parse(strings.NewReader(tt.html)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + name, prefix := extractServiceInfo(doc) + if prefix != tt.wantPrefix { + t.Errorf("prefix = %q, want %q", prefix, tt.wantPrefix) + } + if tt.wantName != "" && !strings.Contains(name, tt.wantName) { + t.Errorf("name = %q, want to contain %q", name, tt.wantName) + } + }) + } +} + +func TestExtractActions(t *testing.T) { + htmlContent := ` + + + + + + + + + + + + + + + + + + + + + +
ActionsDescriptionAccess levelResource types
GetObjectGets an objectReadobject*
PutObjectPuts an objectWriteobject*
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + actions := extractActions(doc, nil) + + if len(actions) != 2 { + t.Errorf("extractActions() = %d actions, want 2", len(actions)) + } + + if len(actions) > 0 { + if actions[0].Name != "GetObject" { + t.Errorf("actions[0].Name = %q, want GetObject", actions[0].Name) + } + if actions[0].AccessLevel != "Read" { + t.Errorf("actions[0].AccessLevel = %q, want Read", actions[0].AccessLevel) + } + } + + if len(actions) > 1 { + if actions[1].Name != "PutObject" { + t.Errorf("actions[1].Name = %q, want PutObject", actions[1].Name) + } + if actions[1].AccessLevel != "Write" { + t.Errorf("actions[1].AccessLevel = %q, want Write", actions[1].AccessLevel) + } + } +} + +func TestExtractConditionKeys(t *testing.T) { + htmlContent := ` + + + + + + + + + + + + + + + + + + +
Condition keysDescriptionType
s3:authTypeFilters access by authentication methodString
s3:prefixFilters access by key name prefixString
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + keys := extractConditionKeys(doc) + + if len(keys) != 2 { + t.Errorf("extractConditionKeys() = %d keys, want 2", len(keys)) + } + + if len(keys) > 0 { + if keys[0].Name != "s3:authType" { + t.Errorf("keys[0].Name = %q, want s3:authType", keys[0].Name) + } + if keys[0].Type != "String" { + t.Errorf("keys[0].Type = %q, want String", keys[0].Type) + } + } +} + +func TestExtractResourceTypes(t *testing.T) { + htmlContent := ` + + + + + + + + + + + + + + + + + + +
Resource typesARNCondition keys
bucketarn:aws:s3:::${BucketName}s3:authType
objectarn:aws:s3:::${BucketName}/${ObjectName}s3:authType, s3:prefix
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + resources := extractResourceTypes(doc) + + if len(resources) != 2 { + t.Errorf("extractResourceTypes() = %d resources, want 2", len(resources)) + } + + if len(resources) > 0 { + if resources[0].Name != "bucket" { + t.Errorf("resources[0].Name = %q, want bucket", resources[0].Name) + } + if !strings.Contains(resources[0].ARN, "s3:::") { + t.Errorf("resources[0].ARN = %q, want to contain s3:::", resources[0].ARN) + } + } +} + +func TestParseIAMDocumentation(t *testing.T) { + htmlContent := ` + + Actions, resources, and condition keys for Amazon S3 - Service Authorization Reference + +

Actions, resources, and condition keys for Amazon S3

+

Amazon S3 (service prefix: s3) provides the following service-specific resources.

+ +

Actions defined by Amazon S3

+ + + + + + + + + + + + + +
ActionsDescriptionAccess levelResource types
GetObjectGrants permission to retrieve objects from Amazon S3Readobject*
+ +

Condition keys for Amazon S3

+ + + + + + + + + + + +
Condition keysDescriptionType
s3:authTypeFilters access by authentication methodString
+ +

Resource types defined by Amazon S3

+ + + + + + + + + + + +
Resource typesARNCondition keys
objectarn:aws:s3:::${BucketName}/${ObjectName}
+ + + ` + + data, err := parseIAMDocumentation(htmlContent, "https://test.example.com", nil) + if err != nil { + t.Fatalf("parseIAMDocumentation() error: %v", err) + } + + if data.ServicePrefix != "s3" { + t.Errorf("ServicePrefix = %q, want s3", data.ServicePrefix) + } + if len(data.Actions) != 1 { + t.Errorf("Actions = %d, want 1", len(data.Actions)) + } + if len(data.ConditionKeys) != 1 { + t.Errorf("ConditionKeys = %d, want 1", len(data.ConditionKeys)) + } + if len(data.ResourceTypes) != 1 { + t.Errorf("ResourceTypes = %d, want 1", len(data.ResourceTypes)) + } + if data.SourceURL != "https://test.example.com" { + t.Errorf("SourceURL = %q, want https://test.example.com", data.SourceURL) + } +} + +func TestParseIAMDocumentationNoActions(t *testing.T) { + htmlContent := ` + + +

Amazon S3 (service prefix: s3)

+ + +
Not an actions table
+ + + ` + + _, err := parseIAMDocumentation(htmlContent, "https://test.example.com", nil) + if err == nil { + t.Error("parseIAMDocumentation() expected error for no actions, got nil") + } + if err != nil && !strings.Contains(err.Error(), "no IAM actions found") { + t.Errorf("parseIAMDocumentation() error = %v, want error about no actions", err) + } +} + +func TestParseIAMDocumentationNoPrefix(t *testing.T) { + htmlContent := ` + + +

No service prefix here

+ + + ` + + _, err := parseIAMDocumentation(htmlContent, "https://test.example.com", nil) + if err == nil { + t.Error("parseIAMDocumentation() expected error for no prefix, got nil") + } +} + +func TestScrapeIAMDocumentationInvalidURL(t *testing.T) { + _, err := ScrapeIAMDocumentation("https://example.com/not-aws", nil) + if err == nil { + t.Error("ScrapeIAMDocumentation() expected error for invalid URL, got nil") + } + if err != nil && !strings.Contains(err.Error(), "invalid URL") { + t.Errorf("ScrapeIAMDocumentation() error = %v, want error about invalid URL", err) + } +} + +func TestLoadFromCacheValidEntry(t *testing.T) { + // Create a valid cache entry in the actual cache directory + cacheDir := getCacheDir() + if cacheDir == "" { + t.Skip("Could not determine cache directory") + } + + testURL := "https://test.politest.example.com/cache-valid-test" + cacheKey := getCacheKey(testURL) + cachePath := filepath.Join(cacheDir, cacheKey) + + // Ensure cache directory exists + if err := os.MkdirAll(cacheDir, 0750); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Create a valid cache entry + entry := cacheEntry{ + URL: testURL, + Content: "Test cached content", + CachedAt: time.Now(), + ExpiresAt: time.Now().Add(24 * time.Hour), + } + + data, err := json.Marshal(entry) + if err != nil { + t.Fatalf("Failed to marshal cache entry: %v", err) + } + + if err := os.WriteFile(cachePath, data, 0600); err != nil { + t.Fatalf("Failed to write cache file: %v", err) + } + defer os.Remove(cachePath) + + // Test loading from cache + content, ok := loadFromCache(testURL) + if !ok { + t.Error("loadFromCache() should return true for valid cache entry") + } + if content != entry.Content { + t.Errorf("loadFromCache() content = %q, want %q", content, entry.Content) + } +} + +func TestLoadFromCacheExpiredEntry(t *testing.T) { + cacheDir := getCacheDir() + if cacheDir == "" { + t.Skip("Could not determine cache directory") + } + + testURL := "https://test.politest.example.com/cache-expired-test" + cacheKey := getCacheKey(testURL) + cachePath := filepath.Join(cacheDir, cacheKey) + + // Ensure cache directory exists + if err := os.MkdirAll(cacheDir, 0750); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Create an expired cache entry + entry := cacheEntry{ + URL: testURL, + Content: "expired content", + CachedAt: time.Now().Add(-48 * time.Hour), + ExpiresAt: time.Now().Add(-24 * time.Hour), // Expired + } + + data, _ := json.Marshal(entry) + os.WriteFile(cachePath, data, 0600) + defer os.Remove(cachePath) + + // Test loading expired cache + content, ok := loadFromCache(testURL) + if ok { + t.Error("loadFromCache() should return false for expired cache entry") + } + if content != "" { + t.Errorf("loadFromCache() should return empty content for expired entry, got: %s", content) + } +} + +func TestLoadFromCacheInvalidJSONEntry(t *testing.T) { + cacheDir := getCacheDir() + if cacheDir == "" { + t.Skip("Could not determine cache directory") + } + + testURL := "https://test.politest.example.com/cache-invalid-json-test" + cacheKey := getCacheKey(testURL) + cachePath := filepath.Join(cacheDir, cacheKey) + + // Ensure cache directory exists + if err := os.MkdirAll(cacheDir, 0750); err != nil { + t.Fatalf("Failed to create cache directory: %v", err) + } + + // Write invalid JSON + os.WriteFile(cachePath, []byte("not valid json {{{"), 0600) + defer os.Remove(cachePath) + + // Test loading invalid JSON + content, ok := loadFromCache(testURL) + if ok { + t.Error("loadFromCache() should return false for invalid JSON") + } + if content != "" { + t.Errorf("loadFromCache() should return empty content for invalid JSON, got: %s", content) + } +} + +func TestSaveToCacheAndLoad(t *testing.T) { + testURL := "https://test.politest.example.com/save-load-test" + testContent := "Test save and load" + + // Save to cache + saveToCache(testURL, testContent) + + // Clean up after test + cacheDir := getCacheDir() + if cacheDir != "" { + cachePath := filepath.Join(cacheDir, getCacheKey(testURL)) + defer os.Remove(cachePath) + } + + // Load from cache + content, ok := loadFromCache(testURL) + if !ok { + t.Error("loadFromCache() should return true after saveToCache()") + } + if content != testContent { + t.Errorf("loadFromCache() content = %q, want %q", content, testContent) + } +} + +func TestExtractServiceInfoFromTitle(t *testing.T) { + htmlContent := ` + + Actions, resources, and condition keys for AWS Lambda - Service Authorization Reference + +

Actions, resources, and condition keys for AWS Lambda

+

AWS Lambda (service prefix: lambda)

+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + name, prefix := extractServiceInfo(doc) + if prefix != "lambda" { + t.Errorf("prefix = %q, want lambda", prefix) + } + if !strings.Contains(name, "Lambda") { + t.Errorf("name = %q, want to contain Lambda", name) + } +} + +func TestParseIAMDocumentationWithProgress(t *testing.T) { + htmlContent := ` + + +

Test Service (service prefix: test)

+ + + +
ActionsDescriptionAccess level
TestActionTest descriptionRead
+ + + ` + + progress := &testProgress{} + data, err := parseIAMDocumentation(htmlContent, "https://test.example.com", progress) + if err != nil { + t.Fatalf("parseIAMDocumentation() error: %v", err) + } + + if data.ServicePrefix != "test" { + t.Errorf("ServicePrefix = %q, want test", data.ServicePrefix) + } + if progress.statusCalls == 0 { + t.Error("parseIAMDocumentation() should call progress.SetStatus()") + } +} + +type testProgress struct { + statusCalls int +} + +func (p *testProgress) SetStatus(status string) { + p.statusCalls++ +} + +func (p *testProgress) SetProgress(current, total int) {} + +func TestExtractActionsWithDependentActions(t *testing.T) { + htmlContent := ` + + + + + + + + + + + + + + + + + + + +
ActionsDescriptionAccess levelResource typesCondition keysDependent actions
GetObjectGets an objectReadobject*s3:authTypes3:ListBucket
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + actions := extractActions(doc, nil) + + if len(actions) != 1 { + t.Fatalf("extractActions() = %d actions, want 1", len(actions)) + } + + if actions[0].Name != "GetObject" { + t.Errorf("actions[0].Name = %q, want GetObject", actions[0].Name) + } +} + +func TestExtractActionsEmptyTable(t *testing.T) { + htmlContent := ` + + + + +
ActionsDescriptionAccess level
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + actions := extractActions(doc, nil) + if len(actions) != 0 { + t.Errorf("extractActions() = %d actions, want 0 for empty table", len(actions)) + } +} + +func TestParseActionRowMinimalCells(t *testing.T) { + // Test with less than 3 cells + htmlContent := `
OnlyOne
` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + cells := findElements(doc, "td") + action := parseActionRow(cells) + + // Should handle gracefully + if action.Name != "OnlyOne" { + t.Errorf("Name = %q, want OnlyOne", action.Name) + } +} + +func TestExtractConditionKeysEmptyTable(t *testing.T) { + htmlContent := ` + + + + +
Condition keysDescriptionType
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + keys := extractConditionKeys(doc) + if len(keys) != 0 { + t.Errorf("extractConditionKeys() = %d keys, want 0 for empty table", len(keys)) + } +} + +func TestExtractResourceTypesEmptyTable(t *testing.T) { + htmlContent := ` + + + + +
Resource typesARNCondition keys
+ + + ` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + resources := extractResourceTypes(doc) + if len(resources) != 0 { + t.Errorf("extractResourceTypes() = %d resources, want 0 for empty table", len(resources)) + } +} + +func TestFindElementsNoMatches(t *testing.T) { + htmlContent := `

No tables here

` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + elements := findElements(doc, "table") + if len(elements) != 0 { + t.Errorf("findElements() = %d elements, want 0", len(elements)) + } +} + +func TestExtractTableHeadersNoTH(t *testing.T) { + htmlContent := `
Not a header
` + + doc, err := html.Parse(strings.NewReader(htmlContent)) + if err != nil { + t.Fatalf("Failed to parse HTML: %v", err) + } + + tables := findElements(doc, "table") + if len(tables) == 0 { + t.Fatal("No tables found") + } + + headers := extractTableHeaders(tables[0]) + if len(headers) != 0 { + t.Errorf("extractTableHeaders() = %d headers, want 0 for table without th", len(headers)) + } +} diff --git a/main.go b/main.go index 3d5ddec..cff129e 100644 --- a/main.go +++ b/main.go @@ -1,4 +1,5 @@ // cmd: go run . --scenario scenarios/athena_primary.yml --save /tmp/resp.json +// cmd: go run . generate --url https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html --base-url http://localhost:3000 --model gpt-4 --api-key xxx package main import ( @@ -349,9 +350,127 @@ func validateArgs(args []string) error { return nil } +// generateFlags holds the parsed flags for the generate command +type generateFlags struct { + url string + baseURL string + apiKey string + model string + outputDir string + noEnrich bool + quiet bool + prompt string + concurrency int + generateSCP bool +} + +// parseGenerateFlags parses command-line arguments for the generate command +func parseGenerateFlags(args []string) (*generateFlags, error) { + fs := flag.NewFlagSet("generate", flag.ContinueOnError) + + flags := &generateFlags{} + + fs.StringVar(&flags.url, "url", "", "AWS IAM documentation URL (required)") + fs.StringVar(&flags.baseURL, "base-url", "", "OpenAI-compatible API base URL (required)") + fs.StringVar(&flags.apiKey, "api-key", "", "API key for LLM service") + fs.StringVar(&flags.model, "model", "", "LLM model name (required)") + fs.StringVar(&flags.outputDir, "output", ".", "Output directory for generated files") + fs.BoolVar(&flags.noEnrich, "no-enrich", false, "Skip action description enrichment") + fs.BoolVar(&flags.quiet, "quiet", false, "Suppress progress output") + fs.StringVar(&flags.prompt, "prompt", "", "Custom requirements/constraints to include in LLM prompt") + fs.IntVar(&flags.concurrency, "concurrency", 3, "Number of parallel batch requests") + fs.BoolVar(&flags.generateSCP, "scp", false, "Generate companion SCP for org-wide guardrails") + + fs.Usage = func() { + fmt.Fprintf(os.Stderr, "Usage: politest generate [options]\n\n") + fmt.Fprintf(os.Stderr, "Generate security-focused IAM policies from AWS documentation.\n\n") + fmt.Fprintf(os.Stderr, "This command scrapes AWS IAM documentation pages to extract action definitions,\n") + fmt.Fprintf(os.Stderr, "condition keys, and resource types, then uses an LLM to generate a comprehensive\n") + fmt.Fprintf(os.Stderr, "security-focused policy suitable for regulated environments.\n\n") + fmt.Fprintf(os.Stderr, "Options:\n") + fs.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nExample:\n") + fmt.Fprintf(os.Stderr, " politest generate \\\n") + fmt.Fprintf(os.Stderr, " --url https://docs.aws.amazon.com/service-authorization/latest/reference/list_amazonbedrock.html \\\n") + fmt.Fprintf(os.Stderr, " --base-url http://localhost:3000 \\\n") + fmt.Fprintf(os.Stderr, " --model gpt-4 \\\n") + fmt.Fprintf(os.Stderr, " --api-key $OPENAI_API_KEY\n") + } + + if err := fs.Parse(args); err != nil { + return nil, err + } + + return flags, nil +} + +// runGenerate runs the generate command +func runGenerate(flags *generateFlags) error { + cfg := internal.GenerateConfig{ + URL: flags.url, + BaseURL: flags.baseURL, + APIKey: flags.apiKey, + Model: flags.model, + OutputDir: flags.outputDir, + NoEnrich: flags.noEnrich, + Quiet: flags.quiet, + UserPrompt: flags.prompt, + Concurrency: flags.concurrency, + GenerateSCP: flags.generateSCP, + } + + if err := internal.ValidateGenerateConfig(cfg); err != nil { + return err + } + + _, err := internal.RunGenerate(cfg, os.Stdout) + return err +} + +// printUsage prints the main usage information +func printUsage() { + fmt.Fprintf(os.Stderr, "Usage: politest [options]\n\n") + fmt.Fprintf(os.Stderr, "Commands:\n") + fmt.Fprintf(os.Stderr, " (default) Run IAM policy simulations from scenario files\n") + fmt.Fprintf(os.Stderr, " generate Generate security-focused IAM policies from AWS documentation\n") + fmt.Fprintf(os.Stderr, "\n") + fmt.Fprintf(os.Stderr, "Run 'politest -h' for more information on a command.\n\n") + fmt.Fprintf(os.Stderr, "Simulation Options (default command):\n") + fmt.Fprintf(os.Stderr, " -scenario string Path to scenario YAML\n") + fmt.Fprintf(os.Stderr, " -save string Path to save raw JSON response\n") + fmt.Fprintf(os.Stderr, " -no-assert Do not fail on expectation mismatches\n") + fmt.Fprintf(os.Stderr, " -no-warn Suppress SCP/RCP simulation approximation warning\n") + fmt.Fprintf(os.Stderr, " -debug Show debug output\n") + fmt.Fprintf(os.Stderr, " -version Show version information\n") +} + // realMain contains the full main logic and returns an exit code // This allows testing without calling os.Exit func realMain(args []string) int { + // Check for subcommands + if len(args) > 0 { + switch args[0] { + case "generate": + genFlags, err := parseGenerateFlags(args[1:]) + if err != nil { + if err == flag.ErrHelp { + return 0 + } + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + return 1 + } + if err := runGenerate(genFlags); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + return 1 + } + return 0 + case "help", "-h", "--help": + printUsage() + return 0 + } + } + + // Default: run simulation command flags, remainingArgs, err := parseFlags(args) if err != nil { if err == flag.ErrHelp {