Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions cmd/yargen-util/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package main

import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

func TestStartGeneration_SendsFlagFields(t *testing.T) {
type requestBody struct {
JobID string `json:"job_id"`
Author string `json:"author"`
Reference string `json:"reference"`
ShowScores bool `json:"show_scores"`
ExcludeOpcodes bool `json:"exclude_opcodes"`
}

var received requestBody
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/generate" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
if r.Method != http.MethodPost {
t.Fatalf("unexpected method: %s", r.Method)
}
if err := json.NewDecoder(r.Body).Decode(&received); err != nil {
t.Fatalf("failed to decode request body: %v", err)
}
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

if err := startGeneration(srv.URL, "job-123", "neo", "ref-1", true, true); err != nil {
t.Fatalf("startGeneration returned error: %v", err)
}

if received.JobID != "job-123" {
t.Fatalf("unexpected job id: %q", received.JobID)
}
if received.Author != "neo" {
t.Fatalf("unexpected author: %q", received.Author)
}
if received.Reference != "ref-1" {
t.Fatalf("unexpected reference: %q", received.Reference)
}
if !received.ShowScores {
t.Fatalf("expected show_scores=true")
}
if !received.ExcludeOpcodes {
t.Fatalf("expected exclude_opcodes=true")
}
}

func TestWaitForRules_ReturnsRulesOnCompletedStatus(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/jobs/job-123" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
_, _ = w.Write([]byte(`{"status":"completed","rules":"rule x { condition: true }"}`))
}))
defer srv.Close()

rules, err := waitForRules(srv.URL, "job-123", 10, false)
if err != nil {
t.Fatalf("waitForRules returned error: %v", err)
}
if !strings.Contains(rules, "rule x") {
t.Fatalf("unexpected rules output: %q", rules)
}
}

func TestWaitForRules_ReturnsGenerationError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(`{"status":"failed","error":"backend failed"}`))
}))
defer srv.Close()

_, err := waitForRules(srv.URL, "job-123", 10, false)
if err == nil {
t.Fatal("expected failure error")
}
if !strings.Contains(err.Error(), "backend failed") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestWaitForRules_TimeoutRespectsWaitFlag(t *testing.T) {
// maxWait=0 should time out immediately without polling.
rules, err := waitForRules("http://127.0.0.1:1", "job-123", 0, false)
if err == nil {
t.Fatalf("expected timeout error, got rules: %q", rules)
}
if !strings.Contains(err.Error(), "timeout after 0 seconds") {
t.Fatalf("unexpected error: %v", err)
}
}
121 changes: 61 additions & 60 deletions cmd/yargen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,67 +69,14 @@ func runCLI() {
os.Exit(0)
}

// Handle single file mode (-f flag)
var tempDir string
var usingSingleFile bool
if *malwareFile != "" {
if *malwareDir != "" {
printBanner()
fmt.Println("\n[E] Cannot use both -f (single file) and -m (directory) flags")
os.Exit(1)
}

// Check if file exists
if _, err := os.Stat(*malwareFile); os.IsNotExist(err) {
printBanner()
fmt.Printf("\n[E] File not found: %s\n", *malwareFile)
os.Exit(1)
}

// Create temp directory and copy file
var err error
tempDir, err = os.MkdirTemp("", "yargen-single-")
if err != nil {
printBanner()
fmt.Printf("\n[E] Failed to create temp directory: %v\n", err)
os.Exit(1)
}

// Copy file to temp directory
srcFile, err := os.Open(*malwareFile)
if err != nil {
os.RemoveAll(tempDir)
printBanner()
fmt.Printf("\n[E] Failed to open file: %v\n", err)
os.Exit(1)
}
defer srcFile.Close()

dstPath := filepath.Join(tempDir, filepath.Base(*malwareFile))
dstFile, err := os.Create(dstPath)
if err != nil {
os.RemoveAll(tempDir)
printBanner()
fmt.Printf("\n[E] Failed to create temp file: %v\n", err)
os.Exit(1)
}
defer dstFile.Close()

if _, err := io.Copy(dstFile, srcFile); err != nil {
os.RemoveAll(tempDir)
printBanner()
fmt.Printf("\n[E] Failed to copy file: %v\n", err)
os.Exit(1)
}

*malwareDir = tempDir
usingSingleFile = true
}

// Ensure cleanup of temp directory
if tempDir != "" {
defer os.RemoveAll(tempDir)
resolvedDir, usingSingleFile, cleanup, err := resolveMalwareInput(*malwareDir, *malwareFile)
if err != nil {
printBanner()
fmt.Printf("\n[E] %v\n", err)
os.Exit(1)
}
defer cleanup()
*malwareDir = resolvedDir

if *malwareDir == "" {
printBanner()
Expand Down Expand Up @@ -330,3 +277,57 @@ func printSingleFileRecommendation() {
fmt.Println(" This avoids re-loading databases for each sample.")
fmt.Println("------------------------------------------------------------------------")
}

func resolveMalwareInput(malwareDir, malwareFile string) (resolvedDir string, usingSingleFile bool, cleanup func(), err error) {
cleanup = func() {}

if malwareFile == "" {
return malwareDir, false, cleanup, nil
}

if malwareDir != "" {
return "", false, cleanup, fmt.Errorf("cannot use both -f (single file) and -m (directory) flags")
}

info, err := os.Stat(malwareFile)
if err != nil {
if os.IsNotExist(err) {
return "", false, cleanup, fmt.Errorf("file not found: %s", malwareFile)
}
return "", false, cleanup, fmt.Errorf("failed to access file %s: %w", malwareFile, err)
}
if info.IsDir() {
return "", false, cleanup, fmt.Errorf("path is a directory, expected file: %s", malwareFile)
}

tempDir, err := os.MkdirTemp("", "yargen-single-")
if err != nil {
return "", false, cleanup, fmt.Errorf("failed to create temp directory: %w", err)
}

cleanup = func() {
_ = os.RemoveAll(tempDir)
}

srcFile, err := os.Open(malwareFile)
if err != nil {
cleanup()
return "", false, func() {}, fmt.Errorf("failed to open file: %w", err)
}
defer srcFile.Close()

dstPath := filepath.Join(tempDir, filepath.Base(malwareFile))
dstFile, err := os.Create(dstPath)
if err != nil {
cleanup()
return "", false, func() {}, fmt.Errorf("failed to create temp file: %w", err)
}
defer dstFile.Close()

if _, err := io.Copy(dstFile, srcFile); err != nil {
cleanup()
return "", false, func() {}, fmt.Errorf("failed to copy file: %w", err)
}

return tempDir, true, cleanup, nil
}
99 changes: 99 additions & 0 deletions cmd/yargen/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package main

import (
"os"
"path/filepath"
"strings"
"testing"
)

func TestResolveMalwareInput_DirectoryMode(t *testing.T) {
dir := t.TempDir()

resolvedDir, usingSingleFile, cleanup, err := resolveMalwareInput(dir, "")
if err != nil {
t.Fatalf("resolveMalwareInput returned unexpected error: %v", err)
}
if usingSingleFile {
t.Fatalf("expected directory mode, got single-file mode")
}
if resolvedDir != dir {
t.Fatalf("unexpected resolved directory: got %q want %q", resolvedDir, dir)
}

cleanup()
}

func TestResolveMalwareInput_RejectsBothFileAndDirectory(t *testing.T) {
dir := t.TempDir()
file := filepath.Join(dir, "sample.bin")
if err := os.WriteFile(file, []byte("abc"), 0o644); err != nil {
t.Fatalf("failed to create sample file: %v", err)
}

_, _, cleanup, err := resolveMalwareInput(dir, file)
if err == nil {
cleanup()
t.Fatal("expected error when both -m and -f are provided")
}
if !strings.Contains(err.Error(), "cannot use both -f") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestResolveMalwareInput_MissingFile(t *testing.T) {
_, _, cleanup, err := resolveMalwareInput("", filepath.Join(t.TempDir(), "missing.bin"))
if err == nil {
cleanup()
t.Fatal("expected missing-file error")
}
if !strings.Contains(err.Error(), "file not found") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestResolveMalwareInput_RejectsDirectoryInSingleFileMode(t *testing.T) {
dir := t.TempDir()

_, _, cleanup, err := resolveMalwareInput("", dir)
if err == nil {
cleanup()
t.Fatal("expected error for directory input in single-file mode")
}
if !strings.Contains(err.Error(), "expected file") {
t.Fatalf("unexpected error: %v", err)
}
}

func TestResolveMalwareInput_SingleFileCopiesToTempDirectory(t *testing.T) {
tempRoot := t.TempDir()
sourceFile := filepath.Join(tempRoot, "sample.bin")
content := []byte("malware sample")
if err := os.WriteFile(sourceFile, content, 0o644); err != nil {
t.Fatalf("failed to write source file: %v", err)
}

resolvedDir, usingSingleFile, cleanup, err := resolveMalwareInput("", sourceFile)
if err != nil {
t.Fatalf("resolveMalwareInput returned unexpected error: %v", err)
}
if !usingSingleFile {
t.Fatal("expected single-file mode")
}

copiedPath := filepath.Join(resolvedDir, filepath.Base(sourceFile))
copiedContent, err := os.ReadFile(copiedPath)
if err != nil {
cleanup()
t.Fatalf("failed to read copied file: %v", err)
}
if string(copiedContent) != string(content) {
cleanup()
t.Fatalf("copied file content mismatch: got %q want %q", string(copiedContent), string(content))
}

cleanup()
if _, err := os.Stat(resolvedDir); !os.IsNotExist(err) {
t.Fatalf("expected temp directory to be removed, got err=%v", err)
}
}
Loading