Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal/server/handlers_instances.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package server

import (
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"path/filepath"
"strings"

"github.com/nbitslabs/flock/internal/db/sqlc"
"github.com/nbitslabs/flock/internal/ssh"
)

func expandHome(path string) string {
Expand Down Expand Up @@ -68,6 +70,10 @@ func cloneOrGetRepo(basePath, org, repo string) (string, error) {
return "", err
}

if err := ssh.EnsureGitHubHostKey(); err != nil {
return "", fmt.Errorf("failed to ensure GitHub SSH host key: %w", err)
}

orgPath := filepath.Join(basePath, "github.com", org)
if err := os.MkdirAll(orgPath, 0755); err != nil {
return "", err
Expand Down
109 changes: 109 additions & 0 deletions internal/ssh/known_hosts.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package ssh

import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
)

var (
ensureHostKeyOnce sync.Once
ensureHostKeyError error
)

const (
knownHostsFileName = "known_hosts"
sshDirName = "ssh"
githubHost = "github.com"
keyScanTimeout = 10 * time.Second
)

func EnsureGitHubHostKey() error {
ensureHostKeyOnce.Do(func() {
ensureHostKeyError = ensureGitHubHostKey()
})
return ensureHostKeyError
}

func ensureGitHubHostKey() error {
knownHostsPath, err := getKnownHostsPath()
if err != nil {
return fmt.Errorf("failed to get known_hosts path: %w", err)
}

if _, err := os.Stat(knownHostsPath); err == nil {
if containsHost(knownHostsPath, githubHost) {
return nil
}
}

if err := addGitHubHostKey(knownHostsPath); err != nil {
return fmt.Errorf("failed to add GitHub host key: %w", err)
}

return nil
}

func getKnownHostsPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("failed to get home directory: %w", err)
}

sshDir := filepath.Join(home, ".flock", sshDirName)
if err := os.MkdirAll(sshDir, 0700); err != nil {
return "", fmt.Errorf("failed to create SSH directory: %w", err)
}

return filepath.Join(sshDir, knownHostsFileName), nil
}

func containsHost(knownHostsPath, host string) bool {
data, err := os.ReadFile(knownHostsPath)
if err != nil {
return false
}

for _, line := range strings.Split(string(data), "\n") {
line = strings.TrimSpace(line)
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, host+" ") || strings.HasPrefix(line, host+",") {
return true
}
}
return false
}

func addGitHubHostKey(knownHostsPath string) error {
if _, err := exec.LookPath("ssh-keyscan"); err != nil {
return fmt.Errorf("ssh-keyscan not found: %w", err)
}

ctx, cancel := context.WithTimeout(context.Background(), keyScanTimeout)
defer cancel()

cmd := exec.CommandContext(ctx, "ssh-keyscan", "-t", "rsa,dsa,ecdsa,ed25519", githubHost)
output, err := cmd.Output()
if err != nil {
return fmt.Errorf("ssh-keyscan failed: %w", err)
}

f, err := os.OpenFile(knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
return fmt.Errorf("failed to open known_hosts: %w", err)
}
defer f.Close()

if _, err := f.Write(output); err != nil {
return fmt.Errorf("failed to write known_hosts: %w", err)
}

return nil
}
Loading