Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions cmd/openrelik/internal/cli/files_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,10 @@ func TestFileListCmd(t *testing.T) {
name: "list files in folder",
args: []string{"file", "list", "123"},
expectedOutput: []string{
"ID : 1",
"DisplayName : file1.txt",
"Filesize : 1024",
"ID : 2",
"DisplayName : file2.txt",
"Filesize : 2048",
"file1.txt",
"1.0KB",
"file2.txt",
"2.0KB",
},
},
{
Expand Down Expand Up @@ -286,7 +284,7 @@ func TestFileUploadCmd(t *testing.T) {
if !strings.Contains(output, "chunks") {
t.Errorf("expected output to contain chunk info, but it was %q", output)
}
if !strings.Contains(output, "ID : 101") {
if !strings.Contains(output, "ID 101") {
t.Errorf("expected output to contain uploaded file ID, but it was %q", output)
}
}
17 changes: 7 additions & 10 deletions cmd/openrelik/internal/cli/folders_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,15 @@ func TestFolderListCmd(t *testing.T) {
name: "list root folders",
args: []string{"folder", "list"},
expectedOutput: []string{
"ID : 1",
"DisplayName : Root 1",
"ID : 2",
"DisplayName : Root 2",
"Root 1",
"Root 2",
},
},
{
name: "list subfolders",
args: []string{"folder", "list", "1"},
expectedOutput: []string{
"ID : 3",
"DisplayName : Sub 1",
"Sub 1",
},
},
}
Expand Down Expand Up @@ -114,16 +111,16 @@ func TestFolderCreateCmd(t *testing.T) {
name: "create root folder",
args: []string{"folder", "create", "--name", "New Root"},
expectedOutput: []string{
"ID : 100",
"DisplayName : New Root",
"ID 100",
"Display Name New Root",
},
},
{
name: "create subfolder",
args: []string{"folder", "create", "--name", "New Sub", "--parent", "1"},
expectedOutput: []string{
"ID : 200",
"DisplayName : New Sub",
"ID 200",
"Display Name New Sub",
},
},
}
Expand Down
21 changes: 17 additions & 4 deletions cmd/openrelik/internal/cli/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package cli

import (
"context"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -55,11 +56,17 @@ openrelik run strings --and grep 123`,
// dry-run applies to all worker subcommands
runCmd.PersistentFlags().Bool("dry-run", false, "Generate and display workflow spec without executing")

// Load workers from cache to build dynamic subcommands
workers, err := config.LoadWorkersCache()
// Load workers from cache (auto-refreshing if missing or stale) to build dynamic subcommands.
workers, err := config.LoadOrRefreshWorkersCache(context.Background(), func(ctx context.Context) ([]openrelik.Worker, error) {
client, err := newClient()
if err != nil {
return nil, err
}
ws, _, err := client.Workers().Registered(ctx)
return ws, err
})
if err != nil {
// If cache load fails, we don't add dynamic subcommands.
// User can run 'openrelik worker list --refresh' to populate cache.
// If workers cannot be loaded or fetched, skip dynamic subcommands.
return runCmd
}

Expand Down Expand Up @@ -156,6 +163,12 @@ func createWorkerCmd(worker openrelik.Worker, allWorkers []openrelik.Worker) *co
return nil
}

if !dryRun && downloadPolicy != "none" {
if info, err := os.Stat(outputDir); err != nil || !info.IsDir() {
return fmt.Errorf("output directory %q does not exist", outputDir)
}
}

client, err := newClient()
if err != nil {
return err
Expand Down
36 changes: 36 additions & 0 deletions cmd/openrelik/internal/cli/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package cli

import (
"bytes"
"os"
"strings"
"testing"

"github.com/openrelik/openrelik-go-client"
Expand Down Expand Up @@ -103,6 +105,40 @@ func TestDynamicWorkerCommands(t *testing.T) {
}
}

func TestRunCmdOutputDirValidation(t *testing.T) {
tmpDir := t.TempDir()
config.SetBaseDir(tmpDir)
defer config.SetBaseDir("")

testWorkers := []openrelik.Worker{
{TaskName: "openrelik-worker-strings.tasks.strings", DisplayName: "Strings"},
}
if err := config.SaveWorkersCache(testWorkers); err != nil {
t.Fatalf("Failed to save workers cache: %v", err)
}

os.Setenv("OPENRELIK_API_KEY", "test-key")
os.Setenv("OPENRELIK_SERVER_URL", "http://localhost:19999")
defer func() {
os.Unsetenv("OPENRELIK_API_KEY")
os.Unsetenv("OPENRELIK_SERVER_URL")
}()

root := NewRootCmd()
buf := new(bytes.Buffer)
root.SetOut(buf)
root.SetErr(buf)
root.SetArgs([]string{"run", "-o", "/nonexistent/output/dir", "strings", "123"})

err := root.Execute()
if err == nil {
t.Fatal("expected error for non-existent output directory, got nil")
}
if !strings.Contains(err.Error(), "does not exist") {
t.Errorf("expected 'does not exist' in error, got: %v", err)
}
}

func TestSliceArgs(t *testing.T) {
tests := []struct {
name string
Expand Down
8 changes: 4 additions & 4 deletions cmd/openrelik/internal/cli/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ func TestMeCmd(t *testing.T) {

output := buf.String()
expectedFields := []string{
"ID : 1",
"Username : testuser",
"DisplayName : Test User",
"IsAdmin : true",
"ID 1",
"Username testuser",
"Display Name Test User",
"IsAdmin true",
}

for _, field := range expectedFields {
Expand Down
10 changes: 2 additions & 8 deletions cmd/openrelik/internal/cli/workers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ func newWorkerCmd() *cobra.Command {
}

func newListWorkersCmd() *cobra.Command {
var refresh bool

cmd := &cobra.Command{
Use: "list",
Short: "List registered workers",
Expand All @@ -46,17 +44,13 @@ func newListWorkersCmd() *cobra.Command {
return err
}

if refresh {
if err := config.SaveWorkersCache(workers); err != nil {
return err
}
if err := config.SaveWorkersCache(workers); err != nil {
return err
}

return formatAndPrint(cmd, workers)
},
}

cmd.Flags().BoolVar(&refresh, "refresh", false, "Update the local workers cache")

return cmd
}
18 changes: 14 additions & 4 deletions cmd/openrelik/internal/cli/workers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,16 @@ import (
"os"
"strings"
"testing"

"github.com/openrelik/openrelik-go-client/cmd/cli/internal/config"
)

func TestWorkerListCmd(t *testing.T) {
// Use a temp dir so cache writes don't touch the real home directory.
tmpDir := t.TempDir()
config.SetBaseDir(tmpDir)
defer config.SetBaseDir("")

// Mock API server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/taskqueue/tasks/registered" {
Expand Down Expand Up @@ -64,10 +71,8 @@ func TestWorkerListCmd(t *testing.T) {

output := buf.String()
expectedFields := []string{
"TaskName : test-task",
"QueueName : test-queue",
"DisplayName : Test Task",
"Description : A test task",
"Test Task",
"test-queue",
}

for _, field := range expectedFields {
Expand All @@ -78,6 +83,11 @@ func TestWorkerListCmd(t *testing.T) {
}

func TestWorkerListCmdJSON(t *testing.T) {
// Use a temp dir so cache writes don't touch the real home directory.
tmpDir := t.TempDir()
config.SetBaseDir(tmpDir)
defer config.SetBaseDir("")

// Mock API server
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/api/v1/taskqueue/tasks/registered" {
Expand Down
8 changes: 4 additions & 4 deletions cmd/openrelik/internal/cli/workflows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,22 @@ func TestWorkflowCmd(t *testing.T) {
{
name: "info",
args: []string{"workflow", "info", "123"},
expected: "ID : 123",
expected: "ID 123",
},
{
name: "status",
args: []string{"workflow", "status", "123"},
expected: "Status : completed",
expected: "Status completed",
},
{
name: "run",
args: []string{"workflow", "run", "123"},
expected: "DisplayName : Running Workflow",
expected: "Display Name Running Workflow",
},
{
name: "create",
args: []string{"workflow", "create", "--file", "456"},
expected: "ID : 124",
expected: "ID 124",
},
}

Expand Down
52 changes: 48 additions & 4 deletions cmd/openrelik/internal/config/config.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package config

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"time"

"github.com/openrelik/openrelik-go-client"
)
Expand All @@ -16,8 +19,21 @@ const (
workersCacheFile = "workers_cache.json"
dirPerm = 0700
filePerm = 0600
workersCacheTTL = time.Hour
)

// ErrCacheMissing is returned when the workers cache file does not exist.
var ErrCacheMissing = errors.New("workers cache does not exist")

// ErrCacheStale is returned when the workers cache is older than the TTL.
var ErrCacheStale = errors.New("workers cache is stale")

// workersCacheEntry is the on-disk representation of the workers cache.
type workersCacheEntry struct {
Workers []openrelik.Worker `json:"workers"`
SavedAt time.Time `json:"saved_at"`
}

type Settings struct {
ServerURL string `json:"server_url"`
}
Expand Down Expand Up @@ -120,6 +136,8 @@ func SaveCredentials(c *Credentials) error {
return saveAtomic(path, data)
}

// LoadWorkersCache reads the workers cache from disk. Returns ErrCacheMissing if the
// cache file does not exist, or ErrCacheStale if it is older than workersCacheTTL.
func LoadWorkersCache() ([]openrelik.Worker, error) {
dir, err := GetConfigDir()
if err != nil {
Expand All @@ -128,28 +146,54 @@ func LoadWorkersCache() ([]openrelik.Worker, error) {
path := filepath.Join(dir, workersCacheFile)
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, ErrCacheMissing
}
return nil, fmt.Errorf("failed to read workers cache file %s: %w", path, err)
}
var w []openrelik.Worker
if err := json.Unmarshal(data, &w); err != nil {
var entry workersCacheEntry
if err := json.Unmarshal(data, &entry); err != nil {
return nil, fmt.Errorf("failed to unmarshal workers cache file %s: %w", path, err)
}
return w, nil
if time.Since(entry.SavedAt) > workersCacheTTL {
return nil, ErrCacheStale
}
return entry.Workers, nil
}

// SaveWorkersCache writes workers to the cache file with the current timestamp.
func SaveWorkersCache(w []openrelik.Worker) error {
dir, err := EnsureConfigDir()
if err != nil {
return err
}
data, err := json.MarshalIndent(w, "", " ")
entry := workersCacheEntry{Workers: w, SavedAt: time.Now()}
data, err := json.MarshalIndent(entry, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal workers cache: %w", err)
}
path := filepath.Join(dir, workersCacheFile)
return saveAtomic(path, data)
}

// LoadOrRefreshWorkersCache loads the workers cache. If the cache is missing or stale,
// it calls refresh to fetch fresh data, saves it, and returns it.
func LoadOrRefreshWorkersCache(ctx context.Context, refresh func(context.Context) ([]openrelik.Worker, error)) ([]openrelik.Worker, error) {
workers, err := LoadWorkersCache()
if err == nil {
return workers, nil
}
if !errors.Is(err, ErrCacheMissing) && !errors.Is(err, ErrCacheStale) {
return nil, err
}
workers, err = refresh(ctx)
if err != nil {
return nil, err
}
_ = SaveWorkersCache(workers) // best-effort; callers get the data regardless
return workers, nil
}

// saveAtomic writes data to a temporary file and then renames it to the target path
// to ensure the write is atomic.
func saveAtomic(path string, data []byte) error {
Expand Down
Loading
Loading