Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions pkg/cmd/refresh/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type RefreshStore interface {
ssh.SSHConfigurerV2Store
GetCurrentUser() (*entity.User, error)
GetCurrentUserKeys() (*entity.UserKeys, error)
WriteAuthorizedKey(publicKey string) error
Chmod(string, fs.FileMode) error
MkdirAll(string, fs.FileMode) error
GetBrevCloudflaredBinaryPath() (string, error)
Expand Down Expand Up @@ -62,7 +63,7 @@ func RunRefreshBetter(store RefreshStore) error {
return breverrors.WrapAndTrace(err)
}

cu, err := GetConfigUpdater(store)
cu, keys, err := GetConfigUpdater(store)
if err != nil {
return breverrors.WrapAndTrace(err)
}
Expand All @@ -72,6 +73,11 @@ func RunRefreshBetter(store RefreshStore) error {
return breverrors.WrapAndTrace(err)
}

err = store.WriteAuthorizedKey(keys.PublicKey)
if err != nil {
return breverrors.WrapAndTrace(err)
}

privateKeyPath, err := store.GetPrivateKeyPath()
if err != nil {
return breverrors.WrapAndTrace(err)
Expand All @@ -91,7 +97,7 @@ func RunRefresh(store RefreshStore) error {
return breverrors.WrapAndTrace(err)
}

cu, err := GetConfigUpdater(store)
cu, keys, err := GetConfigUpdater(store)
if err != nil {
return breverrors.WrapAndTrace(err)
}
Expand All @@ -101,6 +107,11 @@ func RunRefresh(store RefreshStore) error {
return breverrors.WrapAndTrace(err)
}

err = store.WriteAuthorizedKey(keys.PublicKey)
if err != nil {
return breverrors.WrapAndTrace(err)
}

privateKeyPath, err := store.GetPrivateKeyPath()
if err != nil {
return breverrors.WrapAndTrace(err)
Expand Down Expand Up @@ -139,20 +150,20 @@ func RunRefreshAsync(rstore RefreshStore) *RefreshRes {
return &res
}

func GetConfigUpdater(store RefreshStore) (*ssh.ConfigUpdater, error) {
func GetConfigUpdater(store RefreshStore) (*ssh.ConfigUpdater, *entity.UserKeys, error) {
configs, err := ssh.GetSSHConfigs(store)
if err != nil {
return nil, breverrors.WrapAndTrace(err)
return nil, nil, breverrors.WrapAndTrace(err)
}

keys, err := store.GetCurrentUserKeys()
if err != nil {
return nil, breverrors.WrapAndTrace(err)
return nil, nil, breverrors.WrapAndTrace(err)
}

cu := ssh.NewConfigUpdater(store, configs, keys.PrivateKey)

return cu, nil
return cu, keys, nil
}

func GetCloudflare(refreshStore RefreshStore) store.Cloudflared {
Expand Down
38 changes: 38 additions & 0 deletions pkg/files/files.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"os"
"os/exec"
"path/filepath"
"strings"

breverrors "github.com/brevdev/brev-cli/pkg/errors"
"golang.org/x/text/encoding/charmap"
Expand Down Expand Up @@ -75,6 +76,10 @@ func GetSSHPrivateKeyPath(home string) string {
return fpath
}

func GetAuthorizedKeysPath(home string) string {
return filepath.Join(home, ".ssh", "authorized_keys")
}

func GetUserSSHConfigPath(home string) (string, error) {
sshConfigPath := filepath.Join(home, ".ssh", "config")
return sshConfigPath, nil
Expand Down Expand Up @@ -210,6 +215,39 @@ func OverwriteJSON(fs afero.Fs, filepath string, v interface{}) error {

// write

// WriteAuthorizedKey ensures the given public key is present in ~/.ssh/authorized_keys.
// It appends the key only if it's not already there.
func WriteAuthorizedKey(fs afero.Fs, publicKey string, home string) error {
authorizedKeysPath := GetAuthorizedKeysPath(home)
err := fs.MkdirAll(filepath.Dir(authorizedKeysPath), 0o700)
if err != nil {
return breverrors.WrapAndTrace(err)
}

publicKey = strings.TrimSpace(publicKey)

existing, err := afero.ReadFile(fs, authorizedKeysPath)
if err != nil && !os.IsNotExist(err) {
return breverrors.WrapAndTrace(err)
}

if strings.Contains(string(existing), publicKey) {
return nil
}

content := string(existing)
if len(content) > 0 && !strings.HasSuffix(content, "\n") {
content += "\n"
}
content += publicKey + "\n"

err = afero.WriteFile(fs, authorizedKeysPath, []byte(content), 0o600)
if err != nil {
return breverrors.WrapAndTrace(err)
}
return nil
}

func WriteSSHPrivateKey(fs afero.Fs, data string, home string) error {
pkPath := GetSSHPrivateKeyPath(home)
err := fs.MkdirAll(filepath.Dir(pkPath), defaultFilePermission)
Expand Down
12 changes: 12 additions & 0 deletions pkg/store/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,18 @@ func (f FileStore) WritePrivateKey(pem string) error {
return nil
}

func (f FileStore) WriteAuthorizedKey(publicKey string) error {
home, err := f.UserHomeDir()
if err != nil {
return breverrors.WrapAndTrace(err)
}
err = files.WriteAuthorizedKey(f.fs, publicKey, home)
if err != nil {
return breverrors.WrapAndTrace(err)
}
return nil
}

func (f FileStore) GetPrivateKeyPath() (string, error) {
home, err := f.UserHomeDir()
if err != nil {
Expand Down
Loading