diff --git a/internal/server/handlers_instances.go b/internal/server/handlers_instances.go index 5fe0d5c..de750c0 100644 --- a/internal/server/handlers_instances.go +++ b/internal/server/handlers_instances.go @@ -2,6 +2,7 @@ package server import ( "encoding/json" + "fmt" "net/http" "os" "os/exec" @@ -9,6 +10,7 @@ import ( "strings" "github.com/nbitslabs/flock/internal/db/sqlc" + "github.com/nbitslabs/flock/internal/ssh" ) func expandHome(path string) string { @@ -68,6 +70,10 @@ func cloneOrGetRepo(basePath, org, repo string) (string, error) { return "", err } + if err := ssh.EnsureGitHubHostKey(); err != nil { + return "", fmt.Errorf("failed to ensure GitHub SSH host key: %w", err) + } + orgPath := filepath.Join(basePath, "github.com", org) if err := os.MkdirAll(orgPath, 0755); err != nil { return "", err diff --git a/internal/ssh/known_hosts.go b/internal/ssh/known_hosts.go new file mode 100644 index 0000000..c77e30b --- /dev/null +++ b/internal/ssh/known_hosts.go @@ -0,0 +1,109 @@ +package ssh + +import ( + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" +) + +var ( + ensureHostKeyOnce sync.Once + ensureHostKeyError error +) + +const ( + knownHostsFileName = "known_hosts" + sshDirName = "ssh" + githubHost = "github.com" + keyScanTimeout = 10 * time.Second +) + +func EnsureGitHubHostKey() error { + ensureHostKeyOnce.Do(func() { + ensureHostKeyError = ensureGitHubHostKey() + }) + return ensureHostKeyError +} + +func ensureGitHubHostKey() error { + knownHostsPath, err := getKnownHostsPath() + if err != nil { + return fmt.Errorf("failed to get known_hosts path: %w", err) + } + + if _, err := os.Stat(knownHostsPath); err == nil { + if containsHost(knownHostsPath, githubHost) { + return nil + } + } + + if err := addGitHubHostKey(knownHostsPath); err != nil { + return fmt.Errorf("failed to add GitHub host key: %w", err) + } + + return nil +} + +func getKnownHostsPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + + sshDir := filepath.Join(home, ".flock", sshDirName) + if err := os.MkdirAll(sshDir, 0700); err != nil { + return "", fmt.Errorf("failed to create SSH directory: %w", err) + } + + return filepath.Join(sshDir, knownHostsFileName), nil +} + +func containsHost(knownHostsPath, host string) bool { + data, err := os.ReadFile(knownHostsPath) + if err != nil { + return false + } + + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if strings.HasPrefix(line, host+" ") || strings.HasPrefix(line, host+",") { + return true + } + } + return false +} + +func addGitHubHostKey(knownHostsPath string) error { + if _, err := exec.LookPath("ssh-keyscan"); err != nil { + return fmt.Errorf("ssh-keyscan not found: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), keyScanTimeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "ssh-keyscan", "-t", "rsa,dsa,ecdsa,ed25519", githubHost) + output, err := cmd.Output() + if err != nil { + return fmt.Errorf("ssh-keyscan failed: %w", err) + } + + f, err := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("failed to open known_hosts: %w", err) + } + defer f.Close() + + if _, err := f.Write(output); err != nil { + return fmt.Errorf("failed to write known_hosts: %w", err) + } + + return nil +}