From 5edd6e4b94e1407caeb14c584eaa5cc52dcdf6c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=84=B6?= Date: Sun, 8 Mar 2026 15:44:48 +0800 Subject: [PATCH 1/4] chore(execd): complete the work for #332, and fix the linter errors. --- components/execd/main.go | 3 +- components/execd/pkg/runtime/command.go | 29 ++++++++----------- components/execd/pkg/runtime/types.go | 19 ++++++------ .../execd/pkg/web/controller/command.go | 4 +++ .../execd/pkg/web/model/codeinterpreting.go | 12 +++++++- .../pkg/web/model/codeinterpreting_test.go | 22 ++++++++++++++ 6 files changed, 61 insertions(+), 28 deletions(-) diff --git a/components/execd/main.go b/components/execd/main.go index 8e09e90f..f53f198d 100644 --- a/components/execd/main.go +++ b/components/execd/main.go @@ -17,6 +17,8 @@ package main import ( "fmt" + "github.com/alibaba/opensandbox/internal/version" + _ "go.uber.org/automaxprocs/maxprocs" "github.com/alibaba/opensandbox/execd/pkg/flag" @@ -24,7 +26,6 @@ import ( _ "github.com/alibaba/opensandbox/execd/pkg/util/safego" "github.com/alibaba/opensandbox/execd/pkg/web" "github.com/alibaba/opensandbox/execd/pkg/web/controller" - "github.com/alibaba/opensandbox/internal/version" ) // main initializes and starts the execd server. diff --git a/components/execd/pkg/runtime/command.go b/components/execd/pkg/runtime/command.go index c300a97a..cec6213c 100644 --- a/components/execd/pkg/runtime/command.go +++ b/components/execd/pkg/runtime/command.go @@ -37,7 +37,7 @@ import ( func buildCredential(uid, gid *uint32) (*syscall.Credential, error) { if uid == nil && gid == nil { - return nil, nil + return nil, nil //nolint:nilnil } cred := &syscall.Credential{} @@ -95,9 +95,20 @@ func (c *Controller) runCommand(ctx context.Context, request *ExecuteCodeRequest log.Info("received command: %v", request.Code) cmd := exec.CommandContext(ctx, "bash", "-c", request.Code) + // Configure credentials and process group + cred, err := buildCredential(request.Uid, request.Gid) + if err != nil { + return fmt.Errorf("failed to build credential: %w", err) + } + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Credential: cred, + } + cmd.Stdout = stdout cmd.Stderr = stderr cmd.Env = mergeEnvs(os.Environ(), loadExtraEnvFromFile()) + cmd.Dir = request.Cwd done := make(chan struct{}, 1) var wg sync.WaitGroup @@ -111,20 +122,7 @@ func (c *Controller) runCommand(ctx context.Context, request *ExecuteCodeRequest c.tailStdPipe(stderrPath, request.Hooks.OnExecuteStderr, done) }) - cmd.Dir = request.Cwd - - // Configure credentials and process group - cred, err := buildCredential(request.Uid, request.Gid) - if err != nil { - log.Error("failed to build credentials: %v", err) - } - cmd.SysProcAttr = &syscall.SysProcAttr{ - Setpgid: true, - Credential: cred, - } - err = cmd.Start() - if err != nil { request.Hooks.OnExecuteInit(session) request.Hooks.OnExecuteError(&execute.ErrorOutput{EName: "CommandExecError", EValue: err.Error()}) @@ -219,9 +217,7 @@ func (c *Controller) runBackgroundCommand(ctx context.Context, cancel context.Ca startAt := time.Now() log.Info("received command: %v", request.Code) cmd := exec.CommandContext(ctx, "bash", "-c", request.Code) - cmd.Dir = request.Cwd - // Configure credentials and process group cred, err := buildCredential(request.Uid, request.Gid) if err != nil { @@ -233,7 +229,6 @@ func (c *Controller) runBackgroundCommand(ctx context.Context, cancel context.Ca } cmd.Stdout = pipe - cmd.Stderr = pipe cmd.Env = mergeEnvs(os.Environ(), loadExtraEnvFromFile()) diff --git a/components/execd/pkg/runtime/types.go b/components/execd/pkg/runtime/types.go index b4322f8d..4dc459b3 100644 --- a/components/execd/pkg/runtime/types.go +++ b/components/execd/pkg/runtime/types.go @@ -34,16 +34,17 @@ type ExecuteResultHook struct { // ExecuteCodeRequest represents a code execution request with context and hooks. type ExecuteCodeRequest struct { - Language Language `json:"language"` - Code string `json:"code"` - Context string `json:"context"` - Timeout time.Duration `json:"timeout"` - Cwd string `json:"cwd"` - Envs map[string]string `json:"envs"` - Uid *uint32 `json:"uid,omitempty"` - Gid *uint32 `json:"gid,omitempty"` - Hooks ExecuteResultHook + Language Language `json:"language"` + Code string `json:"code"` + Context string `json:"context"` + Timeout time.Duration `json:"timeout"` + Cwd string `json:"cwd"` + Envs map[string]string `json:"envs"` + Uid *uint32 `json:"uid,omitempty"` + Gid *uint32 `json:"gid,omitempty"` + Hooks ExecuteResultHook } + // SetDefaultHooks installs stdout logging fallbacks for unset hooks. func (req *ExecuteCodeRequest) SetDefaultHooks() { if req.Hooks.OnExecuteResult == nil { diff --git a/components/execd/pkg/web/controller/command.go b/components/execd/pkg/web/controller/command.go index 9d61308b..98590ae4 100644 --- a/components/execd/pkg/web/controller/command.go +++ b/components/execd/pkg/web/controller/command.go @@ -133,6 +133,8 @@ func (c *CodeInterpretingController) buildExecuteCommandRequest(request model.Ru Code: request.Command, Cwd: request.Cwd, Timeout: timeout, + Gid: request.Gid, + Uid: request.Uid, } } else { return &runtime.ExecuteCodeRequest{ @@ -140,6 +142,8 @@ func (c *CodeInterpretingController) buildExecuteCommandRequest(request model.Ru Code: request.Command, Cwd: request.Cwd, Timeout: timeout, + Gid: request.Gid, + Uid: request.Uid, } } } diff --git a/components/execd/pkg/web/model/codeinterpreting.go b/components/execd/pkg/web/model/codeinterpreting.go index adc6452c..74901fbb 100644 --- a/components/execd/pkg/web/model/codeinterpreting.go +++ b/components/execd/pkg/web/model/codeinterpreting.go @@ -16,6 +16,7 @@ package model import ( "encoding/json" + "errors" "fmt" "strings" @@ -53,11 +54,20 @@ type RunCommandRequest struct { Background bool `json:"background,omitempty"` // TimeoutMs caps execution duration; 0 uses server default. TimeoutMs int64 `json:"timeout,omitempty" validate:"omitempty,gte=1"` + + Uid *uint32 `json:"uid,omitempty"` + Gid *uint32 `json:"gid,omitempty"` } func (r *RunCommandRequest) Validate() error { validate := validator.New() - return validate.Struct(r) + if err := validate.Struct(r); err != nil { + return err + } + if r.Gid != nil && r.Uid == nil { + return errors.New("uid is required when gid is provided") + } + return nil } type ServerStreamEventType string diff --git a/components/execd/pkg/web/model/codeinterpreting_test.go b/components/execd/pkg/web/model/codeinterpreting_test.go index 64eee35d..46999470 100644 --- a/components/execd/pkg/web/model/codeinterpreting_test.go +++ b/components/execd/pkg/web/model/codeinterpreting_test.go @@ -60,6 +60,28 @@ func TestRunCommandRequestValidate(t *testing.T) { } } +func ptr32(v uint32) *uint32 { return &v } + +func TestRunCommandRequestValidateUidGid(t *testing.T) { + // uid-only: valid + req := RunCommandRequest{Command: "id", Uid: ptr32(1000)} + if err := req.Validate(); err != nil { + t.Fatalf("expected success with uid only: %v", err) + } + + // uid + gid: valid + req = RunCommandRequest{Command: "id", Uid: ptr32(1000), Gid: ptr32(1000)} + if err := req.Validate(); err != nil { + t.Fatalf("expected success with uid and gid: %v", err) + } + + // gid-only: must be rejected + req = RunCommandRequest{Command: "id", Gid: ptr32(1000)} + if err := req.Validate(); err == nil { + t.Fatalf("expected validation error when gid is set without uid") + } +} + func TestServerStreamEventToJSON(t *testing.T) { event := ServerStreamEvent{ Type: StreamEventTypeStdout, From fca9fab72e86dc644f492a4d9c5c1d0254416c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=84=B6?= Date: Thu, 12 Mar 2026 23:00:39 +0800 Subject: [PATCH 2/4] fix(execd): fallback to sh when bash unavailable --- components/execd/bootstrap.sh | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/components/execd/bootstrap.sh b/components/execd/bootstrap.sh index a820ac26..c77c1e2c 100755 --- a/components/execd/bootstrap.sh +++ b/components/execd/bootstrap.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/bin/sh # Copyright 2025 Alibaba Group Holding Ltd. # @@ -45,13 +45,25 @@ elif [ $# -ge 1 ] && [ "$1" = "-c" ]; then CMD="$*" fi +SHELL_BIN="${BOOTSTRAP_SHELL:-}" +if [ -z "$SHELL_BIN" ]; then + if command -v bash >/dev/null 2>&1; then + SHELL_BIN="$(command -v bash)" + elif command -v sh >/dev/null 2>&1; then + SHELL_BIN="$(command -v sh)" + else + echo "error: neither bash nor sh found in PATH" >&2 + exit 1 + fi +fi + set -x if [ "$CMD" != "" ]; then - exec bash -c "$CMD" + exec "$SHELL_BIN" -c "$CMD" fi if [ $# -eq 0 ]; then - exec bash + exec "$SHELL_BIN" fi exec "$@" From a796b11b06a9b457d4257ea37e0a4da6e319b5a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=84=B6?= Date: Thu, 12 Mar 2026 23:10:09 +0800 Subject: [PATCH 3/4] feat(execd): support env passthrough for runCommand API --- components/execd/pkg/runtime/command.go | 6 ++- .../execd/pkg/runtime/command_windows.go | 6 ++- components/execd/pkg/runtime/env.go | 17 ++++++++ components/execd/pkg/runtime/env_test.go | 30 +++++++++++++ .../execd/pkg/web/controller/command.go | 2 + .../execd/pkg/web/controller/command_test.go | 43 +++++++++++++++++++ .../execd/pkg/web/model/codeinterpreting.go | 5 ++- 7 files changed, 103 insertions(+), 6 deletions(-) diff --git a/components/execd/pkg/runtime/command.go b/components/execd/pkg/runtime/command.go index 54b115b4..63c87148 100644 --- a/components/execd/pkg/runtime/command.go +++ b/components/execd/pkg/runtime/command.go @@ -117,7 +117,8 @@ func (c *Controller) runCommand(ctx context.Context, request *ExecuteCodeRequest cmd.Stdout = stdout cmd.Stderr = stderr - cmd.Env = mergeEnvs(os.Environ(), loadExtraEnvFromFile()) + extraEnv := mergeExtraEnvs(loadExtraEnvFromFile(), request.Envs) + cmd.Env = mergeEnvs(os.Environ(), extraEnv) cmd.Dir = request.Cwd done := make(chan struct{}, 1) @@ -241,7 +242,8 @@ func (c *Controller) runBackgroundCommand(ctx context.Context, cancel context.Ca cmd.Stdout = pipe cmd.Stderr = pipe - cmd.Env = mergeEnvs(os.Environ(), loadExtraEnvFromFile()) + extraEnv := mergeExtraEnvs(loadExtraEnvFromFile(), request.Envs) + cmd.Env = mergeEnvs(os.Environ(), extraEnv) // use DevNull as stdin so interactive programs exit immediately. cmd.Stdin = os.NewFile(uintptr(syscall.Stdin), os.DevNull) diff --git a/components/execd/pkg/runtime/command_windows.go b/components/execd/pkg/runtime/command_windows.go index 3c9aef11..888bd5e8 100644 --- a/components/execd/pkg/runtime/command_windows.go +++ b/components/execd/pkg/runtime/command_windows.go @@ -48,7 +48,8 @@ func (c *Controller) runCommand(ctx context.Context, request *ExecuteCodeRequest cmd.Stdout = stdout cmd.Stderr = stderr cmd.Dir = request.Cwd - cmd.Env = mergeEnvs(os.Environ(), loadExtraEnvFromFile()) + extraEnv := mergeExtraEnvs(loadExtraEnvFromFile(), request.Envs) + cmd.Env = mergeEnvs(os.Environ(), extraEnv) done := make(chan struct{}, 1) safego.Go(func() { @@ -121,7 +122,8 @@ func (c *Controller) runBackgroundCommand(ctx context.Context, cancel context.Ca cmd.Dir = request.Cwd cmd.Stdout = pipe cmd.Stderr = pipe - cmd.Env = mergeEnvs(os.Environ(), loadExtraEnvFromFile()) + extraEnv := mergeExtraEnvs(loadExtraEnvFromFile(), request.Envs) + cmd.Env = mergeEnvs(os.Environ(), extraEnv) devNull, _ := os.OpenFile(os.DevNull, os.O_RDWR, 0) // best-effort, ignore error cmd.Stdin = devNull diff --git a/components/execd/pkg/runtime/env.go b/components/execd/pkg/runtime/env.go index ccea644b..bf28b525 100644 --- a/components/execd/pkg/runtime/env.go +++ b/components/execd/pkg/runtime/env.go @@ -79,3 +79,20 @@ func mergeEnvs(base []string, extra map[string]string) []string { return out } + +// mergeExtraEnvs merges environment maps from file and request-level overrides. +func mergeExtraEnvs(fromFile, fromRequest map[string]string) map[string]string { + if len(fromRequest) == 0 { + return fromFile + } + + merged := make(map[string]string, len(fromFile)+len(fromRequest)) + for k, v := range fromFile { + merged[k] = v + } + for k, v := range fromRequest { + merged[k] = v + } + + return merged +} diff --git a/components/execd/pkg/runtime/env_test.go b/components/execd/pkg/runtime/env_test.go index 31778054..d932b51b 100644 --- a/components/execd/pkg/runtime/env_test.go +++ b/components/execd/pkg/runtime/env_test.go @@ -100,3 +100,33 @@ func TestMergeEnvsOverlaysExtra(t *testing.T) { t.Fatalf("C mismatch, got %q", got["C"]) } } + +func TestMergeExtraEnvsMergesAndOverrides(t *testing.T) { + fromFile := map[string]string{"A": "1", "B": "2"} + fromRequest := map[string]string{"B": "override", "C": "3"} + + got := mergeExtraEnvs(fromFile, fromRequest) + + if len(got) != 3 { + t.Fatalf("expected 3 entries, got %#v", got) + } + if got["A"] != "1" { + t.Fatalf("A mismatch, got %q", got["A"]) + } + if got["B"] != "override" { + t.Fatalf("B mismatch, got %q", got["B"]) + } + if got["C"] != "3" { + t.Fatalf("C mismatch, got %q", got["C"]) + } +} + +func TestMergeExtraEnvsHandlesNilFromFile(t *testing.T) { + fromRequest := map[string]string{"ONLY": "request"} + + got := mergeExtraEnvs(nil, fromRequest) + + if len(got) != 1 || got["ONLY"] != "request" { + t.Fatalf("unexpected merge result: %#v", got) + } +} diff --git a/components/execd/pkg/web/controller/command.go b/components/execd/pkg/web/controller/command.go index 98590ae4..d4da90df 100644 --- a/components/execd/pkg/web/controller/command.go +++ b/components/execd/pkg/web/controller/command.go @@ -135,6 +135,7 @@ func (c *CodeInterpretingController) buildExecuteCommandRequest(request model.Ru Timeout: timeout, Gid: request.Gid, Uid: request.Uid, + Envs: request.Envs, } } else { return &runtime.ExecuteCodeRequest{ @@ -144,6 +145,7 @@ func (c *CodeInterpretingController) buildExecuteCommandRequest(request model.Ru Timeout: timeout, Gid: request.Gid, Uid: request.Uid, + Envs: request.Envs, } } } diff --git a/components/execd/pkg/web/controller/command_test.go b/components/execd/pkg/web/controller/command_test.go index 29455783..439b1a10 100644 --- a/components/execd/pkg/web/controller/command_test.go +++ b/components/execd/pkg/web/controller/command_test.go @@ -18,11 +18,54 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "reflect" "testing" + "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" ) +func TestBuildExecuteCommandRequestForwardsEnvs(t *testing.T) { + ctrl := &CodeInterpretingController{} + envs := map[string]string{"FOO": "bar", "BAZ": "qux"} + req := model.RunCommandRequest{ + Command: "echo hi", + Cwd: "/tmp", + Envs: envs, + } + + execReq := ctrl.buildExecuteCommandRequest(req) + + if execReq.Language != runtime.Command { + t.Fatalf("expected runtime.Command, got %s", execReq.Language) + } + if !reflect.DeepEqual(execReq.Envs, envs) { + t.Fatalf("expected envs to be forwarded, got %#v", execReq.Envs) + } + if execReq.Cwd != "/tmp" { + t.Fatalf("expected Cwd to be forwarded, got %s", execReq.Cwd) + } +} + +func TestBuildExecuteCommandRequestForwardsEnvsBackground(t *testing.T) { + ctrl := &CodeInterpretingController{} + envs := map[string]string{"FOO": "bar"} + req := model.RunCommandRequest{ + Command: "echo hi", + Background: true, + Envs: envs, + } + + execReq := ctrl.buildExecuteCommandRequest(req) + + if execReq.Language != runtime.BackgroundCommand { + t.Fatalf("expected runtime.BackgroundCommand, got %s", execReq.Language) + } + if !reflect.DeepEqual(execReq.Envs, envs) { + t.Fatalf("expected envs to be forwarded, got %#v", execReq.Envs) + } +} + func setupCommandController(method, path string) (*CodeInterpretingController, *httptest.ResponseRecorder) { ctx, w := newTestContext(method, path, nil) ctrl := NewCodeInterpretingController(ctx) diff --git a/components/execd/pkg/web/model/codeinterpreting.go b/components/execd/pkg/web/model/codeinterpreting.go index 74901fbb..771b6d75 100644 --- a/components/execd/pkg/web/model/codeinterpreting.go +++ b/components/execd/pkg/web/model/codeinterpreting.go @@ -55,8 +55,9 @@ type RunCommandRequest struct { // TimeoutMs caps execution duration; 0 uses server default. TimeoutMs int64 `json:"timeout,omitempty" validate:"omitempty,gte=1"` - Uid *uint32 `json:"uid,omitempty"` - Gid *uint32 `json:"gid,omitempty"` + Uid *uint32 `json:"uid,omitempty"` + Gid *uint32 `json:"gid,omitempty"` + Envs map[string]string `json:"envs,omitempty"` } func (r *RunCommandRequest) Validate() error { From e25876118dc62f275be0097101958cb911e3124d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=AB=98=E7=84=B6?= Date: Thu, 12 Mar 2026 23:26:17 +0800 Subject: [PATCH 4/4] test(execd): refactor unit tests to testify require/assert --- .../execd/pkg/runtime/command_status_test.go | 86 ++++------- components/execd/pkg/runtime/command_test.go | 114 ++++---------- components/execd/pkg/runtime/context_test.go | 140 ++++++------------ components/execd/pkg/runtime/env_test.go | 70 +++------ components/execd/pkg/runtime/helpers_test.go | 6 +- components/execd/pkg/runtime/sql_test.go | 61 +++----- components/execd/pkg/runtime/types_test.go | 17 +-- .../execd/pkg/web/controller/basic_test.go | 58 ++------ .../web/controller/codeinterpreting_test.go | 46 ++---- .../execd/pkg/web/controller/command_test.go | 53 ++----- .../controller/filesystem_download_test.go | 15 -- .../pkg/web/controller/filesystem_test.go | 71 +++------ .../web/controller/filesystem_upload_test.go | 15 -- .../execd/pkg/web/controller/utils_test.go | 70 +++------ .../pkg/web/model/codeinterpreting_test.go | 51 ++----- 15 files changed, 255 insertions(+), 618 deletions(-) delete mode 100644 components/execd/pkg/web/controller/filesystem_download_test.go delete mode 100644 components/execd/pkg/web/controller/filesystem_upload_test.go diff --git a/components/execd/pkg/runtime/command_status_test.go b/components/execd/pkg/runtime/command_status_test.go index 8eb8a6d6..12dae236 100644 --- a/components/execd/pkg/runtime/command_status_test.go +++ b/components/execd/pkg/runtime/command_status_test.go @@ -21,14 +21,15 @@ import ( "strings" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestGetCommandStatus_NotFound(t *testing.T) { c := NewController("", "") - if _, err := c.GetCommandStatus("missing"); err == nil { - t.Fatalf("expected error for missing session") - } + _, err := c.GetCommandStatus("missing") + require.Error(t, err, "expected error for missing session") } func TestGetCommandStatus_Running(t *testing.T) { @@ -45,12 +46,8 @@ func TestGetCommandStatus_Running(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) - if err := c.runBackgroundCommand(ctx, cancel, req); err != nil { - t.Fatalf("runBackgroundCommand error: %v", err) - } - if session == "" { - t.Fatalf("session should be set by OnExecuteInit") - } + require.NoError(t, c.runBackgroundCommand(ctx, cancel, req)) + require.NotEmpty(t, session, "session should be set by OnExecuteInit") // Poll until status is registered (runBackgroundCommand stores kernel asynchronously). deadline := time.Now().Add(5 * time.Second) @@ -67,24 +64,15 @@ func TestGetCommandStatus_Running(t *testing.T) { time.Sleep(50 * time.Millisecond) continue } - t.Fatalf("GetCommandStatus unexpected error: %v", err) - } - if err != nil { - t.Fatalf("GetCommandStatus error after retry: %v", err) + require.NoError(t, err, "GetCommandStatus unexpected error") } + require.NoError(t, err, "GetCommandStatus error after retry") - if status == nil || !status.Running { - t.Fatalf("expected running=true") - } - if status.ExitCode != nil { - t.Fatalf("expected exitCode to be nil while running") - } - if status.FinishedAt != nil { - t.Fatalf("expected finishedAt to be nil while running") - } - if status.StartedAt.IsZero() { - t.Fatalf("expected startedAt to be set") - } + require.NotNil(t, status) + require.True(t, status.Running, "expected running=true") + require.Nil(t, status.ExitCode, "expected exitCode to be nil while running") + require.Nil(t, status.FinishedAt, "expected finishedAt to be nil while running") + require.False(t, status.StartedAt.IsZero(), "expected startedAt to be set") t.Log(status) } @@ -96,9 +84,7 @@ func TestSeekBackgroundCommandOutput_Completed(t *testing.T) { stdoutPath := filepath.Join(tmpDir, session+".stdout") stdoutContent := "hello stdout" - if err := os.WriteFile(stdoutPath, []byte(stdoutContent), 0o644); err != nil { - t.Fatalf("write stdout: %v", err) - } + require.NoError(t, os.WriteFile(stdoutPath, []byte(stdoutContent), 0o644)) started := time.Now().Add(-2 * time.Second) finished := time.Now() @@ -116,16 +102,10 @@ func TestSeekBackgroundCommandOutput_Completed(t *testing.T) { c.storeCommandKernel(session, kernel) output, cursor, err := c.SeekBackgroundCommandOutput(session, 0) - if err != nil { - t.Fatalf("GetCommandOutput error: %v", err) - } + require.NoError(t, err, "GetCommandOutput error") - if cursor <= 0 { - t.Fatalf("expected cursor>=0") - } - if string(output) != stdoutContent { - t.Fatalf("expected output=%s, got %s", stdoutContent, string(output)) - } + require.Greater(t, cursor, int64(0), "expected cursor>=0") + require.Equal(t, stdoutContent, string(output)) } func TestSeekBackgroundCommandOutput_WithRunBackgroundCommand(t *testing.T) { @@ -144,12 +124,8 @@ func TestSeekBackgroundCommandOutput_WithRunBackgroundCommand(t *testing.T) { } ctx, cancel := context.WithCancel(context.Background()) - if err := c.runBackgroundCommand(ctx, cancel, req); err != nil { - t.Fatalf("runBackgroundCommand error: %v", err) - } - if session == "" { - t.Fatalf("session should be set by OnExecuteInit") - } + require.NoError(t, c.runBackgroundCommand(ctx, cancel, req)) + require.NotEmpty(t, session, "session should be set by OnExecuteInit") var ( output []byte @@ -165,25 +141,13 @@ func TestSeekBackgroundCommandOutput_WithRunBackgroundCommand(t *testing.T) { } time.Sleep(100 * time.Millisecond) } - if err != nil { - t.Fatalf("SeekBackgroundCommandOutput error: %v", err) - } - if string(output) != expected { - t.Fatalf("unexpected output: %q", string(output)) - } - if cursor < int64(len(expected)) { - t.Fatalf("cursor should advance to end of file, got %d", cursor) - } + require.NoError(t, err, "SeekBackgroundCommandOutput error") + require.Equal(t, expected, string(output)) + require.GreaterOrEqual(t, cursor, int64(len(expected)), "cursor should advance to end of file") // incremental seek from current cursor should return empty data and same-or-higher cursor output2, cursor2, err := c.SeekBackgroundCommandOutput(session, cursor) - if err != nil { - t.Fatalf("SeekBackgroundCommandOutput (second call) error: %v", err) - } - if len(output2) != 0 { - t.Fatalf("expected no new output, got %q", string(output2)) - } - if cursor2 < cursor { - t.Fatalf("cursor should not move backwards: got %d < %d", cursor2, cursor) - } + require.NoError(t, err, "SeekBackgroundCommandOutput (second call) error") + require.Empty(t, output2, "expected no new output") + require.GreaterOrEqual(t, cursor2, cursor, "cursor should not move backwards") } diff --git a/components/execd/pkg/runtime/command_test.go b/components/execd/pkg/runtime/command_test.go index 8f09c9d5..e282d40a 100644 --- a/components/execd/pkg/runtime/command_test.go +++ b/components/execd/pkg/runtime/command_test.go @@ -28,6 +28,7 @@ import ( "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestReadFromPos_SplitsOnCRAndLF(t *testing.T) { @@ -37,46 +38,32 @@ func TestReadFromPos_SplitsOnCRAndLF(t *testing.T) { mutex := &sync.Mutex{} initial := "line1\nprog 10%\rprog 20%\rprog 30%\nlast\n" - if err := os.WriteFile(logFile, []byte(initial), 0o644); err != nil { - t.Fatalf("write initial file: %v", err) - } + require.NoError(t, os.WriteFile(logFile, []byte(initial), 0o644)) var got []string c := &Controller{} nextPos := c.readFromPos(mutex, logFile, 0, func(s string) { got = append(got, s) }, false) want := []string{"line1", "prog 10%", "prog 20%", "prog 30%", "last"} - if len(got) != len(want) { - t.Fatalf("unexpected token count: got %d want %d", len(got), len(want)) - } + require.Len(t, got, len(want)) for i := range want { - if got[i] != want[i] { - t.Fatalf("token[%d]: got %q want %q", i, got[i], want[i]) - } + require.Equal(t, want[i], got[i], "token[%d] mismatch", i) } // append more content and ensure incremental read only yields the new part appendPart := "tail1\r\ntail2\n" f, err := os.OpenFile(logFile, os.O_APPEND|os.O_WRONLY, 0o644) - if err != nil { - t.Fatalf("open append: %v", err) - } - if _, err := f.WriteString(appendPart); err != nil { - f.Close() - t.Fatalf("append write: %v", err) - } + require.NoError(t, err) + _, err = f.WriteString(appendPart) + require.NoError(t, err, "append write") _ = f.Close() got = got[:0] c.readFromPos(mutex, logFile, nextPos, func(s string) { got = append(got, s) }, false) want = []string{"tail1", "tail2"} - if len(got) != len(want) { - t.Fatalf("incremental token count: got %d want %d", len(got), len(want)) - } + require.Len(t, got, len(want)) for i := range want { - if got[i] != want[i] { - t.Fatalf("incremental token[%d]: got %q want %q", i, got[i], want[i]) - } + require.Equal(t, want[i], got[i], "incremental token[%d] mismatch", i) } } @@ -86,20 +73,14 @@ func TestReadFromPos_LongLine(t *testing.T) { // construct a single line larger than the default 64KB, but under 5MB longLine := strings.Repeat("x", 256*1024) + "\n" // 256KB - if err := os.WriteFile(logFile, []byte(longLine), 0o644); err != nil { - t.Fatalf("write long line: %v", err) - } + require.NoError(t, os.WriteFile(logFile, []byte(longLine), 0o644)) var got []string c := &Controller{} c.readFromPos(&sync.Mutex{}, logFile, 0, func(s string) { got = append(got, s) }, false) - if len(got) != 1 { - t.Fatalf("expected one token, got %d", len(got)) - } - if got[0] != strings.TrimSuffix(longLine, "\n") { - t.Fatalf("long line mismatch: got %d chars want %d chars", len(got[0]), len(longLine)-1) - } + require.Len(t, got, 1, "expected one token") + require.Equal(t, strings.TrimSuffix(longLine, "\n"), got[0], "long line mismatch") } func TestReadFromPos_FlushesTrailingLine(t *testing.T) { @@ -159,7 +140,7 @@ func TestRunCommand_Echo(t *testing.T) { stderrLines = append(stderrLines, s) }, OnExecuteError: func(err *execute.ErrorOutput) { - t.Fatalf("unexpected error hook: %+v", err) + require.Failf(t, "unexpected error hook", "%+v", err) }, OnExecuteComplete: func(_ time.Duration) { completeCh <- struct{}{} @@ -167,25 +148,17 @@ func TestRunCommand_Echo(t *testing.T) { }, } - if err := c.runCommand(ctx, req); err != nil { - t.Fatalf("runCommand returned error: %v", err) - } + require.NoError(t, c.runCommand(ctx, req)) select { case <-completeCh: case <-time.After(2 * time.Second): - t.Fatalf("timeout waiting for completion hook") + require.Fail(t, "timeout waiting for completion hook") } - if sessionID == "" { - t.Fatalf("expected session id to be set") - } - if len(stdoutLines) != 1 || stdoutLines[0] != "hello" { - t.Fatalf("unexpected stdout: %#v", stdoutLines) - } - if len(stderrLines) != 1 || stderrLines[0] != "errline" { - t.Fatalf("unexpected stderr: %#v", stderrLines) - } + require.NotEmpty(t, sessionID, "expected session id to be set") + require.Equal(t, []string{"hello"}, stdoutLines) + require.Equal(t, []string{"errline"}, stderrLines) } func TestRunCommand_Error(t *testing.T) { @@ -227,31 +200,20 @@ func TestRunCommand_Error(t *testing.T) { }, } - if err := c.runCommand(ctx, req); err != nil { - t.Fatalf("runCommand returned error: %v", err) - } + require.NoError(t, c.runCommand(ctx, req)) select { case <-completeCh: case <-time.After(2 * time.Second): - t.Fatalf("timeout waiting for completion hook") + require.Fail(t, "timeout waiting for completion hook") } - if sessionID == "" { - t.Fatalf("expected session id to be set") - } - if len(stdoutLines) == 0 || stdoutLines[0] != "before" { - t.Fatalf("unexpected stdout: %#v", stdoutLines) - } - if len(stderrLines) != 0 { - t.Fatalf("expected no stderr, got %#v", stderrLines) - } - if gotErr == nil { - t.Fatalf("expected error hook to be called") - } - if gotErr.EName != "CommandExecError" || gotErr.EValue != "3" { - t.Fatalf("unexpected error payload: %+v", gotErr) - } + require.NotEmpty(t, sessionID, "expected session id to be set") + require.Equal(t, []string{"before"}, stdoutLines) + require.Empty(t, stderrLines, "expected no stderr") + require.NotNil(t, gotErr, "expected error hook to be called") + require.Equal(t, "CommandExecError", gotErr.EName) + require.Equal(t, "3", gotErr.EValue) } // TestStdLogDescriptor_AutoCreatesTempDir verifies that stdLogDescriptor @@ -268,20 +230,14 @@ func TestStdLogDescriptor_AutoCreatesTempDir(t *testing.T) { c := NewController("", "") stdout, stderr, err := c.stdLogDescriptor("test-session") - if err != nil { - t.Fatalf("stdLogDescriptor failed with missing temp dir: %v", err) - } + require.NoError(t, err) stdout.Close() stderr.Close() // The directory must have been created. info, err := os.Stat(missingDir) - if err != nil { - t.Fatalf("expected temp dir to be created, stat error: %v", err) - } - if !info.IsDir() { - t.Fatalf("expected %s to be a directory", missingDir) - } + require.NoError(t, err, "expected temp dir to be created, stat error") + require.True(t, info.IsDir(), "expected %s to be a directory", missingDir) } // TestCombinedOutputDescriptor_AutoCreatesTempDir verifies that @@ -297,16 +253,10 @@ func TestCombinedOutputDescriptor_AutoCreatesTempDir(t *testing.T) { c := NewController("", "") f, err := c.combinedOutputDescriptor("test-session") - if err != nil { - t.Fatalf("combinedOutputDescriptor failed with missing temp dir: %v", err) - } + require.NoError(t, err) f.Close() info, err := os.Stat(missingDir) - if err != nil { - t.Fatalf("expected temp dir to be created, stat error: %v", err) - } - if !info.IsDir() { - t.Fatalf("expected %s to be a directory", missingDir) - } + require.NoError(t, err, "expected temp dir to be created, stat error") + require.True(t, info.IsDir(), "expected %s to be a directory", missingDir) } diff --git a/components/execd/pkg/runtime/context_test.go b/components/execd/pkg/runtime/context_test.go index 34eb956c..9a0376d0 100644 --- a/components/execd/pkg/runtime/context_test.go +++ b/components/execd/pkg/runtime/context_test.go @@ -15,13 +15,14 @@ package runtime import ( - "errors" "net/http" "net/http/httptest" "os" "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestListContextsAndNewIpynbPath(t *testing.T) { @@ -30,33 +31,22 @@ func TestListContextsAndNewIpynbPath(t *testing.T) { c.defaultLanguageJupyterSessions[Go] = "session-go-default" pyContexts, err := c.listLanguageContexts(Python) - if err != nil { - t.Fatalf("listLanguageContexts returned error: %v", err) - } - if len(pyContexts) != 1 || pyContexts[0].ID != "session-python" || pyContexts[0].Language != Python { - t.Fatalf("unexpected python contexts: %#v", pyContexts) - } + require.NoError(t, err) + require.Len(t, pyContexts, 1) + require.Equal(t, "session-python", pyContexts[0].ID) + require.Equal(t, Python, pyContexts[0].Language) allContexts, err := c.listAllContexts() - if err != nil { - t.Fatalf("listAllContexts returned error: %v", err) - } - if len(allContexts) != 2 { - t.Fatalf("expected two contexts, got %d", len(allContexts)) - } + require.NoError(t, err) + require.Len(t, allContexts, 2) tmpDir := filepath.Join(t.TempDir(), "nested") path, err := c.newIpynbPath("abc123", tmpDir) - if err != nil { - t.Fatalf("newIpynbPath error: %v", err) - } - if _, statErr := os.Stat(tmpDir); statErr != nil { - t.Fatalf("expected directory to be created: %v", statErr) - } + require.NoError(t, err) + _, statErr := os.Stat(tmpDir) + require.NoError(t, statErr, "expected directory to be created") expected := filepath.Join(tmpDir, "abc123.ipynb") - if path != expected { - t.Fatalf("unexpected ipynb path: got %s want %s", path, expected) - } + require.Equal(t, expected, path) } func TestNewContextID_UniqueAndLength(t *testing.T) { @@ -64,64 +54,45 @@ func TestNewContextID_UniqueAndLength(t *testing.T) { id1 := c.newContextID() id2 := c.newContextID() - if id1 == "" || id2 == "" { - t.Fatalf("expected non-empty ids") - } - if id1 == id2 { - t.Fatalf("expected unique ids, got identical: %s", id1) - } - if len(id1) != 32 || len(id2) != 32 { - t.Fatalf("expected 32-char ids, got %d and %d", len(id1), len(id2)) - } + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + require.NotEqual(t, id1, id2, "expected unique ids") + require.Len(t, id1, 32) + require.Len(t, id2, 32) } func TestNewIpynbPath_ErrorWhenCwdIsFile(t *testing.T) { c := NewController("", "") tmpFile := filepath.Join(t.TempDir(), "file.txt") - if err := os.WriteFile(tmpFile, []byte("x"), 0o644); err != nil { - t.Fatalf("prepare file: %v", err) - } + require.NoError(t, os.WriteFile(tmpFile, []byte("x"), 0o644)) - if _, err := c.newIpynbPath("abc", tmpFile); err == nil { - t.Fatalf("expected error when cwd is a file") - } + _, err := c.newIpynbPath("abc", tmpFile) + require.Error(t, err, "expected error when cwd is a file") } func TestListContextUnsupportedLanguage(t *testing.T) { c := NewController("", "") _, err := c.ListContext(Command.String()) - if err == nil { - t.Fatalf("expected error for command language") - } - if _, err := c.ListContext(BackgroundCommand.String()); err == nil { - t.Fatalf("expected error for background-command language") - } - if _, err := c.ListContext(SQL.String()); err == nil { - t.Fatalf("expected error for sql language") - } + require.Error(t, err, "expected error for command language") + _, err = c.ListContext(BackgroundCommand.String()) + require.Error(t, err, "expected error for background-command language") + _, err = c.ListContext(SQL.String()) + require.Error(t, err, "expected error for sql language") } func TestDeleteContext_NotFound(t *testing.T) { c := NewController("", "") err := c.DeleteContext("missing") - if err == nil { - t.Fatalf("expected ErrContextNotFound") - } - if !errors.Is(err, ErrContextNotFound) { - t.Fatalf("unexpected error: %v", err) - } + require.Error(t, err, "expected ErrContextNotFound") + require.ErrorIs(t, err, ErrContextNotFound) } func TestGetContext_NotFound(t *testing.T) { c := NewController("", "") _, err := c.GetContext("missing") - if err == nil { - t.Fatalf("expected ErrContextNotFound") - } - if !errors.Is(err, ErrContextNotFound) { - t.Fatalf("unexpected error: %v", err) - } + require.Error(t, err, "expected ErrContextNotFound") + require.ErrorIs(t, err, ErrContextNotFound) } func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { @@ -129,12 +100,8 @@ func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { // mock jupyter server that accepts DELETE server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - t.Fatalf("unexpected method: %s", r.Method) - } - if !strings.HasSuffix(r.URL.Path, "/api/sessions/"+sessionID) { - t.Fatalf("unexpected path: %s", r.URL.Path) - } + require.Equal(t, http.MethodDelete, r.Method, "unexpected method") + require.True(t, strings.HasSuffix(r.URL.Path, "/api/sessions/"+sessionID), "unexpected path: %s", r.URL.Path) w.WriteHeader(http.StatusNoContent) })) defer server.Close() @@ -143,16 +110,11 @@ func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { c.jupyterClientMap[sessionID] = &jupyterKernel{language: Python} c.defaultLanguageJupyterSessions[Python] = sessionID - if err := c.DeleteContext(sessionID); err != nil { - t.Fatalf("DeleteContext returned error: %v", err) - } + require.NoError(t, c.DeleteContext(sessionID)) - if kernel := c.getJupyterKernel(sessionID); kernel != nil { - t.Fatalf("expected cache to be cleared, found: %+v", kernel) - } - if _, ok := c.defaultLanguageJupyterSessions[Python]; ok { - t.Fatalf("expected default session entry to be removed") - } + require.Nil(t, c.getJupyterKernel(sessionID), "expected cache to be cleared") + _, ok := c.defaultLanguageJupyterSessions[Python] + require.False(t, ok, "expected default session entry to be removed") } func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) { @@ -163,15 +125,13 @@ func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) { // mock jupyter server to accept two deletes deleteCalls := make(map[string]int) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodDelete { - t.Fatalf("unexpected method: %s", r.Method) - } + require.Equal(t, http.MethodDelete, r.Method, "unexpected method") if strings.Contains(r.URL.Path, session1) { deleteCalls[session1]++ } else if strings.Contains(r.URL.Path, session2) { deleteCalls[session2]++ } else { - t.Fatalf("unexpected path: %s", r.URL.Path) + require.Failf(t, "unexpected path", "%s", r.URL.Path) } w.WriteHeader(http.StatusNoContent) })) @@ -182,20 +142,14 @@ func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) { c.jupyterClientMap[session2] = &jupyterKernel{language: lang} c.defaultLanguageJupyterSessions[lang] = session2 - if err := c.DeleteLanguageContext(lang); err != nil { - t.Fatalf("DeleteLanguageContext returned error: %v", err) - } - - if _, ok := c.jupyterClientMap[session1]; ok { - t.Fatalf("expected session1 removed from cache") - } - if _, ok := c.jupyterClientMap[session2]; ok { - t.Fatalf("expected session2 removed from cache") - } - if _, ok := c.defaultLanguageJupyterSessions[lang]; ok { - t.Fatalf("expected default entry removed") - } - if deleteCalls[session1] != 1 || deleteCalls[session2] != 1 { - t.Fatalf("unexpected delete calls: %+v", deleteCalls) - } + require.NoError(t, c.DeleteLanguageContext(lang)) + + _, ok := c.jupyterClientMap[session1] + require.False(t, ok, "expected session1 removed from cache") + _, ok = c.jupyterClientMap[session2] + require.False(t, ok, "expected session2 removed from cache") + _, ok = c.defaultLanguageJupyterSessions[lang] + require.False(t, ok, "expected default entry removed") + require.Equal(t, 1, deleteCalls[session1]) + require.Equal(t, 1, deleteCalls[session2]) } diff --git a/components/execd/pkg/runtime/env_test.go b/components/execd/pkg/runtime/env_test.go index d932b51b..98014a29 100644 --- a/components/execd/pkg/runtime/env_test.go +++ b/components/execd/pkg/runtime/env_test.go @@ -19,13 +19,13 @@ import ( "path/filepath" "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestLoadExtraEnvFromFileUnset(t *testing.T) { t.Setenv("EXECD_ENVS", "") - if got := loadExtraEnvFromFile(); got != nil { - t.Fatalf("expected nil when EXECD_ENVS unset, got %#v", got) - } + require.Nil(t, loadExtraEnvFromFile(), "expected nil when EXECD_ENVS unset") } func TestLoadExtraEnvFromFileParsesAndExpands(t *testing.T) { @@ -44,24 +44,15 @@ func TestLoadExtraEnvFromFileParsesAndExpands(t *testing.T) { "", }, "\n") - if err := os.WriteFile(envFile, []byte(content), 0o644); err != nil { - t.Fatalf("write env file: %v", err) - } + require.NoError(t, os.WriteFile(envFile, []byte(content), 0o644)) got := loadExtraEnvFromFile() - if len(got) != 3 { - t.Fatalf("expected 3 entries, got %#v", got) - } - - if got["FOO"] != "bar" { - t.Fatalf("FOO mismatch, got %q", got["FOO"]) - } - if got["PATH"] != "/opt/base/bin" { - t.Fatalf("PATH expansion mismatch, got %q", got["PATH"]) - } - if val, ok := got["EMPTY"]; !ok || val != "" { - t.Fatalf("EMPTY mismatch, got %q (present=%v)", val, ok) - } + require.Len(t, got, 3) + require.Equal(t, "bar", got["FOO"]) + require.Equal(t, "/opt/base/bin", got["PATH"]) + val, ok := got["EMPTY"] + require.True(t, ok) + require.Equal(t, "", val) } func TestLoadExtraEnvFromFileMissingFile(t *testing.T) { @@ -69,9 +60,7 @@ func TestLoadExtraEnvFromFileMissingFile(t *testing.T) { envFile := filepath.Join(dir, "does-not-exist") t.Setenv("EXECD_ENVS", envFile) - if got := loadExtraEnvFromFile(); got != nil { - t.Fatalf("expected nil for missing file, got %#v", got) - } + require.Nil(t, loadExtraEnvFromFile(), "expected nil for missing file") } func TestMergeEnvsOverlaysExtra(t *testing.T) { @@ -87,18 +76,10 @@ func TestMergeEnvsOverlaysExtra(t *testing.T) { } } - if len(got) != 3 { - t.Fatalf("expected 3 entries, got %#v", got) - } - if got["A"] != "1" { - t.Fatalf("A mismatch, got %q", got["A"]) - } - if got["B"] != "override" { - t.Fatalf("B mismatch, got %q", got["B"]) - } - if got["C"] != "3" { - t.Fatalf("C mismatch, got %q", got["C"]) - } + require.Len(t, got, 3) + require.Equal(t, "1", got["A"]) + require.Equal(t, "override", got["B"]) + require.Equal(t, "3", got["C"]) } func TestMergeExtraEnvsMergesAndOverrides(t *testing.T) { @@ -107,18 +88,10 @@ func TestMergeExtraEnvsMergesAndOverrides(t *testing.T) { got := mergeExtraEnvs(fromFile, fromRequest) - if len(got) != 3 { - t.Fatalf("expected 3 entries, got %#v", got) - } - if got["A"] != "1" { - t.Fatalf("A mismatch, got %q", got["A"]) - } - if got["B"] != "override" { - t.Fatalf("B mismatch, got %q", got["B"]) - } - if got["C"] != "3" { - t.Fatalf("C mismatch, got %q", got["C"]) - } + require.Len(t, got, 3) + require.Equal(t, "1", got["A"]) + require.Equal(t, "override", got["B"]) + require.Equal(t, "3", got["C"]) } func TestMergeExtraEnvsHandlesNilFromFile(t *testing.T) { @@ -126,7 +99,6 @@ func TestMergeExtraEnvsHandlesNilFromFile(t *testing.T) { got := mergeExtraEnvs(nil, fromRequest) - if len(got) != 1 || got["ONLY"] != "request" { - t.Fatalf("unexpected merge result: %#v", got) - } + require.Len(t, got, 1) + require.Equal(t, "request", got["ONLY"]) } diff --git a/components/execd/pkg/runtime/helpers_test.go b/components/execd/pkg/runtime/helpers_test.go index 1f50a4a2..843c5673 100644 --- a/components/execd/pkg/runtime/helpers_test.go +++ b/components/execd/pkg/runtime/helpers_test.go @@ -24,6 +24,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/require" ) type stubDriver struct { @@ -109,8 +111,6 @@ func newStubDB(t *testing.T, d *stubDriver) *sql.DB { driverName := fmt.Sprintf("stub-%d", time.Now().UnixNano()) sql.Register(driverName, &stubConnector{d: d}) db, err := sql.Open(driverName, "") - if err != nil { - t.Fatalf("open stub db: %v", err) - } + require.NoError(t, err) return db } diff --git a/components/execd/pkg/runtime/sql_test.go b/components/execd/pkg/runtime/sql_test.go index a8eca895..bb103fb6 100644 --- a/components/execd/pkg/runtime/sql_test.go +++ b/components/execd/pkg/runtime/sql_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" + "github.com/stretchr/testify/require" ) func TestExecuteSelectSQLQuery_Success(t *testing.T) { @@ -58,32 +59,20 @@ func TestExecuteSelectSQLQuery_Success(t *testing.T) { }, } - if err := c.executeSelectSQLQuery(context.Background(), req); err != nil { - t.Fatalf("executeSelectSQLQuery returned error: %v", err) - } + require.NoError(t, c.executeSelectSQLQuery(context.Background(), req)) - if gotError != nil { - t.Fatalf("unexpected error hook: %+v", gotError) - } - if !completed { - t.Fatalf("expected completion hook to be triggered") - } + require.Nil(t, gotError, "unexpected error hook") + require.True(t, completed, "expected completion hook to be triggered") raw, ok := gotResult["text/plain"] - if !ok { - t.Fatalf("expected text/plain payload") - } + require.True(t, ok, "expected text/plain payload") var qr QueryResult - if err := json.Unmarshal([]byte(raw.(string)), &qr); err != nil { - t.Fatalf("unmarshal result: %v", err) - } + require.NoError(t, json.Unmarshal([]byte(raw.(string)), &qr)) - if len(qr.Columns) != 2 || qr.Columns[0] != "id" || qr.Columns[1] != "name" { - t.Fatalf("unexpected columns: %#v", qr.Columns) - } - if len(qr.Rows) != 2 || qr.Rows[0][0] != "1" || qr.Rows[1][1] != "bob" { - t.Fatalf("unexpected rows: %#v", qr.Rows) - } + require.Equal(t, []string{"id", "name"}, qr.Columns, "unexpected columns") + require.Len(t, qr.Rows, 2, "unexpected rows") + require.Equal(t, "1", qr.Rows[0][0]) + require.Equal(t, "bob", qr.Rows[1][1]) } func TestExecuteUpdateSQLQuery_Success(t *testing.T) { @@ -116,30 +105,18 @@ func TestExecuteUpdateSQLQuery_Success(t *testing.T) { }, } - if err := c.executeUpdateSQLQuery(context.Background(), req); err != nil { - t.Fatalf("executeUpdateSQLQuery returned error: %v", err) - } + require.NoError(t, c.executeUpdateSQLQuery(context.Background(), req)) - if gotError != nil { - t.Fatalf("unexpected error hook: %+v", gotError) - } - if !completed { - t.Fatalf("expected completion hook to be triggered") - } + require.Nil(t, gotError, "unexpected error hook") + require.True(t, completed, "expected completion hook to be triggered") raw, ok := gotResult["text/plain"] - if !ok { - t.Fatalf("expected text/plain payload") - } + require.True(t, ok, "expected text/plain payload") var qr QueryResult - if err := json.Unmarshal([]byte(raw.(string)), &qr); err != nil { - t.Fatalf("unmarshal result: %v", err) - } + require.NoError(t, json.Unmarshal([]byte(raw.(string)), &qr)) - if len(qr.Columns) != 1 || qr.Columns[0] != "affected_rows" { - t.Fatalf("unexpected columns: %#v", qr.Columns) - } - if len(qr.Rows) != 1 || len(qr.Rows[0]) != 1 || qr.Rows[0][0] != float64(3) { - t.Fatalf("unexpected affected rows: %#v", qr.Rows) - } + require.Equal(t, []string{"affected_rows"}, qr.Columns, "unexpected columns") + require.Len(t, qr.Rows, 1, "unexpected rows length") + require.Len(t, qr.Rows[0], 1, "unexpected row entry length") + require.Equal(t, float64(3), qr.Rows[0][0]) } diff --git a/components/execd/pkg/runtime/types_test.go b/components/execd/pkg/runtime/types_test.go index 6f84bbfd..5b7afeb9 100644 --- a/components/execd/pkg/runtime/types_test.go +++ b/components/execd/pkg/runtime/types_test.go @@ -17,6 +17,8 @@ package runtime import ( "reflect" "testing" + + "github.com/stretchr/testify/require" ) func TestExecuteCodeRequest_SetDefaultHooks(t *testing.T) { @@ -30,13 +32,10 @@ func TestExecuteCodeRequest_SetDefaultHooks(t *testing.T) { req.SetDefaultHooks() - if req.Hooks.OnExecuteStdout == nil || req.Hooks.OnExecuteStderr == nil || req.Hooks.OnExecuteError == nil { - t.Fatalf("expected default hooks to be populated") - } - if req.Hooks.OnExecuteResult == nil { - t.Fatalf("expected OnExecuteResult to remain set") - } - if reflect.ValueOf(req.Hooks.OnExecuteResult).Pointer() != reflect.ValueOf(customResult).Pointer() { - t.Fatalf("default hooks should not override existing ones") - } + require.NotNil(t, req.Hooks.OnExecuteStdout) + require.NotNil(t, req.Hooks.OnExecuteStderr) + require.NotNil(t, req.Hooks.OnExecuteError) + require.NotNil(t, req.Hooks.OnExecuteResult, "expected OnExecuteResult to remain set") + require.Equal(t, reflect.ValueOf(customResult).Pointer(), reflect.ValueOf(req.Hooks.OnExecuteResult).Pointer(), + "default hooks should not override existing ones") } diff --git a/components/execd/pkg/web/controller/basic_test.go b/components/execd/pkg/web/controller/basic_test.go index 7bebba77..545cd098 100644 --- a/components/execd/pkg/web/controller/basic_test.go +++ b/components/execd/pkg/web/controller/basic_test.go @@ -21,6 +21,7 @@ import ( "testing" "github.com/alibaba/opensandbox/execd/pkg/web/model" + "github.com/stretchr/testify/require" ) func TestBasicControllerRespondSuccess(t *testing.T) { @@ -30,16 +31,10 @@ func TestBasicControllerRespondSuccess(t *testing.T) { payload := map[string]string{"status": "ok"} ctrl.RespondSuccess(payload) - if rec.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", rec.Code) - } + require.Equal(t, http.StatusOK, rec.Code) var resp map[string]string - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if resp["status"] != "ok" { - t.Fatalf("unexpected body: %#v", resp) - } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, "ok", resp["status"]) } func TestBasicControllerRespondError(t *testing.T) { @@ -48,16 +43,11 @@ func TestBasicControllerRespondError(t *testing.T) { ctrl.RespondError(http.StatusBadRequest, model.ErrorCodeInvalidRequest, "boom") - if rec.Code != http.StatusBadRequest { - t.Fatalf("expected status 400, got %d", rec.Code) - } + require.Equal(t, http.StatusBadRequest, rec.Code) var resp model.ErrorResponse - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if resp.Code != model.ErrorCodeInvalidRequest || resp.Message != "boom" { - t.Fatalf("unexpected body: %#v", resp) - } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, model.ErrorCodeInvalidRequest, resp.Code) + require.Equal(t, "boom", resp.Message) } func setupBasicController(method string) (*basicController, *httptest.ResponseRecorder) { @@ -72,16 +62,10 @@ func TestRespondSuccessWritesPayload(t *testing.T) { payload := map[string]string{"status": "ok"} ctrl.RespondSuccess(payload) - if w.Code != http.StatusOK { - t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code) - } + require.Equal(t, http.StatusOK, w.Code) var got map[string]string - if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { - t.Fatalf("failed to unmarshal body: %v", err) - } - if got["status"] != "ok" { - t.Fatalf("unexpected response body: %#v", got) - } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + require.Equal(t, "ok", got["status"]) } func TestRespondErrorAddsCodeAndMessage(t *testing.T) { @@ -89,19 +73,11 @@ func TestRespondErrorAddsCodeAndMessage(t *testing.T) { ctrl.RespondError(http.StatusBadRequest, model.ErrorCodeInvalidRequest, "invalid payload") - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) - } + require.Equal(t, http.StatusBadRequest, w.Code) var got model.ErrorResponse - if err := json.Unmarshal(w.Body.Bytes(), &got); err != nil { - t.Fatalf("failed to unmarshal error body: %v", err) - } - if got.Code != model.ErrorCodeInvalidRequest { - t.Fatalf("unexpected code: %s", got.Code) - } - if got.Message != "invalid payload" { - t.Fatalf("unexpected message: %s", got.Message) - } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + require.Equal(t, model.ErrorCodeInvalidRequest, got.Code) + require.Equal(t, "invalid payload", got.Message) } func TestQueryInt64(t *testing.T) { @@ -122,9 +98,7 @@ func TestQueryInt64(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := ctrl.QueryInt64(tt.query, tt.def) - if got != tt.expected { - t.Fatalf("QueryInt64(%q, %d) = %d, want %d", tt.query, tt.def, got, tt.expected) - } + require.Equalf(t, tt.expected, got, "QueryInt64(%q, %d)", tt.query, tt.def) }) } } diff --git a/components/execd/pkg/web/controller/codeinterpreting_test.go b/components/execd/pkg/web/controller/codeinterpreting_test.go index 6b2b7ead..d4f11dcc 100644 --- a/components/execd/pkg/web/controller/codeinterpreting_test.go +++ b/components/execd/pkg/web/controller/codeinterpreting_test.go @@ -23,6 +23,7 @@ import ( "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" + "github.com/stretchr/testify/require" ) func TestBuildExecuteCodeRequestDefaultsToCommand(t *testing.T) { @@ -37,12 +38,9 @@ func TestBuildExecuteCodeRequestDefaultsToCommand(t *testing.T) { execReq := ctrl.buildExecuteCodeRequest(req) - if execReq.Language != runtime.Command { - t.Fatalf("expected default language %s, got %s", runtime.Command, execReq.Language) - } - if execReq.Context != "session-1" || execReq.Code != "echo 1" { - t.Fatalf("unexpected execute request: %#v", execReq) - } + require.Equal(t, runtime.Command, execReq.Language, "expected default language") + require.Equal(t, "session-1", execReq.Context) + require.Equal(t, "echo 1", execReq.Code) } func TestBuildExecuteCodeRequestRespectsLanguage(t *testing.T) { @@ -59,9 +57,7 @@ func TestBuildExecuteCodeRequestRespectsLanguage(t *testing.T) { execReq := ctrl.buildExecuteCodeRequest(req) - if execReq.Language != runtime.Language("python") { - t.Fatalf("expected python language, got %s", execReq.Language) - } + require.Equal(t, runtime.Language("python"), execReq.Language) } func TestGetContext_NotFoundReturns404(t *testing.T) { @@ -75,20 +71,12 @@ func TestGetContext_NotFoundReturns404(t *testing.T) { ctrl.GetContext() - if w.Code != http.StatusNotFound { - t.Fatalf("expected status %d, got %d", http.StatusNotFound, w.Code) - } + require.Equal(t, http.StatusNotFound, w.Code) var resp model.ErrorResponse - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if resp.Code != model.ErrorCodeContextNotFound { - t.Fatalf("unexpected error code: %s", resp.Code) - } - if resp.Message != "context missing not found" { - t.Fatalf("unexpected message: %s", resp.Message) - } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Equal(t, model.ErrorCodeContextNotFound, resp.Code) + require.Equal(t, "context missing not found", resp.Message) } func TestGetContext_MissingIDReturns400(t *testing.T) { @@ -97,18 +85,10 @@ func TestGetContext_MissingIDReturns400(t *testing.T) { ctrl.GetContext() - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) - } + require.Equal(t, http.StatusBadRequest, w.Code) var resp model.ErrorResponse - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to decode response: %v", err) - } - if resp.Code != model.ErrorCodeMissingQuery { - t.Fatalf("unexpected error code: %s", resp.Code) - } - if resp.Message != "missing path parameter 'contextId'" { - t.Fatalf("unexpected message: %s", resp.Message) - } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Equal(t, model.ErrorCodeMissingQuery, resp.Code) + require.Equal(t, "missing path parameter 'contextId'", resp.Message) } diff --git a/components/execd/pkg/web/controller/command_test.go b/components/execd/pkg/web/controller/command_test.go index 439b1a10..74735963 100644 --- a/components/execd/pkg/web/controller/command_test.go +++ b/components/execd/pkg/web/controller/command_test.go @@ -23,6 +23,7 @@ import ( "github.com/alibaba/opensandbox/execd/pkg/runtime" "github.com/alibaba/opensandbox/execd/pkg/web/model" + "github.com/stretchr/testify/require" ) func TestBuildExecuteCommandRequestForwardsEnvs(t *testing.T) { @@ -36,15 +37,9 @@ func TestBuildExecuteCommandRequestForwardsEnvs(t *testing.T) { execReq := ctrl.buildExecuteCommandRequest(req) - if execReq.Language != runtime.Command { - t.Fatalf("expected runtime.Command, got %s", execReq.Language) - } - if !reflect.DeepEqual(execReq.Envs, envs) { - t.Fatalf("expected envs to be forwarded, got %#v", execReq.Envs) - } - if execReq.Cwd != "/tmp" { - t.Fatalf("expected Cwd to be forwarded, got %s", execReq.Cwd) - } + require.Equal(t, runtime.Command, execReq.Language) + require.True(t, reflect.DeepEqual(execReq.Envs, envs), "expected envs to be forwarded") + require.Equal(t, "/tmp", execReq.Cwd) } func TestBuildExecuteCommandRequestForwardsEnvsBackground(t *testing.T) { @@ -58,12 +53,8 @@ func TestBuildExecuteCommandRequestForwardsEnvsBackground(t *testing.T) { execReq := ctrl.buildExecuteCommandRequest(req) - if execReq.Language != runtime.BackgroundCommand { - t.Fatalf("expected runtime.BackgroundCommand, got %s", execReq.Language) - } - if !reflect.DeepEqual(execReq.Envs, envs) { - t.Fatalf("expected envs to be forwarded, got %#v", execReq.Envs) - } + require.Equal(t, runtime.BackgroundCommand, execReq.Language) + require.True(t, reflect.DeepEqual(execReq.Envs, envs), "expected envs to be forwarded") } func setupCommandController(method, path string) (*CodeInterpretingController, *httptest.ResponseRecorder) { @@ -77,20 +68,12 @@ func TestGetCommandStatus_MissingID(t *testing.T) { ctrl.GetCommandStatus() - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) - } + require.Equal(t, http.StatusBadRequest, w.Code) var resp model.ErrorResponse - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to parse response: %v", err) - } - if resp.Code != model.ErrorCodeInvalidRequest { - t.Fatalf("unexpected error code: %s", resp.Code) - } - if resp.Message != "missing command execution id" { - t.Fatalf("unexpected message: %s", resp.Message) - } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Equal(t, model.ErrorCodeInvalidRequest, resp.Code) + require.Equal(t, "missing command execution id", resp.Message) } func TestGetBackgroundCommandOutput_MissingID(t *testing.T) { @@ -98,18 +81,10 @@ func TestGetBackgroundCommandOutput_MissingID(t *testing.T) { ctrl.GetBackgroundCommandOutput() - if w.Code != http.StatusBadRequest { - t.Fatalf("expected status %d, got %d", http.StatusBadRequest, w.Code) - } + require.Equal(t, http.StatusBadRequest, w.Code) var resp model.ErrorResponse - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("failed to parse response: %v", err) - } - if resp.Code != model.ErrorCodeMissingQuery { - t.Fatalf("unexpected error code: %s", resp.Code) - } - if resp.Message != "missing command execution id" { - t.Fatalf("unexpected message: %s", resp.Message) - } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + require.Equal(t, model.ErrorCodeMissingQuery, resp.Code) + require.Equal(t, "missing command execution id", resp.Message) } diff --git a/components/execd/pkg/web/controller/filesystem_download_test.go b/components/execd/pkg/web/controller/filesystem_download_test.go deleted file mode 100644 index cb3bd469..00000000 --- a/components/execd/pkg/web/controller/filesystem_download_test.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2025 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package controller diff --git a/components/execd/pkg/web/controller/filesystem_test.go b/components/execd/pkg/web/controller/filesystem_test.go index 15bf8c43..804e60df 100644 --- a/components/execd/pkg/web/controller/filesystem_test.go +++ b/components/execd/pkg/web/controller/filesystem_test.go @@ -25,6 +25,7 @@ import ( "testing" "github.com/alibaba/opensandbox/execd/pkg/web/model" + "github.com/stretchr/testify/require" ) func newFilesystemController(t *testing.T, method, rawURL string, body []byte) (*FilesystemController, *httptest.ResponseRecorder) { @@ -37,65 +38,45 @@ func newFilesystemController(t *testing.T, method, rawURL string, body []byte) ( func TestFilesystemControllerGetFilesInfo(t *testing.T) { tmpDir := t.TempDir() target := filepath.Join(tmpDir, "foo.txt") - if err := os.WriteFile(target, []byte("demo"), 0o644); err != nil { - t.Fatalf("write temp file: %v", err) - } + require.NoError(t, os.WriteFile(target, []byte("demo"), 0o644)) query := fmt.Sprintf("/files/info?path=%s", url.QueryEscape(target)) ctrl, rec := newFilesystemController(t, http.MethodGet, query, nil) ctrl.GetFilesInfo() - if rec.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", rec.Code) - } + require.Equal(t, http.StatusOK, rec.Code) var resp map[string]model.FileInfo - if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { - t.Fatalf("decode response: %v", err) - } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) info, ok := resp[target] - if !ok { - t.Fatalf("response missing entry for %s", target) - } - if info.Path == "" || info.Size == 0 { - t.Fatalf("unexpected file info: %#v", info) - } + require.True(t, ok, "response missing entry for %s", target) + require.NotEmpty(t, info.Path) + require.NotZero(t, info.Size) } func TestFilesystemControllerSearchFiles(t *testing.T) { tmpDir := t.TempDir() a := filepath.Join(tmpDir, "alpha.txt") b := filepath.Join(tmpDir, "beta.log") - if err := os.WriteFile(a, []byte("alpha"), 0o644); err != nil { - t.Fatalf("write alpha: %v", err) - } - if err := os.WriteFile(b, []byte("beta"), 0o644); err != nil { - t.Fatalf("write beta: %v", err) - } + require.NoError(t, os.WriteFile(a, []byte("alpha"), 0o644)) + require.NoError(t, os.WriteFile(b, []byte("beta"), 0o644)) rawURL := fmt.Sprintf("/files/search?path=%s&pattern=%s", url.QueryEscape(tmpDir), url.QueryEscape("*.txt")) ctrl, rec := newFilesystemController(t, http.MethodGet, rawURL, nil) ctrl.SearchFiles() - if rec.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", rec.Code) - } + require.Equal(t, http.StatusOK, rec.Code) var files []model.FileInfo - if err := json.Unmarshal(rec.Body.Bytes(), &files); err != nil { - t.Fatalf("decode response: %v", err) - } - if len(files) != 1 || files[0].Path != a { - t.Fatalf("expected only %s, got %#v", a, files) - } + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &files)) + require.Len(t, files, 1) + require.Equal(t, a, files[0].Path) } func TestFilesystemControllerReplaceContent(t *testing.T) { tmpDir := t.TempDir() target := filepath.Join(tmpDir, "content.txt") - if err := os.WriteFile(target, []byte("hello world"), 0o644); err != nil { - t.Fatalf("write temp file: %v", err) - } + require.NoError(t, os.WriteFile(target, []byte("hello world"), 0o644)) body, err := json.Marshal(map[string]model.ReplaceFileContentItem{ target: { @@ -103,24 +84,16 @@ func TestFilesystemControllerReplaceContent(t *testing.T) { New: "universe", }, }) - if err != nil { - t.Fatalf("marshal body: %v", err) - } + require.NoError(t, err) ctrl, rec := newFilesystemController(t, http.MethodPost, "/files/replace", body) ctrl.ReplaceContent() - if rec.Code != http.StatusOK { - t.Fatalf("expected status 200, got %d", rec.Code) - } + require.Equal(t, http.StatusOK, rec.Code) data, err := os.ReadFile(target) - if err != nil { - t.Fatalf("read file: %v", err) - } - if string(data) != "hello universe" { - t.Fatalf("unexpected content: %s", string(data)) - } + require.NoError(t, err) + require.Equal(t, "hello universe", string(data)) } func TestFilesystemControllerSearchFilesHandlesAbsentDir(t *testing.T) { @@ -129,9 +102,7 @@ func TestFilesystemControllerSearchFilesHandlesAbsentDir(t *testing.T) { ctrl.SearchFiles() - if rec.Code != http.StatusNotFound { - t.Fatalf("expected 404, got %d", rec.Code) - } + require.Equal(t, http.StatusNotFound, rec.Code) } func TestReplaceContentFailsUnknownFile(t *testing.T) { @@ -145,7 +116,5 @@ func TestReplaceContentFailsUnknownFile(t *testing.T) { ctrl.ReplaceContent() - if rec.Code != http.StatusNotFound && rec.Code != http.StatusInternalServerError { - t.Fatalf("expected failure status, got %d", rec.Code) - } + require.Contains(t, []int{http.StatusNotFound, http.StatusInternalServerError}, rec.Code, "expected failure status") } diff --git a/components/execd/pkg/web/controller/filesystem_upload_test.go b/components/execd/pkg/web/controller/filesystem_upload_test.go deleted file mode 100644 index cb3bd469..00000000 --- a/components/execd/pkg/web/controller/filesystem_upload_test.go +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2025 Alibaba Group Holding Ltd. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package controller diff --git a/components/execd/pkg/web/controller/utils_test.go b/components/execd/pkg/web/controller/utils_test.go index 92a969f7..3f1a2182 100644 --- a/components/execd/pkg/web/controller/utils_test.go +++ b/components/execd/pkg/web/controller/utils_test.go @@ -21,54 +21,38 @@ import ( "testing" "github.com/alibaba/opensandbox/execd/pkg/web/model" + "github.com/stretchr/testify/require" ) func TestDeleteFile(t *testing.T) { tmp := t.TempDir() file := filepath.Join(tmp, "sample.txt") - if err := os.WriteFile(file, []byte("hello"), 0o644); err != nil { - t.Fatalf("write file: %v", err) - } + require.NoError(t, os.WriteFile(file, []byte("hello"), 0o644)) - if err := DeleteFile(file); err != nil { - t.Fatalf("DeleteFile returned error: %v", err) - } - if _, err := os.Stat(file); !os.IsNotExist(err) { - t.Fatalf("expected file removed, got err=%v", err) - } + require.NoError(t, DeleteFile(file)) + _, err := os.Stat(file) + require.True(t, os.IsNotExist(err), "expected file removed, got err=%v", err) // removing a non-existent file should be a no-op - if err := DeleteFile(file); err != nil { - t.Fatalf("expected no error deleting missing file, got %v", err) - } + require.NoError(t, DeleteFile(file), "expected no error deleting missing file") } func TestRenameFile(t *testing.T) { tmp := t.TempDir() src := filepath.Join(tmp, "src.txt") - if err := os.WriteFile(src, []byte("data"), 0o644); err != nil { - t.Fatalf("write file: %v", err) - } + require.NoError(t, os.WriteFile(src, []byte("data"), 0o644)) dst := filepath.Join(tmp, "nested", "renamed.txt") - if err := RenameFile(model.RenameFileItem{Src: src, Dest: dst}); err != nil { - t.Fatalf("RenameFile returned error: %v", err) - } + require.NoError(t, RenameFile(model.RenameFileItem{Src: src, Dest: dst})) - if _, err := os.Stat(dst); err != nil { - t.Fatalf("expected destination file, got %v", err) - } - if _, err := os.Stat(src); !os.IsNotExist(err) { - t.Fatalf("expected source removed, got err=%v", err) - } + _, err := os.Stat(dst) + require.NoError(t, err) + _, err = os.Stat(src) + require.True(t, os.IsNotExist(err), "expected source removed, got err=%v", err) // destination exists -> expect error - if err := os.WriteFile(src, []byte("data"), 0o644); err != nil { - t.Fatalf("rewrite src: %v", err) - } - if err := RenameFile(model.RenameFileItem{Src: src, Dest: dst}); err == nil { - t.Fatalf("expected error when destination already exists") - } + require.NoError(t, os.WriteFile(src, []byte("data"), 0o644)) + require.Error(t, RenameFile(model.RenameFileItem{Src: src, Dest: dst}), "expected error when destination already exists") } func TestSearchFileMetadata(t *testing.T) { @@ -78,16 +62,12 @@ func TestSearchFileMetadata(t *testing.T) { } path, info, ok := SearchFileMetadata(metadata, "/any/notes.txt") - if !ok { - t.Fatalf("expected metadata entry") - } - if path != "/tmp/a/notes.txt" || info.Path != "/tmp/a/notes.txt" { - t.Fatalf("unexpected match path=%s info=%v", path, info) - } + require.True(t, ok, "expected metadata entry") + require.Equal(t, "/tmp/a/notes.txt", path) + require.Equal(t, "/tmp/a/notes.txt", info.Path) - if _, _, ok := SearchFileMetadata(metadata, "/foo/unknown.txt"); ok { - t.Fatalf("expected no match") - } + _, _, ok = SearchFileMetadata(metadata, "/foo/unknown.txt") + require.False(t, ok, "expected no match") } func TestParseRange(t *testing.T) { @@ -122,17 +102,11 @@ func TestParseRange(t *testing.T) { t.Run(tt.name, func(t *testing.T) { got, err := ParseRange(tt.header, tt.size) if tt.expectErr { - if err == nil { - t.Fatalf("expected error") - } + require.Error(t, err) return } - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if !reflect.DeepEqual(got, tt.want) { - t.Fatalf("got %+v want %+v", got, tt.want) - } + require.NoError(t, err) + require.True(t, reflect.DeepEqual(got, tt.want), "got %+v want %+v", got, tt.want) }) } } diff --git a/components/execd/pkg/web/model/codeinterpreting_test.go b/components/execd/pkg/web/model/codeinterpreting_test.go index 46999470..f0903bf0 100644 --- a/components/execd/pkg/web/model/codeinterpreting_test.go +++ b/components/execd/pkg/web/model/codeinterpreting_test.go @@ -20,44 +20,33 @@ import ( "testing" "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" + "github.com/stretchr/testify/require" ) func TestRunCodeRequestValidate(t *testing.T) { req := RunCodeRequest{ Code: "print('hi')", } - if err := req.Validate(); err != nil { - t.Fatalf("expected validation success: %v", err) - } + require.NoError(t, req.Validate()) req.Code = "" - if err := req.Validate(); err == nil { - t.Fatalf("expected validation error when code is empty") - } + require.Error(t, req.Validate(), "expected validation error when code is empty") } func TestRunCommandRequestValidate(t *testing.T) { req := RunCommandRequest{Command: "ls"} - if err := req.Validate(); err != nil { - t.Fatalf("expected command validation success: %v", err) - } + require.NoError(t, req.Validate(), "expected command validation success") req.TimeoutMs = -100 - if err := req.Validate(); err == nil { - t.Fatalf("expected validation error when timeout is negative") - } + require.Error(t, req.Validate(), "expected validation error when timeout is negative") req.TimeoutMs = 0 req.Command = "ls" - if err := req.Validate(); err != nil { - t.Fatalf("expected success when timeout is omitted/zero: %v", err) - } + require.NoError(t, req.Validate(), "expected success when timeout is omitted/zero") req.TimeoutMs = 10 req.Command = "" - if err := req.Validate(); err == nil { - t.Fatalf("expected validation error when command is empty") - } + require.Error(t, req.Validate(), "expected validation error when command is empty") } func ptr32(v uint32) *uint32 { return &v } @@ -65,21 +54,15 @@ func ptr32(v uint32) *uint32 { return &v } func TestRunCommandRequestValidateUidGid(t *testing.T) { // uid-only: valid req := RunCommandRequest{Command: "id", Uid: ptr32(1000)} - if err := req.Validate(); err != nil { - t.Fatalf("expected success with uid only: %v", err) - } + require.NoError(t, req.Validate(), "expected success with uid only") // uid + gid: valid req = RunCommandRequest{Command: "id", Uid: ptr32(1000), Gid: ptr32(1000)} - if err := req.Validate(); err != nil { - t.Fatalf("expected success with uid and gid: %v", err) - } + require.NoError(t, req.Validate(), "expected success with uid and gid") // gid-only: must be rejected req = RunCommandRequest{Command: "id", Gid: ptr32(1000)} - if err := req.Validate(); err == nil { - t.Fatalf("expected validation error when gid is set without uid") - } + require.Error(t, req.Validate(), "expected validation error when gid is set without uid") } func TestServerStreamEventToJSON(t *testing.T) { @@ -91,12 +74,10 @@ func TestServerStreamEventToJSON(t *testing.T) { data := event.ToJSON() var decoded ServerStreamEvent - if err := json.Unmarshal(data, &decoded); err != nil { - t.Fatalf("failed to unmarshal event: %v", err) - } - if decoded.Type != event.Type || decoded.Text != event.Text || decoded.ExecutionCount != event.ExecutionCount { - t.Fatalf("unexpected decoded event: %#v", decoded) - } + require.NoError(t, json.Unmarshal(data, &decoded)) + require.Equal(t, event.Type, decoded.Type) + require.Equal(t, event.Text, decoded.Text) + require.Equal(t, event.ExecutionCount, decoded.ExecutionCount) } func TestServerStreamEventSummary(t *testing.T) { @@ -134,9 +115,7 @@ func TestServerStreamEventSummary(t *testing.T) { t.Run(tt.name, func(t *testing.T) { summary := tt.event.Summary() for _, want := range tt.contains { - if !strings.Contains(summary, want) { - t.Fatalf("summary missing %q, got: %s", want, summary) - } + require.Containsf(t, summary, want, "summary missing %q", want) } }) }