diff --git a/.github/workflows/real-e2e.yml b/.github/workflows/real-e2e.yml index 3e1c85d3..9d098b18 100644 --- a/.github/workflows/real-e2e.yml +++ b/.github/workflows/real-e2e.yml @@ -35,6 +35,8 @@ jobs: pip install uv - name: Run tests + env: + OPENSANDBOX_SANDBOX_DEFAULT_IMAGE: opensandbox/code-interpreter:latest run: | set -e diff --git a/components/execd/pkg/runtime/bash_session.go b/components/execd/pkg/runtime/bash_session.go new file mode 100644 index 00000000..58b8aeb3 --- /dev/null +++ b/components/execd/pkg/runtime/bash_session.go @@ -0,0 +1,465 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "bufio" + "context" + "errors" + "fmt" + "os" + "os/exec" + "sort" + "strconv" + "strings" + "time" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" + "github.com/alibaba/opensandbox/execd/pkg/log" + "github.com/google/uuid" +) + +const ( + envDumpStartMarker = "__ENV_DUMP_START__" + envDumpEndMarker = "__ENV_DUMP_END__" + exitMarkerPrefix = "__EXIT_CODE__:" + pwdMarkerPrefix = "__PWD__:" +) + +func (c *Controller) createBashSession(_ *CreateContextRequest) (string, error) { + session := newBashSession(nil) + if err := session.start(); err != nil { + return "", fmt.Errorf("failed to start bash session: %w", err) + } + + c.bashSessionClientMap.Store(session.config.Session, session) + log.Info("created bash session %s", session.config.Session) + return session.config.Session, nil +} + +func (c *Controller) runBashSession(_ context.Context, request *ExecuteCodeRequest) error { + if request.Context == "" { + if c.getDefaultLanguageSession(request.Language) == "" { + if err := c.createDefaultBashSession(); err != nil { + return err + } + } + } + + targetSessionID := request.Context + if targetSessionID == "" { + targetSessionID = c.getDefaultLanguageSession(request.Language) + } + + session := c.getBashSession(targetSessionID) + if session == nil { + return ErrContextNotFound + } + + return session.run(request) +} + +func (c *Controller) createDefaultBashSession() error { + session, err := c.createBashSession(&CreateContextRequest{}) + if err != nil { + return err + } + + c.setDefaultLanguageSession(Bash, session) + return nil +} + +func (c *Controller) getBashSession(sessionId string) *bashSession { + if v, ok := c.bashSessionClientMap.Load(sessionId); ok { + if s, ok := v.(*bashSession); ok { + return s + } + } + return nil +} + +func (c *Controller) closeBashSession(sessionId string) error { + session := c.getBashSession(sessionId) + if session == nil { + return ErrContextNotFound + } + + err := session.close() + if err != nil { + return err + } + + c.bashSessionClientMap.Delete(sessionId) + return nil +} + +// nolint:unused +func (c *Controller) listBashSessions() []string { + sessions := make([]string, 0) + c.bashSessionClientMap.Range(func(key, _ any) bool { + sessionID, _ := key.(string) + sessions = append(sessions, sessionID) + return true + }) + + return sessions +} + +// Session implementation (pipe-based, no PTY) +func newBashSession(config *bashSessionConfig) *bashSession { + if config == nil { + config = &bashSessionConfig{ + Session: uuidString(), + StartupTimeout: 5 * time.Second, + } + } + + env := make(map[string]string) + for _, kv := range os.Environ() { + if k, v, ok := splitEnvPair(kv); ok { + env[k] = v + } + } + + return &bashSession{ + config: config, + env: env, + } +} + +func (s *bashSession) start() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.started { + return errors.New("session already started") + } + + s.started = true + return nil +} + +//nolint:gocognit +func (s *bashSession) run(request *ExecuteCodeRequest) error { + s.mu.Lock() + if !s.started { + s.mu.Unlock() + return errors.New("session not started") + } + + envSnapshot := copyEnvMap(s.env) + + cwd := s.cwd + // override original cwd if specified + if request.Cwd != "" { + cwd = request.Cwd + } + sessionID := s.config.Session + s.mu.Unlock() + + startAt := time.Now() + if request.Hooks.OnExecuteInit != nil { + request.Hooks.OnExecuteInit(sessionID) + } + + wait := request.Timeout + if wait <= 0 { + wait = 24 * 3600 * time.Second // max to 24 hours + } + + ctx, cancel := context.WithTimeout(context.Background(), wait) + defer cancel() + + script := buildWrappedScript(request.Code, envSnapshot, cwd) + scriptFile, err := os.CreateTemp("", "execd_bash_*.sh") + if err != nil { + return fmt.Errorf("create script file: %w", err) + } + scriptPath := scriptFile.Name() + if _, err := scriptFile.WriteString(script); err != nil { + _ = scriptFile.Close() + return fmt.Errorf("write script file: %w", err) + } + if err := scriptFile.Close(); err != nil { + return fmt.Errorf("close script file: %w", err) + } + + cmd := exec.CommandContext(ctx, "bash", "--noprofile", "--norc", scriptPath) + // Do not pass envSnapshot via cmd.Env to avoid "argument list too long" when session env is large. + // Child inherits parent env (nil => default in Go). The script file already has "export K=V" for + // all session vars at the top, so the session environment is applied when the script runs. + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("stdout pipe: %w", err) + } + cmd.Stderr = cmd.Stdout + + if err := cmd.Start(); err != nil { + log.Error("start bash session failed: %v (command: %q)", err, request.Code) + return fmt.Errorf("start bash: %w", err) + } + + scanner := bufio.NewScanner(stdout) + scanner.Buffer(make([]byte, 0, 64*1024), 16*1024*1024) + + var ( + envLines []string + pwdLine string + exitCode *int + inEnv bool + ) + + for scanner.Scan() { + line := scanner.Text() + switch { + case line == envDumpStartMarker: + inEnv = true + case line == envDumpEndMarker: + inEnv = false + case strings.HasPrefix(line, exitMarkerPrefix): + if code, err := strconv.Atoi(strings.TrimPrefix(line, exitMarkerPrefix)); err == nil { + exitCode = &code //nolint:ineffassign + } + case strings.HasPrefix(line, pwdMarkerPrefix): + pwdLine = strings.TrimPrefix(line, pwdMarkerPrefix) + default: + if inEnv { + envLines = append(envLines, line) + continue + } + if request.Hooks.OnExecuteStdout != nil { + request.Hooks.OnExecuteStdout(line) + } + } + } + + scanErr := scanner.Err() + waitErr := cmd.Wait() + + if scanErr != nil { + log.Error("read stdout failed: %v (command: %q)", scanErr, request.Code) + return fmt.Errorf("read stdout: %w", scanErr) + } + + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + log.Error("timeout after %s while running command: %q", wait, request.Code) + return fmt.Errorf("timeout after %s while running command %q", wait, request.Code) + } + + if exitCode == nil && cmd.ProcessState != nil { + code := cmd.ProcessState.ExitCode() //nolint:staticcheck + exitCode = &code //nolint:ineffassign + } + + updatedEnv := parseExportDump(envLines) + s.mu.Lock() + if len(updatedEnv) > 0 { + s.env = updatedEnv + } + if pwdLine != "" { + s.cwd = pwdLine + } + s.mu.Unlock() + + var exitErr *exec.ExitError + if waitErr != nil && !errors.As(waitErr, &exitErr) { + log.Error("command wait failed: %v (command: %q)", waitErr, request.Code) + return waitErr + } + + userExitCode := 0 + if exitCode != nil { + userExitCode = *exitCode + } + + if userExitCode != 0 { + errMsg := fmt.Sprintf("command exited with code %d", userExitCode) + if waitErr != nil { + errMsg = waitErr.Error() + } + if request.Hooks.OnExecuteError != nil { + request.Hooks.OnExecuteError(&execute.ErrorOutput{ + EName: "CommandExecError", + EValue: strconv.Itoa(userExitCode), + Traceback: []string{errMsg}, + }) + } + log.Error("CommandExecError: %s (command: %q)", errMsg, request.Code) + return nil + } + + if request.Hooks.OnExecuteComplete != nil { + request.Hooks.OnExecuteComplete(time.Since(startAt)) + } + + return nil +} + +func buildWrappedScript(command string, env map[string]string, cwd string) string { + var b strings.Builder + + keys := make([]string, 0, len(env)) + for k := range env { + v := env[k] + if isValidEnvKey(k) && !envKeysNotPersisted[k] && len(v) <= maxPersistedEnvValueSize { + keys = append(keys, k) + } + } + sort.Strings(keys) + for _, k := range keys { + b.WriteString("export ") + b.WriteString(k) + b.WriteString("=") + b.WriteString(shellEscape(env[k])) + b.WriteString("\n") + } + + if cwd != "" { + b.WriteString("cd ") + b.WriteString(shellEscape(cwd)) + b.WriteString("\n") + } + + b.WriteString(command) + if !strings.HasSuffix(command, "\n") { + b.WriteString("\n") + } + + b.WriteString("__USER_EXIT_CODE__=$?\n") + b.WriteString("printf \"\\n%s\\n\" \"" + envDumpStartMarker + "\"\n") + b.WriteString("export -p\n") + b.WriteString("printf \"%s\\n\" \"" + envDumpEndMarker + "\"\n") + b.WriteString("printf \"" + pwdMarkerPrefix + "%s\\n\" \"$(pwd)\"\n") + b.WriteString("printf \"" + exitMarkerPrefix + "%s\\n\" \"$__USER_EXIT_CODE__\"\n") + b.WriteString("exit \"$__USER_EXIT_CODE__\"\n") + + return b.String() +} + +// envKeysNotPersisted are not carried across runs (prompt/display vars). +var envKeysNotPersisted = map[string]bool{ + "PS1": true, "PS2": true, "PS3": true, "PS4": true, + "PROMPT_COMMAND": true, +} + +// maxPersistedEnvValueSize caps single env value length as a safeguard. +const maxPersistedEnvValueSize = 8 * 1024 + +func parseExportDump(lines []string) map[string]string { + if len(lines) == 0 { + return nil + } + env := make(map[string]string, len(lines)) + for _, line := range lines { + k, v, ok := parseExportLine(line) + if !ok || envKeysNotPersisted[k] || len(v) > maxPersistedEnvValueSize { + continue + } + env[k] = v + } + return env +} + +func parseExportLine(line string) (string, string, bool) { + const prefix = "declare -x " + if !strings.HasPrefix(line, prefix) { + return "", "", false + } + rest := strings.TrimSpace(strings.TrimPrefix(line, prefix)) + if rest == "" { + return "", "", false + } + name, value := rest, "" + if eq := strings.Index(rest, "="); eq >= 0 { + name = rest[:eq] + raw := rest[eq+1:] + if unquoted, err := strconv.Unquote(raw); err == nil { + value = unquoted + } else { + value = strings.Trim(raw, `"`) + } + } + if !isValidEnvKey(name) { + return "", "", false + } + return name, value, true +} + +func shellEscape(value string) string { + return "'" + strings.ReplaceAll(value, "'", `'"'"'`) + "'" +} + +func isValidEnvKey(key string) bool { + if key == "" { + return false + } + + for i, r := range key { + if i == 0 { + if (r < 'A' || (r > 'Z' && r < 'a') || r > 'z') && r != '_' { + return false + } + continue + } + if (r < 'A' || (r > 'Z' && r < 'a') || r > 'z') && (r < '0' || r > '9') && r != '_' { + return false + } + } + + return true +} + +func copyEnvMap(src map[string]string) map[string]string { + if src == nil { + return map[string]string{} + } + + dst := make(map[string]string, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func splitEnvPair(kv string) (string, string, bool) { + parts := strings.SplitN(kv, "=", 2) + if len(parts) != 2 { + return "", "", false + } + if !isValidEnvKey(parts[0]) { + return "", "", false + } + return parts[0], parts[1], true +} + +func (s *bashSession) close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.started { + return nil + } + s.started = false + s.env = nil + s.cwd = "" + return nil +} + +func uuidString() string { + return uuid.New().String() +} diff --git a/components/execd/pkg/runtime/bash_session_test.go b/components/execd/pkg/runtime/bash_session_test.go new file mode 100644 index 00000000..2a6e7482 --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_test.go @@ -0,0 +1,613 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows +// +build !windows + +package runtime + +import ( + "context" + "fmt" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" +) + +func TestBashSession_NonZeroExitEmitsError(t *testing.T) { + if _, err := exec.LookPath("bash"); err != nil { + t.Skip("bash not found in PATH") + } + + c := NewController("", "") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var ( + sessionID string + stdoutLine string + errCh = make(chan *execute.ErrorOutput, 1) + completeCh = make(chan struct{}, 1) + ) + + req := &ExecuteCodeRequest{ + Language: Bash, + Code: `echo "before"; exit 7`, + Cwd: t.TempDir(), + Timeout: 5 * time.Second, + Hooks: ExecuteResultHook{ + OnExecuteInit: func(s string) { sessionID = s }, + OnExecuteStdout: func(s string) { stdoutLine = s }, + OnExecuteError: func(err *execute.ErrorOutput) { errCh <- err }, + OnExecuteComplete: func(_ time.Duration) { + completeCh <- struct{}{} + }, + }, + } + + if err := c.runBashSession(ctx, req); err != nil { + t.Fatalf("runBashSession returned error: %v", err) + } + + var gotErr *execute.ErrorOutput + select { + case gotErr = <-errCh: + case <-time.After(2 * time.Second): + t.Fatalf("expected error hook to be called") + } + + if gotErr == nil { + t.Fatalf("expected non-nil error output") + } + if gotErr.EName != "CommandExecError" || gotErr.EValue != "7" { + t.Fatalf("unexpected error payload: %+v", gotErr) + } + + if sessionID == "" { + t.Fatalf("expected session id to be set") + } + if stdoutLine != "before" { + t.Fatalf("unexpected stdout: %q", stdoutLine) + } + + select { + case <-completeCh: + t.Fatalf("did not expect completion hook on non-zero exit") + default: + } +} + +func TestBashSession_envAndExitCode(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + var ( + initCalls int + completeCalls int + stdoutLines []string + ) + + hooks := ExecuteResultHook{ + OnExecuteInit: func(ctx string) { + if ctx != session.config.Session { + t.Fatalf("unexpected session in OnExecuteInit: %s", ctx) + } + initCalls++ + }, + OnExecuteStdout: func(text string) { + t.Log(text) + stdoutLines = append(stdoutLines, text) + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + } + + // 1) export an env var + request := &ExecuteCodeRequest{ + Code: "export FOO=hello", + Hooks: hooks, + Timeout: 3 * time.Second, + } + if err := session.run(request); err != nil { + t.Fatalf("runCommand(export) error = %v", err) + } + exportStdoutCount := len(stdoutLines) + + // 2) verify env is persisted + request = &ExecuteCodeRequest{ + Code: "echo $FOO", + Hooks: hooks, + Timeout: 3 * time.Second, + } + if err := session.run(request); err != nil { + t.Fatalf("runCommand(echo) error = %v", err) + } + echoLines := stdoutLines[exportStdoutCount:] + foundHello := false + for _, line := range echoLines { + if strings.TrimSpace(line) == "hello" { + foundHello = true + break + } + } + if !foundHello { + t.Fatalf("expected echo $FOO to output 'hello', got %v", echoLines) + } + + // 3) ensure exit code of previous command is reflected in shell state + request = &ExecuteCodeRequest{ + Code: "false; echo EXIT:$?", + Hooks: hooks, + Timeout: 3 * time.Second, + } + prevCount := len(stdoutLines) + if err := session.run(request); err != nil { + t.Fatalf("runCommand(exitcode) error = %v", err) + } + exitLines := stdoutLines[prevCount:] + foundExit := false + for _, line := range exitLines { + if strings.Contains(line, "EXIT:1") { + foundExit = true + break + } + } + if !foundExit { + t.Fatalf("expected exit code output 'EXIT:1', got %v", exitLines) + } + + if initCalls != 3 { + t.Fatalf("OnExecuteInit expected 3 calls, got %d", initCalls) + } + if completeCalls != 3 { + t.Fatalf("OnExecuteComplete expected 3 calls, got %d", completeCalls) + } +} + +func TestBashSession_envLargeOutputChained(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + var ( + initCalls int + completeCalls int + stdoutLines []string + ) + + hooks := ExecuteResultHook{ + OnExecuteInit: func(ctx string) { + if ctx != session.config.Session { + t.Fatalf("unexpected session in OnExecuteInit: %s", ctx) + } + initCalls++ + }, + OnExecuteStdout: func(text string) { + t.Log(text) + stdoutLines = append(stdoutLines, text) + }, + OnExecuteComplete: func(_ time.Duration) { + completeCalls++ + }, + } + + runAndCollect := func(cmd string) []string { + start := len(stdoutLines) + request := &ExecuteCodeRequest{ + Code: cmd, + Hooks: hooks, + Timeout: 10 * time.Second, + } + if err := session.run(request); err != nil { + t.Fatalf("runCommand(%q) error = %v", cmd, err) + } + return append([]string(nil), stdoutLines[start:]...) + } + + lines1 := runAndCollect("export FOO=hello1; for i in $(seq 1 60); do echo A${i}:$FOO; done") + if len(lines1) < 60 { + t.Fatalf("expected >=60 lines for cmd1, got %d", len(lines1)) + } + if !containsLine(lines1, "A1:hello1") || !containsLine(lines1, "A60:hello1") { + t.Fatalf("env not reflected in cmd1 output, got %v", lines1[:3]) + } + + lines2 := runAndCollect("export FOO=${FOO}_next; export BAR=bar1; for i in $(seq 1 60); do echo B${i}:$FOO:$BAR; done") + if len(lines2) < 60 { + t.Fatalf("expected >=60 lines for cmd2, got %d", len(lines2)) + } + if !containsLine(lines2, "B1:hello1_next:bar1") || !containsLine(lines2, "B60:hello1_next:bar1") { + t.Fatalf("env not propagated to cmd2 output, sample %v", lines2[:3]) + } + + lines3 := runAndCollect("export BAR=${BAR}_last; for i in $(seq 1 60); do echo C${i}:$FOO:$BAR; done; echo FINAL_FOO=$FOO; echo FINAL_BAR=$BAR") + if len(lines3) < 62 { // 60 lines + 2 finals + t.Fatalf("expected >=62 lines for cmd3, got %d", len(lines3)) + } + if !containsLine(lines3, "C1:hello1_next:bar1_last") || !containsLine(lines3, "C60:hello1_next:bar1_last") { + t.Fatalf("env not propagated to cmd3 output, sample %v", lines3[:3]) + } + if !containsLine(lines3, "FINAL_FOO=hello1_next") || !containsLine(lines3, "FINAL_BAR=bar1_last") { + t.Fatalf("final env lines missing, got %v", lines3[len(lines3)-5:]) + } + + if initCalls != 3 { + t.Fatalf("OnExecuteInit expected 3 calls, got %d", initCalls) + } + if completeCalls != 3 { + t.Fatalf("OnExecuteComplete expected 3 calls, got %d", completeCalls) + } +} + +func TestBashSession_cwdPersistsWithoutOverride(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + targetDir := t.TempDir() + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + runAndCollect := func(req *ExecuteCodeRequest) []string { + start := len(stdoutLines) + if err := session.run(req); err != nil { + t.Fatalf("runCommand(%q) error = %v", req.Code, err) + } + return append([]string(nil), stdoutLines[start:]...) + } + + firstRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: fmt.Sprintf("cd %s\npwd", targetDir), + Hooks: hooks, + Timeout: 3 * time.Second, + }) + if !containsLine(firstRunLines, targetDir) { + t.Fatalf("expected cd to update cwd to %q, got %v", targetDir, firstRunLines) + } + + secondRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: "pwd", + Hooks: hooks, + Timeout: 3 * time.Second, + }) + if !containsLine(secondRunLines, targetDir) { + t.Fatalf("expected subsequent run to inherit cwd %q, got %v", targetDir, secondRunLines) + } + + session.mu.Lock() + finalCwd := session.cwd + session.mu.Unlock() + if finalCwd != targetDir { + t.Fatalf("expected session cwd to stay at %q, got %q", targetDir, finalCwd) + } +} + +func TestBashSession_requestCwdOverridesAfterCd(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + initialDir := t.TempDir() + overrideDir := t.TempDir() + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + runAndCollect := func(req *ExecuteCodeRequest) []string { + start := len(stdoutLines) + if err := session.run(req); err != nil { + t.Fatalf("runCommand(%q) error = %v", req.Code, err) + } + return append([]string(nil), stdoutLines[start:]...) + } + + // First request: change session cwd via script. + firstRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: fmt.Sprintf("cd %s\npwd", initialDir), + Hooks: hooks, + Timeout: 3 * time.Second, + }) + if !containsLine(firstRunLines, initialDir) { + t.Fatalf("expected cd to update cwd to %q, got %v", initialDir, firstRunLines) + } + + // Second request: explicit Cwd overrides session cwd. + secondRunLines := runAndCollect(&ExecuteCodeRequest{ + Code: "pwd", + Cwd: overrideDir, + Hooks: hooks, + Timeout: 3 * time.Second, + }) + if !containsLine(secondRunLines, overrideDir) { + t.Fatalf("expected command to run in override cwd %q, got %v", overrideDir, secondRunLines) + } + + session.mu.Lock() + finalCwd := session.cwd + session.mu.Unlock() + if finalCwd != overrideDir { + t.Fatalf("expected session cwd updated to override dir %q, got %q", overrideDir, finalCwd) + } +} + +func TestBashSession_envDumpNotLeakedWhenNoTrailingNewline(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + request := &ExecuteCodeRequest{ + Code: `set +x; printf '{"foo":1}'`, + Hooks: hooks, + Timeout: 3 * time.Second, + } + + if err := session.run(request); err != nil { + t.Fatalf("runCommand(no-trailing-newline) error = %v", err) + } + + if len(stdoutLines) != 1 { + t.Fatalf("expected exactly one stdout line, got %v", stdoutLines) + } + if strings.TrimSpace(stdoutLines[0]) != `{"foo":1}` { + t.Fatalf("unexpected stdout content %q", stdoutLines[0]) + } + for _, line := range stdoutLines { + if strings.Contains(line, envDumpStartMarker) || strings.Contains(line, "declare -x") { + t.Fatalf("env dump leaked into stdout: %v", stdoutLines) + } + } +} + +func TestBashSession_envDumpNotLeakedWhenNoOutput(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + request := &ExecuteCodeRequest{ + Code: `set +x; true`, + Hooks: hooks, + Timeout: 3 * time.Second, + } + + if err := session.run(request); err != nil { + t.Fatalf("runCommand(no-output) error = %v", err) + } + + if len(stdoutLines) > 1 { + t.Fatalf("expected at most one stdout line, got %v", stdoutLines) + } + if len(stdoutLines) == 1 && strings.TrimSpace(stdoutLines[0]) != "" { + t.Fatalf("expected empty stdout, got %q", stdoutLines[0]) + } + for _, line := range stdoutLines { + if strings.Contains(line, envDumpStartMarker) || strings.Contains(line, "declare -x") { + t.Fatalf("env dump leaked into stdout: %v", stdoutLines) + } + } +} + +func TestBashSession_heredoc(t *testing.T) { + rewardDir := t.TempDir() + controller := NewController("", "") + + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + fmt.Printf("[stdout] %s\n", line) + }, + OnExecuteComplete: func(d time.Duration) { + fmt.Printf("[complete] %s\n", d) + }, + } + + // First run: heredoc + reward file write. + script := fmt.Sprintf(` +set -x +reward_dir=%q +mkdir -p "$reward_dir" + +cat > /tmp/repro_script.sh <<'SHEOF' +#!/usr/bin/env sh +echo "hello heredoc" +SHEOF + +chmod +x /tmp/repro_script.sh +/tmp/repro_script.sh +echo "after heredoc" +echo 1 > "$reward_dir/reward.txt" +cat "$reward_dir/reward.txt" +`, rewardDir) + + if err := controller.Execute(&ExecuteCodeRequest{ + Language: Bash, + Timeout: 10 * time.Second, + Code: script, + Hooks: hooks, + }); err != nil { + fmt.Fprintf(os.Stderr, "first Execute failed: %v\n", err) + os.Exit(1) + } + + // Second run: ensure the session keeps working. + if err := controller.Execute(&ExecuteCodeRequest{ + Language: Bash, + Timeout: 5 * time.Second, + Code: "echo 'second command works'", + Hooks: hooks, + }); err != nil { + fmt.Fprintf(os.Stderr, "second Execute failed: %v\n", err) + os.Exit(1) + } +} + +func TestBashSession_execReplacesShell(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + script := ` +cat > /tmp/exec_child.sh <<'EOF' +echo "child says hi" +EOF +chmod +x /tmp/exec_child.sh +exec /tmp/exec_child.sh +` + + request := &ExecuteCodeRequest{ + Code: script, + Hooks: hooks, + Timeout: 5 * time.Second, + } + err := session.run(request) + if err != nil { + t.Fatalf("expected exec to complete without killing the session, got %v", err) + } + if !containsLine(stdoutLines, "child says hi") { + t.Fatalf("expected child output, got %v", stdoutLines) + } + + // Subsequent run should still work because we restart bash per run. + request = &ExecuteCodeRequest{ + Code: "echo still-alive", + Hooks: hooks, + Timeout: 2 * time.Second, + } + stdoutLines = nil + if err := session.run(request); err != nil { + t.Fatalf("expected run to succeed after exec replaced the shell, got %v", err) + } + if !containsLine(stdoutLines, "still-alive") { + t.Fatalf("expected follow-up output, got %v", stdoutLines) + } +} + +func TestBashSession_complexExec(t *testing.T) { + session := newBashSession(nil) + t.Cleanup(func() { _ = session.close() }) + + if err := session.start(); err != nil { + t.Fatalf("Start() error = %v", err) + } + + var stdoutLines []string + hooks := ExecuteResultHook{ + OnExecuteStdout: func(line string) { + stdoutLines = append(stdoutLines, line) + }, + } + + script := ` +LOG_FILE=$(mktemp) +export LOG_FILE +exec 3>&1 4>&2 +exec > >(tee "$LOG_FILE") 2>&1 + +set -x +echo "from-complex-exec" +exec 1>&3 2>&4 # step record +echo "after-restore" +` + + request := &ExecuteCodeRequest{ + Code: script, + Hooks: hooks, + Timeout: 5 * time.Second, + } + err := session.run(request) + if err != nil { + t.Fatalf("expected complex exec to finish, got %v", err) + } + if !containsLine(stdoutLines, "from-complex-exec") || !containsLine(stdoutLines, "after-restore") { + t.Fatalf("expected exec outputs, got %v", stdoutLines) + } + + // Session should still be usable. + request = &ExecuteCodeRequest{ + Code: "echo still-alive", + Hooks: hooks, + Timeout: 2 * time.Second, + } + stdoutLines = nil + if err := session.run(request); err != nil { + t.Fatalf("expected run to succeed after complex exec, got %v", err) + } + if !containsLine(stdoutLines, "still-alive") { + t.Fatalf("expected follow-up output, got %v", stdoutLines) + } +} + +func containsLine(lines []string, target string) bool { + for _, l := range lines { + if strings.TrimSpace(l) == target { + return true + } + } + return false +} diff --git a/components/execd/pkg/runtime/bash_session_windows.go b/components/execd/pkg/runtime/bash_session_windows.go new file mode 100644 index 00000000..9b1bac42 --- /dev/null +++ b/components/execd/pkg/runtime/bash_session_windows.go @@ -0,0 +1,67 @@ +// Copyright 2026 Alibaba Group Holding Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows +// +build windows + +package runtime + +import ( + "context" + "errors" + "time" +) + +var errBashSessionNotSupported = errors.New("bash session is not supported on windows") + +func (c *Controller) createBashSession(_ *CreateContextRequest) (string, error) { + return "", errBashSessionNotSupported +} + +func (c *Controller) runBashSession(_ context.Context, _ *ExecuteCodeRequest) error { //nolint:revive + return errBashSessionNotSupported +} + +func (c *Controller) createDefaultBashSession() error { //nolint:revive + return errBashSessionNotSupported +} + +func (c *Controller) getBashSession(_ string) (*bashSession, error) { //nolint:revive + return nil, errBashSessionNotSupported +} + +func (c *Controller) closeBashSession(_ string) error { //nolint:revive + return errBashSessionNotSupported +} + +func (c *Controller) listBashSessions() []string { //nolint:revive + return nil +} + +// Stub methods on bashSession to satisfy interfaces on non-Linux platforms. +func newBashSession(config *bashSessionConfig) *bashSession { + return &bashSession{config: config} +} + +func (s *bashSession) start() error { + return errBashSessionNotSupported +} + +func (s *bashSession) run(_ string, _ time.Duration, _ *ExecuteResultHook) error { + return errBashSessionNotSupported +} + +func (s *bashSession) close() error { + return nil +} diff --git a/components/execd/pkg/runtime/command_common.go b/components/execd/pkg/runtime/command_common.go index 633efa35..4f49ebbc 100644 --- a/components/execd/pkg/runtime/command_common.go +++ b/components/execd/pkg/runtime/command_common.go @@ -45,18 +45,17 @@ func (c *Controller) tailStdPipe(file string, onExecute func(text string), done // getCommandKernel retrieves a command execution context. func (c *Controller) getCommandKernel(sessionID string) *commandKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.commandClientMap[sessionID] + if v, ok := c.commandClientMap.Load(sessionID); ok { + if kernel, ok := v.(*commandKernel); ok { + return kernel + } + } + return nil } // storeCommandKernel registers a command execution context. func (c *Controller) storeCommandKernel(sessionID string, kernel *commandKernel) { - c.mu.Lock() - defer c.mu.Unlock() - - c.commandClientMap[sessionID] = kernel + c.commandClientMap.Store(sessionID, kernel) } // stdLogDescriptor creates temporary files for capturing command output. diff --git a/components/execd/pkg/runtime/command_status.go b/components/execd/pkg/runtime/command_status.go index 97f112b1..6dbc6d4f 100644 --- a/components/execd/pkg/runtime/command_status.go +++ b/components/execd/pkg/runtime/command_status.go @@ -40,11 +40,11 @@ type CommandOutput struct { } func (c *Controller) commandSnapshot(session string) *commandKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - kernel, ok := c.commandClientMap[session] - if !ok || kernel == nil { + var kernel *commandKernel + if v, ok := c.commandClientMap.Load(session); ok { + kernel, _ = v.(*commandKernel) + } + if kernel == nil { return nil } @@ -116,8 +116,11 @@ func (c *Controller) markCommandFinished(session string, exitCode int, errMsg st c.mu.Lock() defer c.mu.Unlock() - kernel, ok := c.commandClientMap[session] - if !ok || kernel == nil { + var kernel *commandKernel + if v, ok := c.commandClientMap.Load(session); ok { + kernel, _ = v.(*commandKernel) + } + if kernel == nil { return } diff --git a/components/execd/pkg/runtime/context.go b/components/execd/pkg/runtime/context.go index a1135507..6e7ea870 100644 --- a/components/execd/pkg/runtime/context.go +++ b/components/execd/pkg/runtime/context.go @@ -32,6 +32,11 @@ import ( // CreateContext provisions a kernel-backed session and returns its ID. func (c *Controller) CreateContext(req *CreateContextRequest) (string, error) { + if req.Language == Bash { + return c.createBashSession(req) + } + + // Create a new Jupyter session. var ( client *jupyter.Client session *jupytersession.Session @@ -42,7 +47,7 @@ func (c *Controller) CreateContext(req *CreateContextRequest) (string, error) { log.Error("failed to create session, retrying: %v", err) return err != nil }, func() error { - client, session, err = c.createContext(*req) + client, session, err = c.createJupyterContext(*req) return err }) if err != nil { @@ -116,15 +121,8 @@ func (c *Controller) deleteSessionAndCleanup(session string) error { return err } - c.mu.Lock() - defer c.mu.Unlock() - - delete(c.jupyterClientMap, session) - for lang, id := range c.defaultLanguageJupyterSessions { - if id == session { - delete(c.defaultLanguageJupyterSessions, lang) - } - } + c.jupyterClientMap.Delete(session) + c.deleteDefaultSessionByID(session) return nil } @@ -143,8 +141,12 @@ func (c *Controller) newIpynbPath(sessionID, cwd string) (string, error) { return filepath.Join(cwd, fmt.Sprintf("%s.ipynb", sessionID)), nil } -// createDefaultLanguageContext prewarms a session for stateless execution. -func (c *Controller) createDefaultLanguageContext(language Language) error { +// createDefaultLanguageJupyterContext prewarms a session for stateless execution. +func (c *Controller) createDefaultLanguageJupyterContext(language Language) error { + if c.getDefaultLanguageSession(language) != "" { + return nil + } + var ( client *jupyter.Client session *jupytersession.Session @@ -154,7 +156,7 @@ func (c *Controller) createDefaultLanguageContext(language Language) error { log.Error("failed to create context, retrying: %v", err) return err != nil }, func() error { - client, session, err = c.createContext(CreateContextRequest{ + client, session, err = c.createJupyterContext(CreateContextRequest{ Language: language, Cwd: "", }) @@ -164,20 +166,17 @@ func (c *Controller) createDefaultLanguageContext(language Language) error { return err } - c.mu.Lock() - defer c.mu.Unlock() - - c.defaultLanguageJupyterSessions[language] = session.ID - c.jupyterClientMap[session.ID] = &jupyterKernel{ + c.setDefaultLanguageSession(language, session.ID) + c.jupyterClientMap.Store(session.ID, &jupyterKernel{ kernelID: session.Kernel.ID, client: client, language: language, - } + }) return nil } -// createContext performs the actual context creation workflow. -func (c *Controller) createContext(request CreateContextRequest) (*jupyter.Client, *jupytersession.Session, error) { +// createJupyterContext performs the actual context creation workflow. +func (c *Controller) createJupyterContext(request CreateContextRequest) (*jupyter.Client, *jupytersession.Session, error) { client := c.jupyterClient() kernel, err := c.searchKernel(client, request.Language) @@ -217,10 +216,7 @@ func (c *Controller) createContext(request CreateContextRequest) (*jupyter.Clien // storeJupyterKernel caches a session -> kernel mapping. func (c *Controller) storeJupyterKernel(sessionID string, kernel *jupyterKernel) { - c.mu.Lock() - defer c.mu.Unlock() - - c.jupyterClientMap[sessionID] = kernel + c.jupyterClientMap.Store(sessionID, kernel) } func (c *Controller) jupyterClient() *jupyter.Client { @@ -236,49 +232,63 @@ func (c *Controller) jupyterClient() *jupyter.Client { jupyter.WithHTTPClient(httpClient)) } -func (c *Controller) listAllContexts() ([]CodeContext, error) { - c.mu.RLock() - defer c.mu.RUnlock() +func (c *Controller) getDefaultLanguageSession(language Language) string { + if v, ok := c.defaultLanguageSessions.Load(language); ok { + if session, ok := v.(string); ok { + return session + } + } + return "" +} + +func (c *Controller) setDefaultLanguageSession(language Language, sessionID string) { + c.defaultLanguageSessions.Store(language, sessionID) +} +func (c *Controller) deleteDefaultSessionByID(sessionID string) { + c.defaultLanguageSessions.Range(func(key, value any) bool { + if s, ok := value.(string); ok && s == sessionID { + c.defaultLanguageSessions.Delete(key) + } + return true + }) +} + +func (c *Controller) listAllContexts() ([]CodeContext, error) { contexts := make([]CodeContext, 0) - for session, kernel := range c.jupyterClientMap { - if kernel != nil { - contexts = append(contexts, CodeContext{ - ID: session, - Language: kernel.language, - }) + c.jupyterClientMap.Range(func(key, value any) bool { + session, _ := key.(string) + if kernel, ok := value.(*jupyterKernel); ok && kernel != nil { + contexts = append(contexts, CodeContext{ID: session, Language: kernel.language}) } - } + return true + }) - for language, defaultContext := range c.defaultLanguageJupyterSessions { - contexts = append(contexts, CodeContext{ - ID: defaultContext, - Language: language, - }) - } + c.defaultLanguageSessions.Range(func(key, value any) bool { + lang, _ := key.(Language) + session, _ := value.(string) + if session == "" { + return true + } + contexts = append(contexts, CodeContext{ID: session, Language: lang}) + return true + }) return contexts, nil } func (c *Controller) listLanguageContexts(language Language) ([]CodeContext, error) { - c.mu.RLock() - defer c.mu.RUnlock() - contexts := make([]CodeContext, 0) - for session, kernel := range c.jupyterClientMap { - if kernel != nil && kernel.language == language { - contexts = append(contexts, CodeContext{ - ID: session, - Language: language, - }) + c.jupyterClientMap.Range(func(key, value any) bool { + session, _ := key.(string) + if kernel, ok := value.(*jupyterKernel); ok && kernel != nil && kernel.language == language { + contexts = append(contexts, CodeContext{ID: session, Language: language}) } - } + return true + }) - if defaultContext := c.defaultLanguageJupyterSessions[language]; defaultContext != "" { - contexts = append(contexts, CodeContext{ - ID: defaultContext, - Language: language, - }) + if defaultContext := c.getDefaultLanguageSession(language); defaultContext != "" { + contexts = append(contexts, CodeContext{ID: defaultContext, Language: language}) } return contexts, nil diff --git a/components/execd/pkg/runtime/context_test.go b/components/execd/pkg/runtime/context_test.go index 6a27ad18..43efe81c 100644 --- a/components/execd/pkg/runtime/context_test.go +++ b/components/execd/pkg/runtime/context_test.go @@ -26,8 +26,9 @@ import ( func TestListContextsAndNewIpynbPath(t *testing.T) { c := NewController("http://example", "token") - c.jupyterClientMap["session-python"] = &jupyterKernel{language: Python} - c.defaultLanguageJupyterSessions[Go] = "session-go-default" + + c.jupyterClientMap.Store("session-python", &jupyterKernel{language: Python}) + c.setDefaultLanguageSession(Go, "session-go-default") pyContexts, err := c.listLanguageContexts(Python) if err != nil { @@ -128,8 +129,8 @@ func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { defer server.Close() c := NewController(server.URL, "token") - c.jupyterClientMap[sessionID] = &jupyterKernel{language: Python} - c.defaultLanguageJupyterSessions[Python] = sessionID + c.jupyterClientMap.Store(sessionID, &jupyterKernel{language: Python}) + c.setDefaultLanguageSession(Python, sessionID) if err := c.DeleteContext(sessionID); err != nil { t.Fatalf("DeleteContext returned error: %v", err) @@ -138,7 +139,7 @@ func TestDeleteContext_RemovesCacheOnSuccess(t *testing.T) { if kernel := c.getJupyterKernel(sessionID); kernel != nil { t.Fatalf("expected cache to be cleared, found: %+v", kernel) } - if _, ok := c.defaultLanguageJupyterSessions[Python]; ok { + if c.getDefaultLanguageSession(Python) != "" { t.Fatalf("expected default session entry to be removed") } } @@ -166,21 +167,21 @@ func TestDeleteLanguageContext_RemovesCacheOnSuccess(t *testing.T) { defer server.Close() c := NewController(server.URL, "token") - c.jupyterClientMap[session1] = &jupyterKernel{language: lang} - c.jupyterClientMap[session2] = &jupyterKernel{language: lang} - c.defaultLanguageJupyterSessions[lang] = session2 + c.jupyterClientMap.Store(session1, &jupyterKernel{language: lang}) + c.jupyterClientMap.Store(session2, &jupyterKernel{language: lang}) + c.setDefaultLanguageSession(lang, session2) if err := c.DeleteLanguageContext(lang); err != nil { t.Fatalf("DeleteLanguageContext returned error: %v", err) } - if _, ok := c.jupyterClientMap[session1]; ok { + if v, ok := c.jupyterClientMap.Load(session1); ok && v != nil { t.Fatalf("expected session1 removed from cache") } - if _, ok := c.jupyterClientMap[session2]; ok { + if v, ok := c.jupyterClientMap.Load(session2); ok && v != nil { t.Fatalf("expected session2 removed from cache") } - if _, ok := c.defaultLanguageJupyterSessions[lang]; ok { + if c.getDefaultLanguageSession(lang) != "" { t.Fatalf("expected default entry removed") } if deleteCalls[session1] != 1 || deleteCalls[session2] != 1 { diff --git a/components/execd/pkg/runtime/ctrl.go b/components/execd/pkg/runtime/ctrl.go index 20bbecc6..2bb1967b 100644 --- a/components/execd/pkg/runtime/ctrl.go +++ b/components/execd/pkg/runtime/ctrl.go @@ -35,14 +35,15 @@ var kernelWaitingBackoff = wait.Backoff{ // Controller manages code execution across runtimes. type Controller struct { - baseURL string - token string - mu sync.RWMutex - jupyterClientMap map[string]*jupyterKernel - defaultLanguageJupyterSessions map[Language]string - commandClientMap map[string]*commandKernel - db *sql.DB - dbOnce sync.Once + baseURL string + token string + mu sync.RWMutex + jupyterClientMap sync.Map // sessionID -> *jupyterKernel + defaultLanguageSessions sync.Map // Language -> sessionID + commandClientMap sync.Map // sessionID -> *commandKernel + bashSessionClientMap sync.Map // sessionID -> *bashSession + db *sql.DB + dbOnce sync.Once } type jupyterKernel struct { @@ -71,9 +72,10 @@ func NewController(baseURL, token string) *Controller { baseURL: baseURL, token: token, - jupyterClientMap: make(map[string]*jupyterKernel), - defaultLanguageJupyterSessions: make(map[Language]string), - commandClientMap: make(map[string]*commandKernel), + jupyterClientMap: sync.Map{}, + defaultLanguageSessions: sync.Map{}, + commandClientMap: sync.Map{}, + bashSessionClientMap: sync.Map{}, } } @@ -93,10 +95,12 @@ func (c *Controller) Execute(request *ExecuteCodeRequest) error { return c.runCommand(ctx, request) case BackgroundCommand: return c.runBackgroundCommand(ctx, request) - case Bash, Python, Java, JavaScript, TypeScript, Go: + case Python, Java, JavaScript, TypeScript, Go: return c.runJupyter(ctx, request) case SQL: return c.runSQL(ctx, request) + case Bash: + return c.runBashSession(ctx, request) default: return fmt.Errorf("unknown language: %s", request.Language) } diff --git a/components/execd/pkg/runtime/interrupt.go b/components/execd/pkg/runtime/interrupt.go index 1a9515fa..67902a3d 100644 --- a/components/execd/pkg/runtime/interrupt.go +++ b/components/execd/pkg/runtime/interrupt.go @@ -38,6 +38,8 @@ func (c *Controller) Interrupt(sessionID string) error { case c.getCommandKernel(sessionID) != nil: kernel := c.getCommandKernel(sessionID) return c.killPid(kernel.pid) + case c.getBashSession(sessionID) != nil: + return c.closeBashSession(sessionID) default: return errors.New("no such session") } diff --git a/components/execd/pkg/runtime/jupyter.go b/components/execd/pkg/runtime/jupyter.go index cdc0a6cc..9ea33b13 100644 --- a/components/execd/pkg/runtime/jupyter.go +++ b/components/execd/pkg/runtime/jupyter.go @@ -29,9 +29,8 @@ func (c *Controller) runJupyter(ctx context.Context, request *ExecuteCodeRequest return errors.New("language runtime server not configured, please check your image runtime") } if request.Context == "" { - if _, exists := c.defaultLanguageJupyterSessions[request.Language]; !exists { - err := c.createDefaultLanguageContext(request.Language) - if err != nil { + if c.getDefaultLanguageSession(request.Language) == "" { + if err := c.createDefaultLanguageJupyterContext(request.Language); err != nil { return err } } @@ -39,7 +38,7 @@ func (c *Controller) runJupyter(ctx context.Context, request *ExecuteCodeRequest var targetSessionID string if request.Context == "" { - targetSessionID = c.defaultLanguageJupyterSessions[request.Language] + targetSessionID = c.getDefaultLanguageSession(request.Language) } else { targetSessionID = request.Context } @@ -135,10 +134,12 @@ func (c *Controller) setWorkingDir(_ *jupyterKernel, _ *CreateContextRequest) er // getJupyterKernel retrieves a kernel connection from the session map. func (c *Controller) getJupyterKernel(sessionID string) *jupyterKernel { - c.mu.RLock() - defer c.mu.RUnlock() - - return c.jupyterClientMap[sessionID] + if v, ok := c.jupyterClientMap.Load(sessionID); ok { + if kernel, ok := v.(*jupyterKernel); ok { + return kernel + } + } + return nil } // searchKernel finds a kernel spec name for the given language. diff --git a/components/execd/pkg/runtime/types.go b/components/execd/pkg/runtime/types.go index cb82a11b..5dcf44c8 100644 --- a/components/execd/pkg/runtime/types.go +++ b/components/execd/pkg/runtime/types.go @@ -16,6 +16,7 @@ package runtime import ( "fmt" + "sync" "time" "github.com/alibaba/opensandbox/execd/pkg/jupyter/execute" @@ -80,3 +81,22 @@ type CodeContext struct { ID string `json:"id,omitempty"` Language Language `json:"language"` } + +// bashSessionConfig holds bash session configuration. +type bashSessionConfig struct { + // StartupSource is a list of scripts sourced on startup. + StartupSource []string + // Session is the session identifier. + Session string + // StartupTimeout is the startup timeout. + StartupTimeout time.Duration +} + +// bashSession represents a bash session. +type bashSession struct { + config *bashSessionConfig + mu sync.Mutex + started bool + env map[string]string + cwd string +} diff --git a/sandboxes/code-interpreter/Dockerfile b/sandboxes/code-interpreter/Dockerfile index 2a1aa013..280f1b30 100644 --- a/sandboxes/code-interpreter/Dockerfile +++ b/sandboxes/code-interpreter/Dockerfile @@ -24,7 +24,7 @@ RUN set -euo pipefail \ echo "Setting up ipykernel for Python $version" \ && . /opt/opensandbox/code-interpreter-env.sh python $version \ && python3 --version \ - && python3 -m pip install ipykernel jupyter bash_kernel --break-system-packages; \ + && python3 -m pip install ipykernel jupyter --break-system-packages; \ done \ && echo "Setting up ipykernel complete" diff --git a/sandboxes/code-interpreter/README.md b/sandboxes/code-interpreter/README.md index 2b7943a2..b17072e1 100644 --- a/sandboxes/code-interpreter/README.md +++ b/sandboxes/code-interpreter/README.md @@ -144,7 +144,6 @@ The image comes with pre-configured Jupyter kernels for all supported languages: - **Java**: IJava kernel - **TypeScript/JavaScript**: tslab kernel - **Go**: gonb kernel -- **Bash**: bash_kernel ### Starting Jupyter diff --git a/sandboxes/code-interpreter/README_zh.md b/sandboxes/code-interpreter/README_zh.md index e68b2584..9a9bfacb 100644 --- a/sandboxes/code-interpreter/README_zh.md +++ b/sandboxes/code-interpreter/README_zh.md @@ -142,7 +142,6 @@ source /opt/opensandbox/code-interpreter-env.sh go - **Java**:IJava 内核 - **TypeScript/JavaScript**:tslab 内核 - **Go**:gonb 内核 -- **Bash**:bash_kernel ### 启动 Jupyter diff --git a/sandboxes/code-interpreter/scripts/code-interpreter.sh b/sandboxes/code-interpreter/scripts/code-interpreter.sh index 11968c1d..6d03691c 100755 --- a/sandboxes/code-interpreter/scripts/code-interpreter.sh +++ b/sandboxes/code-interpreter/scripts/code-interpreter.sh @@ -93,12 +93,6 @@ setup_go() { } } -setup_bash() { - time { - python3 -m bash_kernel.install - } -} - # export go bin path export PATH="$(go env GOPATH)/bin:$PATH" if [ -n "${EXECD_ENVS:-}" ]; then @@ -114,7 +108,5 @@ setup_node & pids+=($!) setup_go & pids+=($!) -setup_bash & -pids+=($!) jupyter notebook --ip=127.0.0.1 --port="${JUPYTER_PORT:-44771}" --allow-root --no-browser --NotebookApp.token="${JUPYTER_TOKEN:-opensandboxcodeinterpreterjupyter}" >/opt/opensandbox/jupyter.log diff --git a/tests/javascript/tests/test_code_interpreter_e2e.test.ts b/tests/javascript/tests/test_code_interpreter_e2e.test.ts index 686fc407..6ae00a41 100644 --- a/tests/javascript/tests/test_code_interpreter_e2e.test.ts +++ b/tests/javascript/tests/test_code_interpreter_e2e.test.ts @@ -267,3 +267,58 @@ test("07 interrupt code execution + fake id", async () => { await expect(ci0.codes.interrupt(`fake-${Date.now()}`)).rejects.toBeTruthy(); }); + +test("08 bash env propagation across sequential executions", async () => { + if (!ci) throw new Error("not initialized"); + + const stdout: string[] = []; + const stderr: string[] = []; + const errors: string[] = []; + + const handlers: ExecutionHandlers = { + onStdout: (m) => { + if (m.text) stdout.push(m.text.trim()); + }, + onStderr: (m) => { + if (m.text) stderr.push(m.text.trim()); + }, + onError: (e) => { + errors.push(e.name); + }, + }; + + const code1 = "export FOO=hello\nexport BAR=world\n"; + const code2 = 'printf "step1:$FOO:$BAR\\n"\n'; + const code3 = + "export FOO=${FOO}_next\n" + + 'printf "step2:$FOO:$BAR\\n"\n' + + "export BAR=${BAR}_next\n" + + 'printf "step3:$FOO:$BAR\\n"\n'; + + const r1 = await ci.codes.run(code1, { + language: SupportedLanguages.BASH, + handlers, + }); + expect(r1.id).toBeTruthy(); + expect(r1.error).toBeUndefined(); + + const r2 = await ci.codes.run(code2, { + language: SupportedLanguages.BASH, + handlers, + }); + expect(r2.id).toBeTruthy(); + expect(r2.error).toBeUndefined(); + + const r3 = await ci.codes.run(code3, { + language: SupportedLanguages.BASH, + handlers, + }); + expect(r3.id).toBeTruthy(); + expect(r3.error).toBeUndefined(); + + expect(stdout).toContain("step1:hello:world"); + expect(stdout).toContain("step2:hello_next:world"); + expect(stdout).toContain("step3:hello_next:world_next"); + expect(errors).toHaveLength(0); + expect(stderr.filter((s) => s.length > 0)).toHaveLength(0); +}); diff --git a/tests/python/tests/test_code_interpreter_e2e.py b/tests/python/tests/test_code_interpreter_e2e.py index 9a2548de..88075f79 100644 --- a/tests/python/tests/test_code_interpreter_e2e.py +++ b/tests/python/tests/test_code_interpreter_e2e.py @@ -998,12 +998,12 @@ async def test_09_context_management_endpoints(self): code_interpreter = TestCodeInterpreterE2E.code_interpreter assert code_interpreter is not None - language = SupportedLanguage.BASH + language = SupportedLanguage.PYTHON logger.info("=" * 80) logger.info("TEST 9: Context management endpoints (%s)", language) logger.info("=" * 80) - # Ensure clean slate for bash contexts to avoid interference with other tests. + # Ensure clean slate for python contexts to avoid interference with other tests. await code_interpreter.codes.delete_contexts(language) ctx1 = await code_interpreter.codes.create_context(language) @@ -1012,14 +1012,14 @@ async def test_09_context_management_endpoints(self): assert ctx2.id is not None and ctx2.id.strip() assert ctx1.language == language assert ctx2.language == language - logger.info("✓ Created two bash contexts: %s, %s", ctx1.id, ctx2.id) + logger.info("✓ Created two python contexts: %s, %s", ctx1.id, ctx2.id) listed = await code_interpreter.codes.list_contexts(language) - bash_context_ids = {c.id for c in listed if c.id} - assert ctx1.id in bash_context_ids - assert ctx2.id in bash_context_ids + python_context_ids = {c.id for c in listed if c.id} + assert ctx1.id in python_context_ids + assert ctx2.id in python_context_ids assert all(c.language == language for c in listed) - logger.info("✓ list_contexts returned expected bash contexts") + logger.info("✓ list_contexts returned expected python contexts") fetched = await code_interpreter.codes.get_context(ctx1.id) assert fetched.id == ctx1.id @@ -1038,5 +1038,5 @@ async def test_09_context_management_endpoints(self): c for c in await code_interpreter.codes.list_contexts(language) if c.id ] assert len(final_contexts) == 0 - logger.info("✓ delete_contexts removed all bash contexts") + logger.info("✓ delete_contexts removed all python contexts") diff --git a/tests/python/tests/test_code_interpreter_e2e_sync.py b/tests/python/tests/test_code_interpreter_e2e_sync.py index 95d8f9b7..9bae9f2e 100644 --- a/tests/python/tests/test_code_interpreter_e2e_sync.py +++ b/tests/python/tests/test_code_interpreter_e2e_sync.py @@ -893,3 +893,95 @@ def test_09_context_management_endpoints(self): assert len(final_contexts) == 0 logger.info("✓ delete_contexts removed all bash contexts") + @pytest.mark.timeout(300) + @pytest.mark.order(10) + def test_10_bash_env_propagation(self): + """Ensure bash commands share env/vars across sequential executions.""" + TestCodeInterpreterE2ESync._ensure_code_interpreter_created() + code_interpreter = TestCodeInterpreterE2ESync.code_interpreter + assert code_interpreter is not None + + stdout_messages: list[OutputMessage] = [] + stderr_messages: list[OutputMessage] = [] + errors: list[ExecutionError] = [] + completed_events: list[ExecutionComplete] = [] + init_events: list[ExecutionInit] = [] + + def on_stdout(msg: OutputMessage): + stdout_messages.append(msg) + + def on_stderr(msg: OutputMessage): + stderr_messages.append(msg) + + def on_error(err: ExecutionError): + errors.append(err) + + def on_complete(evt: ExecutionComplete): + completed_events.append(evt) + + def on_init(evt: ExecutionInit): + init_events.append(evt) + + handlers = ExecutionHandlersSync( + on_stdout=on_stdout, + on_stderr=on_stderr, + on_result=None, + on_error=on_error, + on_execution_complete=on_complete, + on_init=on_init, + ) + + # Send three sequential commands in the same session, validating env propagation. + code1 = ( + "export FOO=hello\n" + "export BAR=world\n" + ) + code2 = ( + "printf \"step1:$FOO:$BAR\\n\"\n" + ) + code3 = ( + "export FOO=${FOO}_next\n" + "printf \"step2:$FOO:$BAR\\n\"\n" + "export BAR=${BAR}_next\n" + "printf \"step3:$FOO:$BAR\\n\"\n" + ) + + # export envs + result1 = code_interpreter.codes.run( + code1, + language=SupportedLanguage.BASH, + handlers=handlers, + ) + + assert result1 is not None + assert result1.id is not None and str(result1.id).strip() + assert result1.error is None + + # print env + result2 = code_interpreter.codes.run( + code2, + language=SupportedLanguage.BASH, + handlers=handlers, + ) + + assert result2 is not None + assert result2.id is not None and str(result2.id).strip() + assert result2.error is None + + # print env + result3 = code_interpreter.codes.run( + code3, + language=SupportedLanguage.BASH, + handlers=handlers, + ) + assert result3 is not None + assert result3.id is not None and str(result3.id).strip() + assert result3.error is None + + # Expect at least three stdout lines with propagated env values. + stdout_texts = [m.text.strip() for m in stdout_messages if m.text] + assert "step1:hello:world" in stdout_texts + assert "step2:hello_next:world" in stdout_texts + assert "step3:hello_next:world_next" in stdout_texts + for m in stdout_messages[:3]: + _assert_recent_timestamp_ms(m.timestamp)