diff --git a/.github/workflows/package-ci.yml b/.github/workflows/package-ci.yml index 7dfea08..15e3ece 100644 --- a/.github/workflows/package-ci.yml +++ b/.github/workflows/package-ci.yml @@ -8,6 +8,101 @@ on: workflow_dispatch: jobs: + lint: + name: Lint + runs-on: ubuntu-24.04 + permissions: + contents: read + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.mod + + - name: Install BPF build dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang llvm libbpf-dev build-essential pkg-config zlib1g-dev + has_working_bpftool() { + local output + [[ -x "${1:-}" ]] || return 1 + output="$("${1}" version 2>/dev/null)" || return 1 + [[ "${output}" == *"libbpf"* ]] + } + first_working_bpftool() { + for candidate in "$@"; do + if has_working_bpftool "${candidate}"; then + echo "${candidate}" + return 0 + fi + done + return 1 + } + BPFTOOL_CMD="" + if candidate="$(command -v bpftool 2>/dev/null)" && has_working_bpftool "${candidate}"; then + BPFTOOL_CMD="${candidate}" + fi + for pkg in \ + bpftool \ + "linux-tools-$(uname -r)" \ + "linux-cloud-tools-$(uname -r)" \ + linux-tools-generic \ + linux-cloud-tools-generic \ + linux-tools-azure \ + linux-cloud-tools-azure \ + linux-tools-common + do + if [[ -n "${BPFTOOL_CMD}" ]]; then + break + fi + sudo apt-get install -y "${pkg}" || true + mapfile -t BPFTOOL_CANDIDATES < <(find /usr/bin /usr/sbin /usr/local/bin /usr/local/sbin /usr/lib -type f -name 'bpftool*' 2>/dev/null | sort -u) + if candidate="$(first_working_bpftool "${BPFTOOL_CANDIDATES[@]}")"; then + BPFTOOL_CMD="${candidate}" + fi + done + if [[ -z "${BPFTOOL_CMD}" ]]; then + case "$(uname -m)" in + x86_64|amd64) BPFTOOL_ARCH="amd64" ;; + aarch64|arm64) BPFTOOL_ARCH="arm64" ;; + *) BPFTOOL_ARCH="" ;; + esac + if [[ -n "${BPFTOOL_ARCH}" ]]; then + BPFTOOL_VERSION="v7.6.0" + BPFTOOL_URL="https://github.com/libbpf/bpftool/releases/download/${BPFTOOL_VERSION}/bpftool-${BPFTOOL_VERSION}-${BPFTOOL_ARCH}.tar.gz" + tmpdir="$(mktemp -d)" + if curl -fsSL "${BPFTOOL_URL}" -o "${tmpdir}/bpftool.tgz" && tar -xzf "${tmpdir}/bpftool.tgz" -C "${tmpdir}"; then + mapfile -t BPFTOOL_CANDIDATES < <(find "${tmpdir}" -type f -perm -111 2>/dev/null | sort -u) + if candidate="$(first_working_bpftool "${BPFTOOL_CANDIDATES[@]}")"; then + sudo install -m 0755 "${candidate}" /usr/local/bin/bpftool-ci + BPFTOOL_CMD="/usr/local/bin/bpftool-ci" + fi + fi + rm -rf "${tmpdir}" + fi + fi + if [[ -z "${BPFTOOL_CMD}" ]]; then + echo "Unable to locate a working bpftool binary" + exit 1 + fi + echo "BPFTOOL_CMD=${BPFTOOL_CMD}" >> "${GITHUB_ENV}" + "${BPFTOOL_CMD}" version + + - name: Generate eBPF bindings + run: | + mkdir -p lib/provider/ebpf/bpf/headers + "${BPFTOOL_CMD:-bpftool}" btf dump file /sys/kernel/btf/vmlinux format c > lib/provider/ebpf/bpf/vmlinux.h + go generate ./lib/provider/ebpf + + - name: Run golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + args: --timeout=5m + test: name: Test (${{ matrix.runner }}) runs-on: ${{ matrix.runner }} diff --git a/.gitignore b/.gitignore index e9c0504..6f491f8 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,12 @@ go.work.sum # .idea/ # .vscode/ .claude/settings.local.json + +# Compiled binaries +/aurora +/aurora-util + +# Generated eBPF files (produced by go generate) +lib/provider/ebpf/bpf/headers/ +lib/provider/ebpf/*_bpfel.go +lib/provider/ebpf/*_bpfel.o diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..0a1da47 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,112 @@ +# golangci-lint configuration for Aurora Linux +# https://golangci-lint.run/usage/configuration/ + +run: + # Timeout for analysis + timeout: 5m + # Include test files + tests: true + +linters: + # Disable all linters by default and enable explicitly + disable-all: true + enable: + # Default/core linters + - errcheck # Check for unchecked errors + - govet # Go vet examines Go source code + - staticcheck # Static analysis checks + - gosimple # Simplify code + - unused # Find unused code + - ineffassign # Detect ineffectual assignments + + # Additional quality linters + - misspell # Find misspelled words + - gofumpt # Stricter gofmt + - revive # Fast, configurable linter + - gocritic # Highly extensible Go linter + - gocyclo # Cyclomatic complexity checker + +linters-settings: + gocyclo: + # Aurora has some inherently complex functions (eBPF event handling, + # config parsing, validation logic) that are best kept as single functions + min-complexity: 50 + + gofumpt: + # Only check new code, don't enforce reformatting existing code yet + extra-rules: false + + errcheck: + # Exclude common patterns where errors are intentionally ignored + exclude-functions: + # Close errors are almost never actionable + - (io.Closer).Close + - (*os.File).Close + - (*compress/gzip.Reader).Close + - (*archive/zip.ReadCloser).Close + # Cleanup in defer is best-effort + - os.RemoveAll + # SetDeadline errors are usually fine to ignore in tests + - (net.PacketConn).SetReadDeadline + - (net.Conn).SetReadDeadline + # Print functions rarely fail meaningfully + - fmt.Fprintln + - fmt.Fprintf + # Syscall close in cleanup + - syscall.Close + + gocritic: + # Enable only diagnostic checks for now; style/perf can be added later + enabled-tags: + - diagnostic + disabled-checks: + # These are too noisy for initial adoption + - appendAssign + - commentFormatting + + revive: + # Start with a minimal set of rules + rules: + - name: blank-imports + - name: context-as-argument + - name: context-keys-type + - name: error-return + - name: error-strings + - name: error-naming + - name: increment-decrement + - name: var-naming + - name: range + - name: receiver-naming + - name: time-naming + - name: indent-error-flow + - name: errorf + +issues: + # Exclusion rules + exclude-rules: + # Exclude errcheck in test files - tests often ignore errors intentionally + - path: _test\.go + linters: + - errcheck + # Exclude generated eBPF files from most checks + - path: "lib/provider/ebpf/.*_bpfel\\.go$" + linters: + - unused + - govet + # Existing code has different formatting - exclude gofumpt for now + # TODO: Run gofumpt on entire codebase in a separate PR + - path: \.go$ + linters: + - gofumpt + # Exclude gocritic style suggestions that would require larger refactors + - linters: + - gocritic + text: "(paramTypeCombine|unnamedResult|httpNoBody|octalLiteral|unnecessaryDefer|filepathJoin|stringConcatSimplify)" + # Revive warnings for existing code patterns + - linters: + - revive + text: "(stutters|exported:|should have comment|var-naming|increment-decrement)" + + # Don't limit the number of issues per linter + max-issues-per-linter: 0 + max-same-issues: 0 diff --git a/cmd/aurora-util/helpers_test.go b/cmd/aurora-util/helpers_test.go index 185867f..85240f9 100644 --- a/cmd/aurora-util/helpers_test.go +++ b/cmd/aurora-util/helpers_test.go @@ -1,9 +1,12 @@ package main import ( + "archive/tar" "archive/zip" + "compress/gzip" "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "os" @@ -1024,3 +1027,395 @@ func writeTestZip(t *testing.T, archivePath string, files map[string]string) { t.Fatalf("zip.Close() error = %v", err) } } + +// --------------------------------------------------------------------------- +// extractSubdirFromArchive — dispatch tests +// --------------------------------------------------------------------------- + +func TestExtractSubdirFromArchiveUnsupportedFormat(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "unknown.bin") + os.WriteFile(archivePath, []byte("random content"), 0644) + + _, err := extractSubdirFromArchive(archivePath, "rules/linux", filepath.Join(tmpDir, "out")) + if err == nil { + t.Fatal("expected error for unsupported format") + } +} + +func TestExtractSubdirFromArchiveMissingFile(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + _, err := extractSubdirFromArchive(filepath.Join(tmpDir, "missing.tar.gz"), "rules/linux", filepath.Join(tmpDir, "out")) + if err == nil { + t.Fatal("expected error for missing file") + } +} + +// --------------------------------------------------------------------------- +// extractSubdirFromTarGz — comprehensive tests +// --------------------------------------------------------------------------- + +func TestExtractSubdirFromTarGz(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "out") + + writeTestTarGzHelper(t, archivePath, map[string]string{ + "repo/rules/linux/proc/test1.yml": "rule1", + "repo/rules/linux/file/test2.yml": "rule2", + "repo/rules/windows/test3.yml": "rule3", + "repo/README.md": "readme", + }) + + written, err := extractSubdirFromTarGz(archivePath, "rules/linux", destDir) + if err != nil { + t.Fatalf("error = %v", err) + } + if written != 2 { + t.Fatalf("written = %d, want 2", written) + } + + content, err := os.ReadFile(filepath.Join(destDir, "proc", "test1.yml")) + if err != nil || string(content) != "rule1" { + t.Fatalf("rule1 content = %q, err = %v", content, err) + } +} + +func TestExtractSubdirFromTarGzEmptyResult(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "out") + + writeTestTarGzHelper(t, archivePath, map[string]string{ + "repo/rules/windows/test.yml": "rule", + }) + + _, err := extractSubdirFromTarGz(archivePath, "rules/linux", destDir) + if err == nil || !strings.Contains(err.Error(), "no files found") { + t.Fatalf("expected 'no files found' error, got: %v", err) + } +} + +func TestExtractSubdirFromTarGzDirectoryEntries(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "test.tar.gz") + destDir := filepath.Join(tmpDir, "out") + + writeTestTarGzWithDirs(t, archivePath, map[string]string{ + "repo/rules/linux/proc/test1.yml": "rule1", + }, []string{ + "repo/rules/linux/proc/", + }) + + written, err := extractSubdirFromTarGz(archivePath, "rules/linux", destDir) + if err != nil { + t.Fatalf("error = %v", err) + } + if written != 1 { + t.Fatalf("written = %d, want 1", written) + } + + // Directory should exist + info, err := os.Stat(filepath.Join(destDir, "proc")) + if err != nil || !info.IsDir() { + t.Fatal("expected proc directory to exist") + } +} + +// --------------------------------------------------------------------------- +// extractBestBinaryFromArchive — dispatch tests +// --------------------------------------------------------------------------- + +func TestExtractBestBinaryFromArchiveUnsupported(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "unknown.bin") + os.WriteFile(archivePath, []byte("random content"), 0644) + + _, err := extractBestBinaryFromArchive(archivePath, filepath.Join(tmpDir, "out"), []string{"aurora"}) + if err == nil { + t.Fatal("expected error for unsupported format") + } +} + +func TestExtractBestBinaryFromTarGz(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "release.tar.gz") + outputPath := filepath.Join(tmpDir, "aurora") + + writeTestTarGzHelper(t, archivePath, map[string]string{ + "release/aurora": "binary-content", + "release/aurora-linux": "alt-binary-content", + "release/README.md": "readme", + }) + + entryName, err := extractBestBinaryFromTarGz(archivePath, outputPath, []string{"aurora", "aurora-linux"}) + if err != nil { + t.Fatalf("error = %v", err) + } + if !strings.Contains(entryName, "aurora") { + t.Fatalf("entryName = %q, expected aurora", entryName) + } + + content, _ := os.ReadFile(outputPath) + if string(content) != "binary-content" { + t.Fatalf("content = %q, want binary-content", content) + } +} + +func TestExtractBestBinaryFromTarGzNotFound(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "release.tar.gz") + outputPath := filepath.Join(tmpDir, "aurora") + + writeTestTarGzHelper(t, archivePath, map[string]string{ + "release/README.md": "readme", + }) + + _, err := extractBestBinaryFromTarGz(archivePath, outputPath, []string{"aurora"}) + if err == nil { + t.Fatal("expected error when binary not found") + } +} + +func TestExtractBestBinaryFromZip(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "release.zip") + outputPath := filepath.Join(tmpDir, "aurora") + + writeTestZip(t, archivePath, map[string]string{ + "release/aurora": "binary-content", + "release/aurora-linux": "alt-binary-content", + "release/README.md": "readme", + }) + + entryName, err := extractBestBinaryFromZip(archivePath, outputPath, []string{"aurora", "aurora-linux"}) + if err != nil { + t.Fatalf("error = %v", err) + } + if !strings.Contains(entryName, "aurora") { + t.Fatalf("entryName = %q, expected aurora", entryName) + } + + content, _ := os.ReadFile(outputPath) + if string(content) != "binary-content" { + t.Fatalf("content = %q, want binary-content", content) + } +} + +func TestExtractBestBinaryFromZipNotFound(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "release.zip") + outputPath := filepath.Join(tmpDir, "aurora") + + writeTestZip(t, archivePath, map[string]string{ + "release/README.md": "readme", + }) + + _, err := extractBestBinaryFromZip(archivePath, outputPath, []string{"aurora"}) + if err == nil { + t.Fatal("expected error when binary not found") + } +} + +func TestExtractBestBinaryFromZipFallbackToSecondPreferred(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + archivePath := filepath.Join(tmpDir, "release.zip") + outputPath := filepath.Join(tmpDir, "binary") + + writeTestZip(t, archivePath, map[string]string{ + "release/aurora-linux": "second-binary", + }) + + entryName, err := extractBestBinaryFromZip(archivePath, outputPath, []string{"aurora", "aurora-linux"}) + if err != nil { + t.Fatalf("error = %v", err) + } + if !strings.Contains(entryName, "aurora-linux") { + t.Fatalf("entryName = %q, expected aurora-linux", entryName) + } +} + +// --------------------------------------------------------------------------- +// writeReaderToFile +// --------------------------------------------------------------------------- + +func TestWriteReaderToFileSuccess(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + outputPath := filepath.Join(tmpDir, "nested", "dir", "file.txt") + + err := writeReaderToFile(outputPath, strings.NewReader("hello world"), 0o600) + if err != nil { + t.Fatalf("error = %v", err) + } + + content, _ := os.ReadFile(outputPath) + if string(content) != "hello world" { + t.Fatalf("content = %q, want hello world", content) + } + + info, _ := os.Stat(outputPath) + if info.Mode().Perm() != 0o600 { + t.Fatalf("mode = %v, want 0600", info.Mode().Perm()) + } +} + +// --------------------------------------------------------------------------- +// copyFile additional tests +// --------------------------------------------------------------------------- + +func TestCopyFileMissingSource(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + err := copyFile(filepath.Join(tmpDir, "missing.txt"), filepath.Join(tmpDir, "dst.txt"), 0644) + if err == nil { + t.Fatal("expected error for missing source") + } +} + +func TestCopyFileCreatesParentDirs(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + srcPath := filepath.Join(tmpDir, "src.txt") + dstPath := filepath.Join(tmpDir, "deep", "nested", "dst.txt") + + os.WriteFile(srcPath, []byte("content"), 0644) + + if err := copyFile(srcPath, dstPath, 0644); err != nil { + t.Fatalf("error = %v", err) + } + + content, _ := os.ReadFile(dstPath) + if string(content) != "content" { + t.Fatalf("content = %q", content) + } +} + +// --------------------------------------------------------------------------- +// copyDir additional tests +// --------------------------------------------------------------------------- + +func TestCopyDirMissingSource(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + err := copyDir(filepath.Join(tmpDir, "missing"), filepath.Join(tmpDir, "dst")) + if err == nil { + t.Fatal("expected error for missing source") + } +} + +func TestCopyDirSymlinkError(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + srcDir := filepath.Join(tmpDir, "src") + os.MkdirAll(srcDir, 0755) + + realFile := filepath.Join(srcDir, "real.txt") + os.WriteFile(realFile, []byte("content"), 0644) + + symlink := filepath.Join(srcDir, "link.txt") + os.Symlink(realFile, symlink) + + err := copyDir(srcDir, filepath.Join(tmpDir, "dst")) + if err == nil { + t.Fatal("expected error for symlink in source") + } + if !strings.Contains(err.Error(), "symlink") { + t.Fatalf("error should mention symlink: %v", err) + } +} + +// --------------------------------------------------------------------------- +// relFromArchiveSubdir +// --------------------------------------------------------------------------- + +func TestRelFromArchiveSubdirBasic(t *testing.T) { + t.Parallel() + tests := []struct { + entry string + subdir string + wantRel string + wantOK bool + }{ + {"repo/rules/linux/test.yml", "rules/linux", "test.yml", true}, + {"repo/rules/linux/", "rules/linux", "", true}, + {"repo/rules/windows/test.yml", "rules/linux", "", false}, + {"a/b/c/d.txt", "b/c", "d.txt", true}, + {"just.txt", "", "just.txt", true}, + {"", "", "", false}, + } + + for _, tc := range tests { + name := fmt.Sprintf("%s_%s", tc.entry, tc.subdir) + t.Run(name, func(t *testing.T) { + rel, ok := relFromArchiveSubdir(tc.entry, tc.subdir) + if ok != tc.wantOK || rel != tc.wantRel { + t.Fatalf("relFromArchiveSubdir(%q, %q) = (%q, %v), want (%q, %v)", + tc.entry, tc.subdir, rel, ok, tc.wantRel, tc.wantOK) + } + }) + } +} + +// --------------------------------------------------------------------------- +// helpers for tar.gz tests +// --------------------------------------------------------------------------- + +// writeTestTarGzHelper wraps writeTestTarGz (from main_test.go) for use in test helper +func writeTestTarGzHelper(t *testing.T, archivePath string, files map[string]string) { + t.Helper() + if err := writeTestTarGz(archivePath, files); err != nil { + t.Fatalf("writeTestTarGz() error = %v", err) + } +} + +func writeTestTarGzWithDirs(t *testing.T, archivePath string, files map[string]string, dirs []string) { + t.Helper() + f, err := os.Create(archivePath) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + defer f.Close() + + gw := gzip.NewWriter(f) + defer gw.Close() + + tw := tar.NewWriter(gw) + defer tw.Close() + + for _, dir := range dirs { + hdr := &tar.Header{ + Name: dir, + Mode: 0755, + Typeflag: tar.TypeDir, + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("tar.WriteHeader() error = %v", err) + } + } + + for name, content := range files { + hdr := &tar.Header{ + Name: name, + Mode: 0644, + Size: int64(len(content)), + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("tar.WriteHeader() error = %v", err) + } + if _, err := tw.Write([]byte(content)); err != nil { + t.Fatalf("tar.Write() error = %v", err) + } + } +} diff --git a/cmd/aurora-util/main.go b/cmd/aurora-util/main.go index 282f288..17d1c03 100644 --- a/cmd/aurora-util/main.go +++ b/cmd/aurora-util/main.go @@ -334,9 +334,7 @@ func normalizePprofBaseURL(raw string) (string, error) { } cleanPath := strings.TrimSuffix(u.Path, "/") - if strings.HasSuffix(cleanPath, "/debug/pprof") { - cleanPath = strings.TrimSuffix(cleanPath, "/debug/pprof") - } + cleanPath = strings.TrimSuffix(cleanPath, "/debug/pprof") if cleanPath == "/debug/pprof" { cleanPath = "" } @@ -610,7 +608,7 @@ func extractSubdirFromTarGz(archivePath, sourceSubdir, destinationDir string) (i if err := os.MkdirAll(targetPath, 0o755); err != nil { return 0, fmt.Errorf("creating directory %q: %w", targetPath, err) } - case tar.TypeReg, tar.TypeRegA: + case tar.TypeReg: if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { return 0, fmt.Errorf("creating parent directory for %q: %w", targetPath, err) } @@ -725,7 +723,7 @@ func extractBestBinaryFromTarGz(archivePath, outputPath string, preferredNames [ if err != nil { return "", fmt.Errorf("reading tar archive: %w", err) } - if hdr.Typeflag != tar.TypeReg && hdr.Typeflag != tar.TypeRegA { + if hdr.Typeflag != tar.TypeReg { continue } diff --git a/cmd/aurora/agent/agent.go b/cmd/aurora/agent/agent.go index 0b51fa0..7ad1cb1 100644 --- a/cmd/aurora/agent/agent.go +++ b/cmd/aurora/agent/agent.go @@ -143,10 +143,10 @@ func (a *Agent) Run() error { // Create and initialize eBPF listener a.listener = ebpfprovider.NewListener(a.correlator) - // Enable all sources - a.listener.AddSource(ebpfprovider.SourceProcessExec) - a.listener.AddSource(ebpfprovider.SourceFileCreate) - a.listener.AddSource(ebpfprovider.SourceNetConnect) + // Enable all sources (errors are checked during Initialize) + _ = a.listener.AddSource(ebpfprovider.SourceProcessExec) + _ = a.listener.AddSource(ebpfprovider.SourceFileCreate) + _ = a.listener.AddSource(ebpfprovider.SourceNetConnect) if err := a.listener.Initialize(); err != nil { return fmt.Errorf("initializing eBPF listener: %w", err) diff --git a/cmd/aurora/agent/agent_test.go b/cmd/aurora/agent/agent_test.go index 5ee5221..dee5319 100644 --- a/cmd/aurora/agent/agent_test.go +++ b/cmd/aurora/agent/agent_test.go @@ -258,3 +258,276 @@ func (s *stubEvent) Value(fieldname string) enrichment.DataValue { return s.fields.Value(fieldname) } func (s *stubEvent) ForEach(fn func(key string, value string)) { s.fields.ForEach(fn) } + +func TestTraceEventLogsAllFields(t *testing.T) { + // Capture log output + var buf strings.Builder + logger := log.New() + logger.SetOutput(&buf) + logger.SetLevel(log.DebugLevel) + log.SetOutput(&buf) + log.SetLevel(log.DebugLevel) + defer func() { + log.SetOutput(os.Stderr) + log.SetLevel(log.InfoLevel) + }() + + a := New(Parameters{Trace: true}) + + evt := &stubEvent{ + id: provider.EventIdentifier{ + ProviderName: "LinuxEBPF", + EventID: 1, + }, + source: "LinuxEBPF:ProcessExec", + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/bash"), + "CommandLine": enrichment.NewStringValue("bash -c echo"), + "ProcessId": enrichment.NewStringValue("1234"), + }, + } + + a.traceEvent(evt) + + output := buf.String() + if !strings.Contains(output, "Trace event") { + t.Fatalf("expected 'Trace event' in output, got: %s", output) + } + if !strings.Contains(output, "LinuxEBPF") { + t.Fatalf("expected 'LinuxEBPF' in output, got: %s", output) + } +} + +func TestShouldExcludeEventWithEmptyFilter(t *testing.T) { + a := New(Parameters{ProcessExclude: ""}) + + evt := &stubEvent{ + id: provider.EventIdentifier{ProviderName: "LinuxEBPF", EventID: 1}, + source: "LinuxEBPF:ProcessExec", + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/bash"), + }, + } + + if a.shouldExcludeEvent(evt) { + t.Fatal("shouldExcludeEvent() with empty filter should return false") + } +} + +func TestShouldExcludeEventWithWhitespaceFilter(t *testing.T) { + a := New(Parameters{ProcessExclude: " "}) + + evt := &stubEvent{ + id: provider.EventIdentifier{ProviderName: "LinuxEBPF", EventID: 1}, + source: "LinuxEBPF:ProcessExec", + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/bash"), + }, + } + + if a.shouldExcludeEvent(evt) { + t.Fatal("shouldExcludeEvent() with whitespace filter should return false") + } +} + +func TestShouldExcludeEventMatchesParentFields(t *testing.T) { + a := New(Parameters{ProcessExclude: "systemd"}) + + evt := &stubEvent{ + id: provider.EventIdentifier{ProviderName: "LinuxEBPF", EventID: 1}, + source: "LinuxEBPF:ProcessExec", + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/bash"), + "CommandLine": enrichment.NewStringValue("bash"), + "ParentImage": enrichment.NewStringValue("/usr/lib/systemd/systemd"), + "ParentCommandLine": enrichment.NewStringValue("/usr/lib/systemd/systemd"), + }, + } + + if !a.shouldExcludeEvent(evt) { + t.Fatal("shouldExcludeEvent() should match ParentImage") + } +} + +func TestShouldExcludeEventCaseInsensitive(t *testing.T) { + a := New(Parameters{ProcessExclude: "BASH"}) + + evt := &stubEvent{ + id: provider.EventIdentifier{ProviderName: "LinuxEBPF", EventID: 1}, + source: "LinuxEBPF:ProcessExec", + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/bash"), + }, + } + + if !a.shouldExcludeEvent(evt) { + t.Fatal("shouldExcludeEvent() should be case insensitive") + } +} + +func TestStartPprofEndpointWithEmptyAddress(t *testing.T) { + a := New(Parameters{PprofListen: ""}) + + err := a.startPprofEndpoint() + if err != nil { + t.Fatalf("startPprofEndpoint() with empty address should return nil, got: %v", err) + } + if a.pprofSrv != nil { + t.Fatal("pprofSrv should be nil when no address is configured") + } +} + +func TestStartPprofEndpointWithWhitespaceAddress(t *testing.T) { + a := New(Parameters{PprofListen: " "}) + + err := a.startPprofEndpoint() + if err != nil { + t.Fatalf("startPprofEndpoint() with whitespace address should return nil, got: %v", err) + } + if a.pprofSrv != nil { + t.Fatal("pprofSrv should be nil when no address is configured") + } +} + +func TestStartPprofEndpointWithValidAddress(t *testing.T) { + a := New(Parameters{PprofListen: "127.0.0.1:0"}) // port 0 = auto-assign + + err := a.startPprofEndpoint() + if err != nil { + t.Fatalf("startPprofEndpoint() error = %v", err) + } + if a.pprofSrv == nil { + t.Fatal("pprofSrv should not be nil") + } + if a.pprofAddr == "" { + t.Fatal("pprofAddr should not be empty") + } + + // Clean up + a.stopPprofEndpoint() + if a.pprofSrv != nil { + t.Fatal("pprofSrv should be nil after stop") + } + if a.pprofAddr != "" { + t.Fatal("pprofAddr should be empty after stop") + } +} + +func TestStartPprofEndpointWithInvalidAddress(t *testing.T) { + a := New(Parameters{PprofListen: "invalid:address:format:99999999"}) + + err := a.startPprofEndpoint() + if err == nil { + a.stopPprofEndpoint() + t.Fatal("startPprofEndpoint() expected error for invalid address") + } +} + +func TestStopPprofEndpointIsIdempotent(t *testing.T) { + a := New(Parameters{}) + a.pprofSrv = nil + + // Should not panic when called on nil server + a.stopPprofEndpoint() + + // Call again + a.stopPprofEndpoint() +} + +func TestShutdownWithNilComponents(t *testing.T) { + a := &Agent{} + + // Should not panic with all nil components + a.shutdown() +} + +func TestCloseOutputsWithCloserErrors(t *testing.T) { + closeCount := 0 + a := &Agent{ + closers: []func() error{ + func() error { closeCount++; return nil }, + func() error { closeCount++; return nil }, + }, + } + + a.closeOutputs() + + if closeCount != 2 { + t.Fatalf("closeCount = %d, want 2", closeCount) + } + if a.closers != nil { + t.Fatal("closers should be nil after closeOutputs") + } +} + +func TestPrintWelcomeBannerInJSONMode(t *testing.T) { + // In JSON mode, banner should be skipped + a := New(Parameters{JSONOutput: true}) + + // This should not panic and should do nothing + a.printWelcomeBanner() +} + +func TestPrintWelcomeBannerInTextMode(t *testing.T) { + // Capture output + origOut := log.StandardLogger().Out + defer log.StandardLogger().SetOutput(origOut) + + var buf strings.Builder + log.StandardLogger().SetOutput(&buf) + + a := New(Parameters{JSONOutput: false, Version: "1.2.3"}) + a.printWelcomeBanner() + + output := buf.String() + // The ASCII art contains "AURORA" split across multiple lines with backslashes + // Check for parts of it or the descriptive text + if !strings.Contains(output, "Real-Time Sigma Matching") { + t.Fatalf("expected 'Real-Time Sigma Matching' in banner, got: %s", output) + } + if !strings.Contains(output, "v1.2.3") { + t.Fatalf("expected v1.2.3 in banner, got: %s", output) + } +} + +func TestPrintWelcomeBannerVersionNormalization(t *testing.T) { + origOut := log.StandardLogger().Out + defer log.StandardLogger().SetOutput(origOut) + + var buf strings.Builder + log.StandardLogger().SetOutput(&buf) + + // Test without "v" prefix + a := New(Parameters{JSONOutput: false, Version: "2.0.0"}) + a.printWelcomeBanner() + + if !strings.Contains(buf.String(), "v2.0.0") { + t.Fatalf("expected v2.0.0 in banner (auto-prefixed), got: %s", buf.String()) + } + + buf.Reset() + + // Test with "v" prefix already + a2 := New(Parameters{JSONOutput: false, Version: "v3.0.0"}) + a2.printWelcomeBanner() + + if !strings.Contains(buf.String(), "v3.0.0") { + t.Fatalf("expected v3.0.0 in banner, got: %s", buf.String()) + } +} + +func TestPrintWelcomeBannerDefaultVersion(t *testing.T) { + origOut := log.StandardLogger().Out + defer log.StandardLogger().SetOutput(origOut) + + var buf strings.Builder + log.StandardLogger().SetOutput(&buf) + + a := New(Parameters{JSONOutput: false, Version: ""}) + a.printWelcomeBanner() + + // Default version should be 0.1.4 + if !strings.Contains(buf.String(), "v0.1.4") { + t.Fatalf("expected default version in banner, got: %s", buf.String()) + } +} diff --git a/cmd/aurora/agent/validate_test.go b/cmd/aurora/agent/validate_test.go index 18d91dc..b0e7271 100644 --- a/cmd/aurora/agent/validate_test.go +++ b/cmd/aurora/agent/validate_test.go @@ -232,3 +232,200 @@ func TestValidateParametersAcceptsLoopbackPprofListen(t *testing.T) { t.Fatalf("ValidateParameters() unexpected error: %v", err) } } + +func TestValidateHostPortWithPort0(t *testing.T) { + err := validateHostPort("--tcp-target", "127.0.0.1:0") + if err == nil { + t.Fatal("validateHostPort() expected error for port 0") + } + if !strings.Contains(err.Error(), "1-65535") { + t.Fatalf("expected port range error, got %v", err) + } +} + +func TestValidateHostPortWithPort65536(t *testing.T) { + err := validateHostPort("--tcp-target", "127.0.0.1:65536") + if err == nil { + t.Fatal("validateHostPort() expected error for port 65536") + } + if !strings.Contains(err.Error(), "1-65535") { + t.Fatalf("expected port range error, got %v", err) + } +} + +func TestValidateHostPortWithEmptyHost(t *testing.T) { + err := validateHostPort("--tcp-target", ":8080") + if err == nil { + t.Fatal("validateHostPort() expected error for empty host") + } + if !strings.Contains(err.Error(), "must include a host") { + t.Fatalf("expected host error, got %v", err) + } +} + +func TestValidateHostPortWithNonNumericPort(t *testing.T) { + err := validateHostPort("--tcp-target", "localhost:abc") + if err == nil { + t.Fatal("validateHostPort() expected error for non-numeric port") + } + if !strings.Contains(err.Error(), "numeric port") { + t.Fatalf("expected numeric port error, got %v", err) + } +} + +func TestIsLoopbackHostWithBracketedIPv6(t *testing.T) { + // Bracketed IPv6 notation like [::1] shouldn't parse as IP directly + if isLoopbackHost("[::1]") { + t.Fatal("isLoopbackHost([::1]) should return false (brackets not stripped)") + } + + // Without brackets should work + if !isLoopbackHost("::1") { + t.Fatal("isLoopbackHost(::1) should return true") + } +} + +func TestIsLoopbackHostWithLocalhost(t *testing.T) { + if !isLoopbackHost("localhost") { + t.Fatal("isLoopbackHost(localhost) should return true") + } + if !isLoopbackHost("LOCALHOST") { + t.Fatal("isLoopbackHost(LOCALHOST) should return true (case insensitive)") + } +} + +func TestIsLoopbackHostWithIPv4Loopback(t *testing.T) { + if !isLoopbackHost("127.0.0.1") { + t.Fatal("isLoopbackHost(127.0.0.1) should return true") + } + if !isLoopbackHost("127.0.0.255") { + t.Fatal("isLoopbackHost(127.0.0.255) should return true") + } +} + +func TestIsLoopbackHostWithNonLoopback(t *testing.T) { + if isLoopbackHost("192.168.1.1") { + t.Fatal("isLoopbackHost(192.168.1.1) should return false") + } + if isLoopbackHost("example.com") { + t.Fatal("isLoopbackHost(example.com) should return false") + } + if isLoopbackHost("0.0.0.0") { + t.Fatal("isLoopbackHost(0.0.0.0) should return false") + } +} + +func TestValidateLoopbackHostPortWithBracketedIPv6(t *testing.T) { + // [::1] is valid IPv6 notation for net.SplitHostPort + // After splitting, host = "::1" which is loopback + // However, isLoopbackHost receives "::1" (without brackets) + err := validateLoopbackHostPort("--pprof-listen", "[::1]:6060") + if err != nil { + t.Fatalf("validateLoopbackHostPort([::1]:6060) unexpected error: %v", err) + } +} + +func TestValidateLoopbackHostPortWithUnbracketedIPv6(t *testing.T) { + // Unbracketed IPv6 with port is ambiguous and should fail parsing + err := validateLoopbackHostPort("--pprof-listen", "::1:6060") + // This will fail because ::1:6060 is ambiguous (colons in IPv6) + if err == nil { + t.Fatal("validateLoopbackHostPort() expected error for ambiguous IPv6") + } +} + +func TestIsPowerOfTwo(t *testing.T) { + tests := []struct { + value int + want bool + }{ + {0, false}, + {1, true}, + {2, true}, + {3, false}, + {4, true}, + {5, false}, + {1024, true}, + {2048, true}, + {3000, false}, + {-1, false}, + {-2, false}, + } + + for _, tc := range tests { + if got := isPowerOfTwo(tc.value); got != tc.want { + t.Errorf("isPowerOfTwo(%d) = %v, want %v", tc.value, got, tc.want) + } + } +} + +func TestValidateParametersRejectsEmptyRulesPath(t *testing.T) { + params := DefaultParameters() + params.RuleDirs = []string{" "} + + err := ValidateParameters(params) + if err == nil { + t.Fatal("ValidateParameters() expected error for empty rules path") + } +} + +func TestValidateParametersRejectsZeroCorrelationCache(t *testing.T) { + tmpDir := t.TempDir() + params := DefaultParameters() + params.RuleDirs = []string{tmpDir} + params.CorrelationCacheSize = 0 + + err := ValidateParameters(params) + if err == nil { + t.Fatal("ValidateParameters() expected error for zero correlation cache") + } + if !strings.Contains(err.Error(), "--correlation-cache") { + t.Fatalf("expected --correlation-cache context, got %v", err) + } +} + +func TestValidateParametersRejectsNegativeThrottleRate(t *testing.T) { + tmpDir := t.TempDir() + params := DefaultParameters() + params.RuleDirs = []string{tmpDir} + params.ThrottleRate = -1 + + err := ValidateParameters(params) + if err == nil { + t.Fatal("ValidateParameters() expected error for negative throttle rate") + } + if !strings.Contains(err.Error(), "--throttle-rate") { + t.Fatalf("expected --throttle-rate context, got %v", err) + } +} + +func TestValidateParametersRejectsThrottleRateWithoutBurst(t *testing.T) { + tmpDir := t.TempDir() + params := DefaultParameters() + params.RuleDirs = []string{tmpDir} + params.ThrottleRate = 1.0 + params.ThrottleBurst = 0 + + err := ValidateParameters(params) + if err == nil { + t.Fatal("ValidateParameters() expected error for throttle rate without burst") + } + if !strings.Contains(err.Error(), "--throttle-burst") { + t.Fatalf("expected --throttle-burst context, got %v", err) + } +} + +func TestValidateParametersRejectsNegativeStatsInterval(t *testing.T) { + tmpDir := t.TempDir() + params := DefaultParameters() + params.RuleDirs = []string{tmpDir} + params.StatsInterval = -1 + + err := ValidateParameters(params) + if err == nil { + t.Fatal("ValidateParameters() expected error for negative stats interval") + } + if !strings.Contains(err.Error(), "--stats-interval") { + t.Fatalf("expected --stats-interval context, got %v", err) + } +} diff --git a/lib/consumer/sigma/loadrules.go b/lib/consumer/sigma/loadrules.go deleted file mode 100644 index b067779..0000000 --- a/lib/consumer/sigma/loadrules.go +++ /dev/null @@ -1,12 +0,0 @@ -package sigma - -import ( - "path/filepath" - "strings" -) - -// isYAMLFile returns true if the file has a .yml or .yaml extension. -func isYAMLFile(path string) bool { - ext := strings.ToLower(filepath.Ext(path)) - return ext == ".yml" || ext == ".yaml" -} diff --git a/lib/consumer/sigma/matchdetails_test.go b/lib/consumer/sigma/matchdetails_test.go index 7574a83..ef1ed07 100644 --- a/lib/consumer/sigma/matchdetails_test.go +++ b/lib/consumer/sigma/matchdetails_test.go @@ -481,3 +481,207 @@ func TestStringifyRuleMetadataValue(t *testing.T) { }) } } + +func TestDescribeStringMatcherPatterns(t *testing.T) { + // Test ContentPattern (TextPatternNone = exact match) + contentMatcher, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternNone, false, false, false, "/usr/bin/bash") + got := describeStringMatcherPatterns(contentMatcher, "/usr/bin/bash") + if len(got) != 1 || got[0] != "/usr/bin/bash" { + t.Fatalf("ContentPattern match: got %v, want [/usr/bin/bash]", got) + } + + // Test ContentPattern non-match + got = describeStringMatcherPatterns(contentMatcher, "/usr/bin/zsh") + if len(got) != 0 { + t.Fatalf("ContentPattern non-match: got %v, want []", got) + } + + // Test PrefixPattern + prefixMatcher, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternPrefix, false, false, false, "/usr/") + got = describeStringMatcherPatterns(prefixMatcher, "/usr/bin/bash") + if len(got) != 1 || got[0] != "/usr/*" { + t.Fatalf("PrefixPattern: got %v, want [/usr/*]", got) + } + + // Test PrefixPattern non-match + got = describeStringMatcherPatterns(prefixMatcher, "/bin/bash") + if len(got) != 0 { + t.Fatalf("PrefixPattern non-match: got %v, want []", got) + } + + // Test SuffixPattern + suffixMatcher, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternSuffix, false, false, false, "/bash") + got = describeStringMatcherPatterns(suffixMatcher, "/usr/bin/bash") + if len(got) != 1 || got[0] != "*/bash" { + t.Fatalf("SuffixPattern: got %v, want [*/bash]", got) + } + + // Test SuffixPattern non-match + got = describeStringMatcherPatterns(suffixMatcher, "/usr/bin/zsh") + if len(got) != 0 { + t.Fatalf("SuffixPattern non-match: got %v, want []", got) + } + + // Test RegexPattern + regexMatcher, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternRegex, false, false, false, ".*bash$") + got = describeStringMatcherPatterns(regexMatcher, "/usr/bin/bash") + if len(got) != 1 || got[0] != "/.*bash$/" { + t.Fatalf("RegexPattern: got %v, want [/.*bash$/]", got) + } + + // Test RegexPattern non-match + got = describeStringMatcherPatterns(regexMatcher, "/usr/bin/zsh") + if len(got) != 0 { + t.Fatalf("RegexPattern non-match: got %v, want []", got) + } + + // Test StringMatchers (multiple patterns OR'd together) with exact match patterns + matcher1, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternNone, false, false, false, "test1") + matcher2, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternNone, false, false, false, "test2") + combined := sigmaengine.StringMatchers{matcher1, matcher2} + // Only one should match for exact patterns + got = describeStringMatcherPatterns(combined, "test1") + if len(got) != 1 || got[0] != "test1" { + t.Fatalf("StringMatchers: got %v, want [test1]", got) + } + + // Test StringMatchersConj (multiple patterns AND'd together) with prefix patterns + prefix1, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternPrefix, false, false, false, "/usr/") + prefix2, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternSuffix, false, false, false, "/bash") + conjMatcher := sigmaengine.StringMatchersConj{prefix1, prefix2} + got = describeStringMatcherPatterns(conjMatcher, "/usr/bin/bash") + if len(got) != 2 { + t.Fatalf("StringMatchersConj: got %v, want 2 patterns", got) + } + + // Test GlobPattern via NewStringMatcher (contains modifier creates glob) + globMatcher, _ := sigmaengine.NewStringMatcher(sigmaengine.TextPatternContains, false, false, false, "evil") + got = describeStringMatcherPatterns(globMatcher, "path/to/evil/binary") + // GlobPattern returns "" as the description + if len(got) != 1 || got[0] != "" { + t.Fatalf("GlobPattern: got %v, want []", got) + } + + // Test GlobPattern non-match + got = describeStringMatcherPatterns(globMatcher, "path/to/good/binary") + if len(got) != 0 { + t.Fatalf("GlobPattern non-match: got %v, want []", got) + } +} + +func TestDescribeNumMatcherPatterns(t *testing.T) { + // Test NumPattern match + numMatcher := sigmaengine.NumPattern{Val: 42} + got := describeNumMatcherPatterns(numMatcher, 42) + if len(got) != 1 || got[0] != "42" { + t.Fatalf("NumPattern match: got %v, want [42]", got) + } + + // Test NumPattern non-match + got = describeNumMatcherPatterns(numMatcher, 99) + if len(got) != 0 { + t.Fatalf("NumPattern non-match: got %v, want []", got) + } + + // Test NumMatchers (multiple patterns OR'd) + numMatcher1 := sigmaengine.NumPattern{Val: 10} + numMatcher2 := sigmaengine.NumPattern{Val: 20} + combined := sigmaengine.NumMatchers{numMatcher1, numMatcher2} + got = describeNumMatcherPatterns(combined, 10) + if len(got) != 1 || got[0] != "10" { + t.Fatalf("NumMatchers match: got %v, want [10]", got) + } + + // Test NumMatchers with both matching + got = describeNumMatcherPatterns(combined, 20) + if len(got) != 1 || got[0] != "20" { + t.Fatalf("NumMatchers match second: got %v, want [20]", got) + } +} + +func TestBuildRuleMetadataNilTree(t *testing.T) { + meta := buildRuleMetadata(nil) + if meta.ID != "" || meta.Title != "" { + t.Fatalf("expected empty metadata for nil tree, got ID=%q Title=%q", meta.ID, meta.Title) + } + + meta = buildRuleMetadata(&sigmaengine.Tree{Rule: nil}) + if meta.ID != "" || meta.Title != "" { + t.Fatalf("expected empty metadata for nil rule, got ID=%q Title=%q", meta.ID, meta.Title) + } +} + +func TestRuleFieldPatternMatchesWithNilMatcher(t *testing.T) { + p := ruleFieldPattern{ + Modifiers: []string{"contains"}, + Pattern: "test", + Matcher: nil, + } + + if p.matches("any value") { + t.Fatal("matches() should return false for nil Matcher") + } +} + +func TestCollectFieldPatternEntryWithNumericTypes(t *testing.T) { + dst := make(map[string][]ruleFieldPattern) + + // Test int + collectFieldPatternEntry(dst, "EventID", 4688, false) + if len(dst["eventid"]) != 1 || dst["eventid"][0].Pattern != "4688" { + t.Fatalf("int pattern: got %v", dst["eventid"]) + } + + // Test float32 + dst = make(map[string][]ruleFieldPattern) + collectFieldPatternEntry(dst, "Score", float32(3.14), false) + if len(dst["score"]) != 1 { + t.Fatalf("float32 pattern: got %v", dst["score"]) + } + + // Test bool + dst = make(map[string][]ruleFieldPattern) + collectFieldPatternEntry(dst, "Enabled", true, false) + if len(dst["enabled"]) != 1 || dst["enabled"][0].Pattern != "true" { + t.Fatalf("bool pattern: got %v", dst["enabled"]) + } +} + +func TestCollectFieldPatternsFromSelectionValueMapInterfaceInterface(t *testing.T) { + dst := make(map[string][]ruleFieldPattern) + + // Use map[interface{}]interface{} as YAML parsing sometimes produces this + value := map[interface{}]interface{}{ + "Image|endswith": "/bash", + } + + collectFieldPatternsFromSelectionValue(dst, value, false) + + if len(dst["image"]) != 1 || dst["image"][0].Pattern != "/bash" { + t.Fatalf("map[interface{}]interface{}: got %v", dst["image"]) + } +} + +func TestExtractDetectionFieldPatternsWithNestedMaps(t *testing.T) { + detection := sigmaengine.Detection{ + "selection": []interface{}{ + map[interface{}]interface{}{ + "Image|endswith": "/curl", + }, + map[interface{}]interface{}{ + "Image|endswith": "/wget", + }, + }, + "condition": "selection", + } + + result := extractDetectionFieldPatterns(detection, false) + if result == nil { + t.Fatal("expected non-nil result") + } + + patterns := result["image"] + if len(patterns) != 2 { + t.Fatalf("expected 2 patterns, got %d", len(patterns)) + } +} diff --git a/lib/consumer/sigma/sigmaconsumer_test.go b/lib/consumer/sigma/sigmaconsumer_test.go index 413bae4..39dd202 100644 --- a/lib/consumer/sigma/sigmaconsumer_test.go +++ b/lib/consumer/sigma/sigmaconsumer_test.go @@ -390,6 +390,253 @@ func TestSigmaRuleLevelToLogLevel(t *testing.T) { } } +func TestIsValidMinLevel(t *testing.T) { + tests := []struct { + level string + want bool + }{ + {level: "info", want: true}, + {level: "INFO", want: true}, + {level: " info ", want: true}, + {level: "informational", want: true}, + {level: "low", want: true}, + {level: "medium", want: true}, + {level: "high", want: true}, + {level: "critical", want: true}, + {level: "CRITICAL", want: true}, + {level: "invalid", want: false}, + {level: "", want: false}, + {level: "debug", want: false}, + {level: "warning", want: false}, + } + + for _, tc := range tests { + t.Run(tc.level, func(t *testing.T) { + got := IsValidMinLevel(tc.level) + if got != tc.want { + t.Fatalf("IsValidMinLevel(%q) = %v, want %v", tc.level, got, tc.want) + } + }) + } +} + +func TestSigmaConsumerName(t *testing.T) { + consumer := New(Config{}) + got := consumer.Name() + if got != "SigmaConsumer" { + t.Fatalf("Name() = %q, want SigmaConsumer", got) + } +} + +func TestSigmaConsumerInitialize(t *testing.T) { + consumer := New(Config{}) + err := consumer.Initialize() + if err != nil { + t.Fatalf("Initialize() error = %v", err) + } +} + +func TestSigmaConsumerClose(t *testing.T) { + consumer := New(Config{}) + err := consumer.Close() + if err != nil { + t.Fatalf("Close() error = %v", err) + } +} + +func TestSigmaConsumerMatches(t *testing.T) { + consumer := New(Config{}) + + // Initially should be 0 + if got := consumer.Matches(); got != 0 { + t.Fatalf("Matches() = %d, want 0", got) + } + + // After adding via atomic + consumer.matches.Add(5) + if got := consumer.Matches(); got != 5 { + t.Fatalf("Matches() = %d, want 5", got) + } + + consumer.matches.Add(3) + if got := consumer.Matches(); got != 8 { + t.Fatalf("Matches() = %d, want 8", got) + } +} + +func TestSigmaEventWrapperKeywords(t *testing.T) { + event := &testEvent{ + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/bin/bash"), + "CommandLine": enrichment.NewStringValue("bash -c echo"), + }, + } + wrapper := &sigmaEventWrapper{event: event} + + keywords, ok := wrapper.Keywords() + if !ok { + t.Fatal("Keywords() returned false") + } + if len(keywords) != 2 { + t.Fatalf("Keywords() len = %d, want 2", len(keywords)) + } + + // Check that both field values are included + keywordSet := make(map[string]bool) + for _, kw := range keywords { + keywordSet[kw] = true + } + if !keywordSet["/bin/bash"] || !keywordSet["bash -c echo"] { + t.Fatalf("Keywords() = %v, expected [/bin/bash, bash -c echo]", keywords) + } +} + +func TestSigmaEventWrapperKeywordsEmpty(t *testing.T) { + event := &testEvent{fields: enrichment.DataFieldsMap{}} + wrapper := &sigmaEventWrapper{event: event} + + keywords, ok := wrapper.Keywords() + if ok { + t.Fatal("Keywords() should return false for empty fields") + } + if keywords != nil { + t.Fatalf("Keywords() = %v, want nil", keywords) + } +} + +func TestSigmaEventWrapperForReplayKeywords(t *testing.T) { + wrapper := &sigmaEventWrapperForReplay{ + fields: map[string]string{ + "Image": "/bin/bash", + "CommandLine": "bash -c echo", + }, + } + + keywords, ok := wrapper.Keywords() + if !ok { + t.Fatal("Keywords() returned false") + } + if len(keywords) != 2 { + t.Fatalf("Keywords() len = %d, want 2", len(keywords)) + } +} + +func TestSigmaEventWrapperForReplayKeywordsEmpty(t *testing.T) { + wrapper := &sigmaEventWrapperForReplay{fields: map[string]string{}} + + keywords, ok := wrapper.Keywords() + if ok { + t.Fatal("Keywords() should return false for empty fields") + } + if keywords != nil { + t.Fatalf("Keywords() = %v, want nil", keywords) + } +} + +func TestSigmaEventWrapperForReplaySelect(t *testing.T) { + wrapper := &sigmaEventWrapperForReplay{ + fields: map[string]string{ + "Image": "/bin/bash", + }, + } + + // Existing key + val, ok := wrapper.Select("Image") + if !ok || val != "/bin/bash" { + t.Fatalf("Select(Image) = (%v, %v), want (/bin/bash, true)", val, ok) + } + + // Missing key + val, ok = wrapper.Select("NonExistent") + if ok { + t.Fatalf("Select(NonExistent) = (%v, %v), want (_, false)", val, ok) + } +} + +func TestFormatMatchMessage(t *testing.T) { + event := &testEvent{ + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/whoami"), + "CommandLine": enrichment.NewStringValue("whoami /all"), + "ProcessId": enrichment.NewStringValue("1234"), + }, + } + result := sigmaengine.Result{ + ID: "test-rule-id", + Title: "Test Rule", + } + + got := FormatMatchMessage(event, result, "medium") + if !strings.Contains(got, "[medium]") { + t.Fatalf("FormatMatchMessage() missing level, got %q", got) + } + if !strings.Contains(got, "test-rule-id") { + t.Fatalf("FormatMatchMessage() missing rule ID, got %q", got) + } + if !strings.Contains(got, "PID=1234") { + t.Fatalf("FormatMatchMessage() missing PID, got %q", got) + } + if !strings.Contains(got, "Image=/usr/bin/whoami") { + t.Fatalf("FormatMatchMessage() missing Image, got %q", got) + } + if !strings.Contains(got, "CommandLine=whoami /all") { + t.Fatalf("FormatMatchMessage() missing CommandLine, got %q", got) + } +} + +func TestFormatMatchMessageEmptyFields(t *testing.T) { + event := &testEvent{fields: enrichment.DataFieldsMap{}} + result := sigmaengine.Result{ID: "rule-123"} + + got := FormatMatchMessage(event, result, "high") + if !strings.Contains(got, "[high]") { + t.Fatalf("FormatMatchMessage() missing level, got %q", got) + } + if !strings.Contains(got, "rule-123") { + t.Fatalf("FormatMatchMessage() missing rule ID, got %q", got) + } +} + +func TestHandleEventNoRuleset(t *testing.T) { + consumer := New(Config{}) + // Don't call InitializeWithRules, so ruleset is nil + + event := &testEvent{ + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/bin/bash"), + }, + } + + err := consumer.HandleEvent(event) + if err != nil { + t.Fatalf("HandleEvent() error = %v", err) + } + // Should return early without error when ruleset is nil +} + +func TestEvalFieldsMapNoRuleset(t *testing.T) { + consumer := New(Config{}) + // Don't call InitializeWithRules + + results := consumer.EvalFieldsMap(map[string]string{ + "Image": "/bin/bash", + }) + + if results != nil { + t.Fatalf("EvalFieldsMap() = %v, want nil when ruleset is nil", results) + } +} + +func TestLookupRuleLevelNilMap(t *testing.T) { + consumer := New(Config{}) + consumer.ruleLevels = nil + + got := consumer.lookupRuleLevel("any-id") + if got != "" { + t.Fatalf("lookupRuleLevel() = %q, want empty string for nil map", got) + } +} + func BenchmarkLookupRuleLevel(b *testing.B) { consumer := New(Config{}) for i := 0; i < 2000; i++ { diff --git a/lib/enrichment/correlator_test.go b/lib/enrichment/correlator_test.go index 457fa4b..f3d3089 100644 --- a/lib/enrichment/correlator_test.go +++ b/lib/enrichment/correlator_test.go @@ -63,3 +63,56 @@ func TestCorrelatorEviction(t *testing.T) { t.Error("PID 3 should be present") } } + +func TestCorrelatorWithZeroSize(t *testing.T) { + _, err := NewCorrelator(0) + if err == nil { + t.Fatal("NewCorrelator(0) expected error") + } +} + +func TestCorrelatorWithNegativeSize(t *testing.T) { + _, err := NewCorrelator(-1) + if err == nil { + t.Fatal("NewCorrelator(-1) expected error") + } +} + +func TestCorrelatorLen(t *testing.T) { + c, err := NewCorrelator(100) + if err != nil { + t.Fatal(err) + } + + if c.Len() != 0 { + t.Errorf("Len() = %d, want 0 for empty cache", c.Len()) + } + + c.Store(1, &ProcessInfo{PID: 1, Image: "a"}) + c.Store(2, &ProcessInfo{PID: 2, Image: "b"}) + + if c.Len() != 2 { + t.Errorf("Len() = %d, want 2", c.Len()) + } +} + +func TestCorrelatorUpdate(t *testing.T) { + c, err := NewCorrelator(100) + if err != nil { + t.Fatal(err) + } + + // Store initial value + c.Store(1234, &ProcessInfo{PID: 1234, Image: "/usr/bin/bash"}) + + // Update with new value + c.Store(1234, &ProcessInfo{PID: 1234, Image: "/usr/bin/zsh"}) + + got := c.Lookup(1234) + if got == nil { + t.Fatal("Lookup returned nil") + } + if got.Image != "/usr/bin/zsh" { + t.Errorf("Image = %q, want /usr/bin/zsh (updated)", got.Image) + } +} diff --git a/lib/enrichment/enricher_test.go b/lib/enrichment/enricher_test.go index d17b24d..3b8e93b 100644 --- a/lib/enrichment/enricher_test.go +++ b/lib/enrichment/enricher_test.go @@ -102,3 +102,94 @@ func TestEventEnricherAllowsRegisterDuringEnrich(t *testing.T) { t.Fatal("Enrich() appears deadlocked while registering a manipulator") } } + +func TestEventEnricherMultipleManipulatorsForSameKey(t *testing.T) { + enricher := NewEventEnricher() + + callOrder := make([]int, 0, 3) + + enricher.Register("TestProvider:1", func(fields DataFieldsMap) { + callOrder = append(callOrder, 1) + fields.AddField("First", "yes") + }) + + enricher.Register("TestProvider:1", func(fields DataFieldsMap) { + callOrder = append(callOrder, 2) + fields.AddField("Second", "yes") + }) + + enricher.Register("TestProvider:1", func(fields DataFieldsMap) { + callOrder = append(callOrder, 3) + // Can modify fields set by previous manipulators + if fields.Value("First").Valid { + fields.AddField("Third", "saw first") + } + }) + + fields := make(DataFieldsMap) + enricher.Enrich("TestProvider:1", fields) + + // All three manipulators should have been called + if len(callOrder) != 3 { + t.Fatalf("callOrder = %v, want 3 calls", callOrder) + } + + // Order should be preserved (FIFO) + if callOrder[0] != 1 || callOrder[1] != 2 || callOrder[2] != 3 { + t.Fatalf("callOrder = %v, want [1 2 3]", callOrder) + } + + // All fields should be set + if !fields.Value("First").Valid { + t.Error("First not set") + } + if !fields.Value("Second").Valid { + t.Error("Second not set") + } + if !fields.Value("Third").Valid || fields.Value("Third").String != "saw first" { + t.Errorf("Third = %v", fields.Value("Third")) + } +} + +func TestEventEnricherNoManipulatorsRegistered(t *testing.T) { + enricher := NewEventEnricher() + + fields := make(DataFieldsMap) + fields.AddField("Original", "value") + + // Should not panic when no manipulators registered + enricher.Enrich("UnregisteredKey", fields) + + // Original field should be unchanged + if !fields.Value("Original").Valid || fields.Value("Original").String != "value" { + t.Errorf("Original = %v", fields.Value("Original")) + } +} + +func TestDataFieldsMapWithNilValue(t *testing.T) { + m := make(DataFieldsMap) + m["NilEntry"] = nil + + v := m.Value("NilEntry") + if v.Valid { + t.Error("Value for nil entry should not be valid") + } +} + +func TestDataFieldsMapForEachSkipsNilValues(t *testing.T) { + m := make(DataFieldsMap) + m["Good"] = NewStringValue("value") + m["Nil"] = nil + + count := 0 + m.ForEach(func(key, value string) { + count++ + if key == "Nil" { + t.Error("ForEach should skip nil values") + } + }) + + if count != 1 { + t.Errorf("ForEach count = %d, want 1", count) + } +} diff --git a/lib/provider/ebpf/event_test.go b/lib/provider/ebpf/event_test.go new file mode 100644 index 0000000..294ed2e --- /dev/null +++ b/lib/provider/ebpf/event_test.go @@ -0,0 +1,159 @@ +package ebpf + +import ( + "testing" + "time" + + "github.com/Nextron-Labs/aurora-linux/lib/enrichment" + "github.com/Nextron-Labs/aurora-linux/lib/provider" +) + +func TestEbpfEventID(t *testing.T) { + event := &ebpfEvent{ + id: provider.EventIdentifier{ + ProviderName: "LinuxEBPF", + EventID: 1, + }, + } + + got := event.ID() + if got.ProviderName != "LinuxEBPF" { + t.Fatalf("ID().ProviderName = %q, want LinuxEBPF", got.ProviderName) + } + if got.EventID != 1 { + t.Fatalf("ID().EventID = %d, want 1", got.EventID) + } +} + +func TestEbpfEventProcess(t *testing.T) { + event := &ebpfEvent{pid: 12345} + + if got := event.Process(); got != 12345 { + t.Fatalf("Process() = %d, want 12345", got) + } +} + +func TestEbpfEventSource(t *testing.T) { + event := &ebpfEvent{source: "LinuxEBPF:ProcessExec"} + + if got := event.Source(); got != "LinuxEBPF:ProcessExec" { + t.Fatalf("Source() = %q, want LinuxEBPF:ProcessExec", got) + } +} + +func TestEbpfEventTime(t *testing.T) { + ts := time.Date(2026, 3, 17, 12, 0, 0, 0, time.UTC) + event := &ebpfEvent{ts: ts} + + if got := event.Time(); !got.Equal(ts) { + t.Fatalf("Time() = %v, want %v", got, ts) + } +} + +func TestEbpfEventValue(t *testing.T) { + event := &ebpfEvent{ + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/usr/bin/bash"), + "CommandLine": enrichment.NewStringValue("bash -c echo test"), + }, + } + + // Test existing field + got := event.Value("Image") + if !got.Valid { + t.Fatal("Value(Image).Valid = false, want true") + } + if got.String != "/usr/bin/bash" { + t.Fatalf("Value(Image).String = %q, want /usr/bin/bash", got.String) + } + + // Test non-existing field + got = event.Value("NonExistent") + if got.Valid { + t.Fatal("Value(NonExistent).Valid = true, want false") + } +} + +func TestEbpfEventForEach(t *testing.T) { + event := &ebpfEvent{ + fields: enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/bin/bash"), + "ProcessId": enrichment.NewStringValue("1234"), + }, + } + + collected := make(map[string]string) + event.ForEach(func(key, value string) { + collected[key] = value + }) + + if len(collected) != 2 { + t.Fatalf("ForEach collected %d fields, want 2", len(collected)) + } + if collected["Image"] != "/bin/bash" { + t.Fatalf("collected[Image] = %q, want /bin/bash", collected["Image"]) + } + if collected["ProcessId"] != "1234" { + t.Fatalf("collected[ProcessId] = %q, want 1234", collected["ProcessId"]) + } +} + +func TestEbpfEventFields(t *testing.T) { + fields := enrichment.DataFieldsMap{ + "Image": enrichment.NewStringValue("/bin/bash"), + } + event := &ebpfEvent{fields: fields} + + got := event.Fields() + if got["Image"].String() != "/bin/bash" { + t.Fatalf("Fields()[Image] = %q, want /bin/bash", got["Image"].String()) + } + + // Verify it returns the underlying map (not a copy) + got.AddField("NewField", "new value") + if event.fields["NewField"] == nil { + t.Fatal("Fields() should return the underlying map, not a copy") + } +} + +func TestEventConstants(t *testing.T) { + if ProviderName != "LinuxEBPF" { + t.Fatalf("ProviderName = %q, want LinuxEBPF", ProviderName) + } + if EventIDProcessCreation != 1 { + t.Fatalf("EventIDProcessCreation = %d, want 1", EventIDProcessCreation) + } + if EventIDNetworkConnection != 3 { + t.Fatalf("EventIDNetworkConnection = %d, want 3", EventIDNetworkConnection) + } + if EventIDFileEvent != 11 { + t.Fatalf("EventIDFileEvent = %d, want 11", EventIDFileEvent) + } +} + +func TestEbpfEventEmptyFields(t *testing.T) { + event := &ebpfEvent{ + fields: enrichment.DataFieldsMap{}, + } + + // ForEach with empty fields should not panic + count := 0 + event.ForEach(func(key, value string) { + count++ + }) + if count != 0 { + t.Fatalf("ForEach iterated %d times, want 0", count) + } +} + +func TestEbpfEventNilFieldsMap(t *testing.T) { + event := &ebpfEvent{ + fields: nil, + } + + // Value with nil fields should return invalid value + got := event.Value("Image") + if got.Valid { + t.Fatal("Value on nil fields should return Valid=false") + } +} diff --git a/lib/provider/ebpf/listener.go b/lib/provider/ebpf/listener.go index 6d13559..88bb93f 100644 --- a/lib/provider/ebpf/listener.go +++ b/lib/provider/ebpf/listener.go @@ -29,8 +29,6 @@ const ( // Listener implements the EventProvider interface using eBPF tracepoints. type Listener struct { - mu sync.Mutex - // Which sources are enabled enableExec bool enableFile bool diff --git a/lib/provider/ebpf/procfs_test.go b/lib/provider/ebpf/procfs_test.go new file mode 100644 index 0000000..e8eed33 --- /dev/null +++ b/lib/provider/ebpf/procfs_test.go @@ -0,0 +1,182 @@ +package ebpf + +import ( + "os" + "path/filepath" + "strconv" + "testing" +) + +func TestReadExeLinkCurrentProcess(t *testing.T) { + // Read our own process exe link + pid := uint32(os.Getpid()) + got, err := readExeLink(pid) + if err != nil { + t.Fatalf("readExeLink(%d) error = %v", pid, err) + } + if got == "" { + t.Fatal("readExeLink() returned empty string") + } + // Should be a go test binary + if !filepath.IsAbs(got) { + t.Fatalf("readExeLink() = %q, expected absolute path", got) + } +} + +func TestReadExeLinkInvalidPid(t *testing.T) { + _, err := readExeLink(999999999) // unlikely to exist + if err == nil { + t.Fatal("readExeLink(invalid PID) expected error") + } +} + +func TestReadCmdlineCurrentProcess(t *testing.T) { + pid := uint32(os.Getpid()) + got, err := readCmdline(pid) + if err != nil { + t.Fatalf("readCmdline(%d) error = %v", pid, err) + } + // cmdline contains NUL-separated args, should at least have the test binary + if len(got) == 0 { + t.Fatal("readCmdline() returned empty bytes") + } +} + +func TestReadCmdlineInvalidPid(t *testing.T) { + _, err := readCmdline(999999999) + if err == nil { + t.Fatal("readCmdline(invalid PID) expected error") + } +} + +func TestReadCwdCurrentProcess(t *testing.T) { + pid := uint32(os.Getpid()) + got, err := readCwd(pid) + if err != nil { + t.Fatalf("readCwd(%d) error = %v", pid, err) + } + if got == "" { + t.Fatal("readCwd() returned empty string") + } + if !filepath.IsAbs(got) { + t.Fatalf("readCwd() = %q, expected absolute path", got) + } +} + +func TestReadCwdInvalidPid(t *testing.T) { + _, err := readCwd(999999999) + if err == nil { + t.Fatal("readCwd(invalid PID) expected error") + } +} + +func TestReadLoginUIDCurrentProcess(t *testing.T) { + pid := uint32(os.Getpid()) + got := readLoginUID(pid) + // loginuid may be unset (returns "") or a numeric string + // We just verify it doesn't panic and returns something sensible + if got != "" { + // If set, should be parseable as int + _, err := strconv.Atoi(got) + if err != nil { + t.Fatalf("readLoginUID() = %q, not a valid UID", got) + } + } +} + +func TestReadLoginUIDInvalidPid(t *testing.T) { + got := readLoginUID(999999999) + if got != "" { + t.Fatalf("readLoginUID(invalid PID) = %q, want empty string", got) + } +} + +func TestReadLoginUIDUnsetValue(t *testing.T) { + // Test that loginUIDUnset constant is handled correctly + if loginUIDUnset != "4294967295" { + t.Fatalf("loginUIDUnset = %q, want 4294967295", loginUIDUnset) + } +} + +func TestReadFdLinkCurrentProcessStdout(t *testing.T) { + pid := uint32(os.Getpid()) + // fd 1 is stdout, usually valid + got, err := readFdLink(pid, 1) + if err != nil { + t.Fatalf("readFdLink(%d, 1) error = %v", pid, err) + } + // Should be something like /dev/pts/N or pipe: or similar + if got == "" { + t.Fatal("readFdLink() returned empty string") + } +} + +func TestReadFdLinkInvalidFd(t *testing.T) { + pid := uint32(os.Getpid()) + _, err := readFdLink(pid, 99999) // unlikely to exist + if err == nil { + t.Fatal("readFdLink(invalid fd) expected error") + } +} + +func TestResolveFilenameAbsolutePath(t *testing.T) { + // Absolute path should be returned (possibly with symlinks resolved) + got := resolveFilename(uint32(os.Getpid()), "/etc/passwd", -100) + if got == "" { + t.Fatal("resolveFilename() returned empty for absolute path") + } + if !filepath.IsAbs(got) { + t.Fatalf("resolveFilename(/etc/passwd) = %q, expected absolute", got) + } +} + +func TestResolveFilenameRelativeWithATFDCWD(t *testing.T) { + pid := uint32(os.Getpid()) + cwd, _ := os.Getwd() + + // AT_FDCWD = -100 means resolve relative to cwd + got := resolveFilename(pid, "testfile.txt", -100) + + // Should be cwd + testfile.txt (even if file doesn't exist) + expected := filepath.Join(cwd, "testfile.txt") + if got != expected { + t.Fatalf("resolveFilename(relative, AT_FDCWD) = %q, want %q", got, expected) + } +} + +func TestResolveFilenameInvalidPid(t *testing.T) { + // Invalid PID for relative path resolution + got := resolveFilename(999999999, "relative.txt", -100) + // Should return the original filename when cwd can't be read + if got != "relative.txt" { + t.Fatalf("resolveFilename(invalid PID) = %q, want relative.txt", got) + } +} + +func TestResolveFilenameWithSymlinks(t *testing.T) { + tmpDir := t.TempDir() + + // Create a file + realFile := filepath.Join(tmpDir, "real.txt") + if err := os.WriteFile(realFile, []byte("test"), 0644); err != nil { + t.Fatalf("WriteFile error: %v", err) + } + + // Create a symlink to it + symlink := filepath.Join(tmpDir, "link.txt") + if err := os.Symlink(realFile, symlink); err != nil { + t.Fatalf("Symlink error: %v", err) + } + + // resolveFilename should resolve the symlink + got := resolveFilename(uint32(os.Getpid()), symlink, -100) + if got != realFile { + t.Fatalf("resolveFilename(symlink) = %q, want %q", got, realFile) + } +} + +func TestMaxCmdlineBytesConstant(t *testing.T) { + if maxCmdlineBytes != 32768 { + t.Fatalf("maxCmdlineBytes = %d, want 32768", maxCmdlineBytes) + } +} diff --git a/lib/provider/ebpf/usercache_test.go b/lib/provider/ebpf/usercache_test.go new file mode 100644 index 0000000..ff8993a --- /dev/null +++ b/lib/provider/ebpf/usercache_test.go @@ -0,0 +1,174 @@ +package ebpf + +import ( + "os/user" + "strconv" + "testing" +) + +func TestNewUserCacheSuccess(t *testing.T) { + cache, err := NewUserCache(100) + if err != nil { + t.Fatalf("NewUserCache(100) error = %v", err) + } + if cache == nil { + t.Fatal("NewUserCache(100) returned nil") + } + if cache.cache == nil { + t.Fatal("NewUserCache(100).cache is nil") + } +} + +func TestNewUserCacheInvalidSize(t *testing.T) { + // LRU cache with size <= 0 should fail + _, err := NewUserCache(0) + if err == nil { + t.Fatal("NewUserCache(0) expected error") + } + + _, err = NewUserCache(-1) + if err == nil { + t.Fatal("NewUserCache(-1) expected error") + } +} + +func TestUserCacheLookupCurrentUser(t *testing.T) { + cache, err := NewUserCache(10) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // Look up current user + currentUser, err := user.Current() + if err != nil { + t.Skipf("Cannot determine current user: %v", err) + } + + uid64, err := strconv.ParseUint(currentUser.Uid, 10, 32) + if err != nil { + t.Skipf("Cannot parse UID: %v", err) + } + uid := uint32(uid64) + + got := cache.Lookup(uid) + if got != currentUser.Username { + t.Fatalf("Lookup(%d) = %q, want %q", uid, got, currentUser.Username) + } +} + +func TestUserCacheLookupRoot(t *testing.T) { + cache, err := NewUserCache(10) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // UID 0 is root on Unix systems + got := cache.Lookup(0) + if got != "root" { + t.Fatalf("Lookup(0) = %q, want root", got) + } +} + +func TestUserCacheLookupNonExistentUser(t *testing.T) { + cache, err := NewUserCache(10) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // Very high UID unlikely to exist + nonExistentUID := uint32(4294967290) + got := cache.Lookup(nonExistentUID) + + // Should return the numeric UID as string + expected := strconv.FormatUint(uint64(nonExistentUID), 10) + if got != expected { + t.Fatalf("Lookup(%d) = %q, want %q", nonExistentUID, got, expected) + } +} + +func TestUserCacheLookupCaching(t *testing.T) { + cache, err := NewUserCache(10) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // First lookup + first := cache.Lookup(0) + + // Second lookup should return cached value + second := cache.Lookup(0) + + if first != second { + t.Fatalf("Cached lookup mismatch: first=%q second=%q", first, second) + } +} + +func TestUserCacheLookupMultipleUsers(t *testing.T) { + cache, err := NewUserCache(10) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // Lookup root + root := cache.Lookup(0) + if root != "root" { + t.Fatalf("Lookup(0) = %q, want root", root) + } + + // Lookup current user + currentUser, err := user.Current() + if err != nil { + t.Skipf("Cannot determine current user: %v", err) + } + + uid64, err := strconv.ParseUint(currentUser.Uid, 10, 32) + if err != nil { + t.Skipf("Cannot parse UID: %v", err) + } + + current := cache.Lookup(uint32(uid64)) + if current != currentUser.Username { + t.Fatalf("Lookup(current UID) = %q, want %q", current, currentUser.Username) + } + + // Verify both are still in cache + if cache.Lookup(0) != root { + t.Fatal("Root lookup changed after adding current user") + } +} + +func TestUserCacheLRUEviction(t *testing.T) { + // Create a very small cache to test LRU eviction + cache, err := NewUserCache(2) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // Add more entries than capacity + cache.Lookup(0) // root + cache.Lookup(65534) // nobody (or numeric) + cache.Lookup(4294967290) // non-existent (numeric) + + // The LRU cache should have evicted the oldest entry + // We're just verifying it doesn't panic and still returns sensible values + got := cache.Lookup(4294967290) + expected := "4294967290" + if got != expected { + t.Fatalf("Lookup after eviction = %q, want %q", got, expected) + } +} + +func TestUserCacheLookupRepeatedCalls(t *testing.T) { + cache, err := NewUserCache(10) + if err != nil { + t.Fatalf("NewUserCache error = %v", err) + } + + // Multiple rapid lookups should be consistent + for i := 0; i < 100; i++ { + got := cache.Lookup(0) + if got != "root" { + t.Fatalf("Iteration %d: Lookup(0) = %q, want root", i, got) + } + } +} diff --git a/lib/provider/replay/replay_test.go b/lib/provider/replay/replay_test.go index 13ba6f8..9d61038 100644 --- a/lib/provider/replay/replay_test.go +++ b/lib/provider/replay/replay_test.go @@ -8,6 +8,7 @@ import ( "sync" "testing" + "github.com/Nextron-Labs/aurora-linux/lib/enrichment" "github.com/Nextron-Labs/aurora-linux/lib/provider" ) @@ -158,3 +159,306 @@ func TestReplayProviderConcurrentAddSourceAndSendEvents(t *testing.T) { wg.Wait() } + +func TestReplayProviderNameAndDescription(t *testing.T) { + r := New() + if got := r.Name(); got != "Replay" { + t.Fatalf("Name() = %q, want Replay", got) + } + if got := r.Description(); got != "Replay provider for pre-recorded events" { + t.Fatalf("Description() = %q", got) + } +} + +func TestReplayProviderLostEventsAlwaysZero(t *testing.T) { + r := New() + if got := r.LostEvents(); got != 0 { + t.Fatalf("LostEvents() = %d, want 0", got) + } +} + +func TestReplayProviderInitialize(t *testing.T) { + r := New() + if err := r.Initialize(); err != nil { + t.Fatalf("Initialize() error = %v", err) + } +} + +func TestReplayProviderHandlesMissingFile(t *testing.T) { + r := New("/nonexistent/path/events.jsonl") + count := 0 + // Should not panic, just log warning and continue + r.SendEvents(func(event provider.Event) { + count++ + }) + + if count != 0 { + t.Fatalf("expected 0 events from missing file, got %d", count) + } +} + +func TestReplayProviderSkipsEmptyLines(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "events.jsonl") + content := "" + + `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"1"}` + "\n" + + "\n" + + `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"2"}` + "\n" + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + r := New(path) + count := 0 + r.SendEvents(func(event provider.Event) { + count++ + }) + + if count != 2 { + t.Fatalf("expected 2 events (skipping empty line), got %d", count) + } +} + +func TestReplayProviderSkipsMalformedJSON(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "events.jsonl") + content := "" + + `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"1"}` + "\n" + + `not valid json` + "\n" + + `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"2"}` + "\n" + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + r := New(path) + count := 0 + r.SendEvents(func(event provider.Event) { + count++ + }) + + if count != 2 { + t.Fatalf("expected 2 events (skipping malformed JSON), got %d", count) + } +} + +func TestReplayProviderWithNoSourceFilters(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "events.jsonl") + content := "" + + `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"1"}` + "\n" + + `{"_provider":"LinuxEBPF","_eventID":3,"_source":"LinuxEBPF:NetConnect","ProcessId":"2"}` + "\n" + + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + r := New(path) + // Don't add any source filters - should emit all events + count := 0 + r.SendEvents(func(event provider.Event) { + count++ + }) + + if count != 2 { + t.Fatalf("expected 2 events with no source filters, got %d", count) + } +} + +func TestReplayProviderWithMultipleFiles(t *testing.T) { + tmpDir := t.TempDir() + path1 := filepath.Join(tmpDir, "events1.jsonl") + path2 := filepath.Join(tmpDir, "events2.jsonl") + + content1 := `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"1"}` + "\n" + content2 := `{"_provider":"LinuxEBPF","_eventID":1,"_source":"LinuxEBPF:ProcessExec","ProcessId":"2"}` + "\n" + + if err := os.WriteFile(path1, []byte(content1), 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + if err := os.WriteFile(path2, []byte(content2), 0644); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + r := New(path1, path2) + count := 0 + r.SendEvents(func(event provider.Event) { + count++ + }) + + if count != 2 { + t.Fatalf("expected 2 events from 2 files, got %d", count) + } +} + +func TestRecordToEventWithTimestamp(t *testing.T) { + record := map[string]interface{}{ + "_provider": "LinuxEBPF", + "_eventID": float64(1), + "_timestamp": "2026-03-17T12:00:00Z", + "ProcessId": "100", + } + + evt, err := recordToEvent(record) + if err != nil { + t.Fatalf("recordToEvent() error = %v", err) + } + + if evt.Time().IsZero() { + t.Fatal("Time() should not be zero") + } + if evt.Time().Year() != 2026 { + t.Fatalf("Time().Year() = %d, want 2026", evt.Time().Year()) + } +} + +func TestRecordToEventWithExplicitSource(t *testing.T) { + record := map[string]interface{}{ + "_provider": "LinuxEBPF", + "_eventID": float64(1), + "_source": "CustomSource", + "ProcessId": "100", + } + + evt, err := recordToEvent(record) + if err != nil { + t.Fatalf("recordToEvent() error = %v", err) + } + + if evt.Source() != "CustomSource" { + t.Fatalf("Source() = %q, want CustomSource", evt.Source()) + } +} + +func TestRecordToEventWithDefaultProvider(t *testing.T) { + record := map[string]interface{}{ + "_eventID": float64(1), + "ProcessId": "100", + } + + evt, err := recordToEvent(record) + if err != nil { + t.Fatalf("recordToEvent() error = %v", err) + } + + if evt.ID().ProviderName != "LinuxEBPF" { + t.Fatalf("ProviderName = %q, want LinuxEBPF (default)", evt.ID().ProviderName) + } +} + +func TestReplayEventMethods(t *testing.T) { + evt := &replayEvent{ + id: provider.EventIdentifier{ + ProviderName: "TestProvider", + EventID: 99, + }, + pid: 12345, + source: "TestSource", + fields: make(enrichment.DataFieldsMap), + } + evt.fields.AddField("Image", "/bin/test") + + if evt.ID().ProviderName != "TestProvider" { + t.Fatalf("ID().ProviderName = %q", evt.ID().ProviderName) + } + if evt.Process() != 12345 { + t.Fatalf("Process() = %d", evt.Process()) + } + if evt.Source() != "TestSource" { + t.Fatalf("Source() = %q", evt.Source()) + } + if !evt.Value("Image").Valid || evt.Value("Image").String != "/bin/test" { + t.Fatalf("Value(Image) = %v", evt.Value("Image")) + } + + count := 0 + evt.ForEach(func(k, v string) { + count++ + }) + if count != 1 { + t.Fatalf("ForEach count = %d, want 1", count) + } + + if evt.Fields() == nil { + t.Fatal("Fields() should not be nil") + } +} + +func TestParseEventIDInvalidString(t *testing.T) { + // Invalid string + got := parseEventID("invalid") + if got != 0 { + t.Fatalf("parseEventID(invalid) = %d, want 0", got) + } + + // Negative float + got = parseEventID(float64(-1)) + if got != 0 { + t.Fatalf("parseEventID(-1) = %d, want 0", got) + } + + // Float > MaxUint16 + got = parseEventID(float64(70000)) + if got != 0 { + t.Fatalf("parseEventID(70000) = %d, want 0", got) + } + + // Non-integer float + got = parseEventID(float64(1.5)) + if got != 0 { + t.Fatalf("parseEventID(1.5) = %d, want 0", got) + } + + // Unknown type + got = parseEventID([]int{1, 2, 3}) + if got != 0 { + t.Fatalf("parseEventID(slice) = %d, want 0", got) + } +} + +func TestParseUint32Invalid(t *testing.T) { + // Invalid string + got, ok := parseUint32("invalid") + if ok || got != 0 { + t.Fatalf("parseUint32(invalid) = (%d, %v), want (0, false)", got, ok) + } + + // Negative float + got, ok = parseUint32(float64(-1)) + if ok || got != 0 { + t.Fatalf("parseUint32(-1) = (%d, %v), want (0, false)", got, ok) + } + + // Float > MaxUint32 + got, ok = parseUint32(float64(5000000000)) + if ok || got != 0 { + t.Fatalf("parseUint32(>MaxUint32) = (%d, %v), want (0, false)", got, ok) + } + + // Non-integer float + got, ok = parseUint32(float64(1.5)) + if ok || got != 0 { + t.Fatalf("parseUint32(1.5) = (%d, %v), want (0, false)", got, ok) + } + + // Unknown type + got, ok = parseUint32([]int{1}) + if ok || got != 0 { + t.Fatalf("parseUint32(slice) = (%d, %v), want (0, false)", got, ok) + } +} + +func TestDefaultSourceForEventUnknown(t *testing.T) { + // Non-LinuxEBPF provider + got := defaultSourceForEvent("OtherProvider", 1) + if got != "" { + t.Fatalf("defaultSourceForEvent(OtherProvider, 1) = %q, want empty", got) + } + + // Unknown event ID + got = defaultSourceForEvent("LinuxEBPF", 99) + if got != "" { + t.Fatalf("defaultSourceForEvent(LinuxEBPF, 99) = %q, want empty", got) + } +}