diff --git a/.gitignore b/.gitignore index 8f1541e0..1276cd3d 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ dist/ # binary brev-cli brev +brev-local # golang executable go1.* diff --git a/Makefile b/Makefile index f459f671..b072f053 100644 --- a/Makefile +++ b/Makefile @@ -8,24 +8,25 @@ fast-build: ## go build -o brev CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" .PHONY: local -local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg, or make local for defaults) +local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg arch=linux/amd64, or make local for defaults) $(call print-target) ifdef env @echo "Building with env=$(env) wrapper..." @echo ${VERSION} - CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" + $(if $(arch),GOOS=$(word 1,$(subst /, ,$(arch))) GOARCH=$(word 2,$(subst /, ,$(arch))),) CGO_ENABLED=0 go build -o brev-local -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" @echo '#!/bin/sh' > brev @echo '# Auto-generated wrapper with environment overrides' >> brev @echo 'export BREV_CONSOLE_URL="https://localhost.nvidia.com:3000"' >> brev @echo 'export BREV_AUTH_URL="https://api.stg.ngc.nvidia.com"' >> brev @echo 'export BREV_AUTH_ISSUER_URL="https://stg.login.nvidia.com"' >> brev @echo 'export BREV_API_URL="https://bd.$(env).brev.nvidia.com"' >> brev + @echo 'export BREV_PUBLIC_API_URL="https://api.$(env).brev.nvidia.com"' >> brev @echo 'export BREV_GRPC_URL="api.$(env).brev.nvidia.com:443"' >> brev - @echo 'exec "$$(cd "$$(dirname "$$0")" && pwd)/brev" "$$@"' >> brev + @echo 'exec "$$(cd "$$(dirname "$$0")" && pwd)/brev-local" "$$@"' >> brev @chmod +x brev else @echo "Building without environment overrides (using config.go defaults)..." - $(MAKE) fast-build + $(if $(arch),GOOS=$(word 1,$(subst /, ,$(arch))) GOARCH=$(word 2,$(subst /, ,$(arch))),) CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" endif .PHONY: install-dev diff --git a/go.mod b/go.mod index 8d197bd7..95b79584 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module github.com/brevdev/brev-cli go 1.24.0 require ( + buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260227212944-1f05724e97ab.2 + buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260227212944-1f05724e97ab.1 + connectrpc.com/connect v1.19.1 github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 github.com/briandowns/spinner v1.16.0 @@ -12,7 +15,7 @@ require ( github.com/go-git/go-git/v5 v5.13.2 github.com/go-resty/resty/v2 v2.17.0 github.com/golang-jwt/jwt/v5 v5.3.0 - github.com/google/go-cmp v0.6.0 + github.com/google/go-cmp v0.7.0 github.com/google/huproxy v0.0.0-20210816191033-a131ee126ce3 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 @@ -45,6 +48,7 @@ require ( ) require ( + buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 // indirect dario.cat/mergo v1.0.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect @@ -148,7 +152,7 @@ require ( golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect golang.org/x/time v0.12.0 // indirect - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.36.11 gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index d450015e..624f0dc4 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260227212944-1f05724e97ab.2 h1:R8G6M5utUEZaurs/9o/fhGMCp/ZHsaYjjArMFBsrrr8= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260227212944-1f05724e97ab.2/go.mod h1:Jooemm4ArTV81Co3zkto/PgOQtelQf5j3fgAGsv73T4= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260227212944-1f05724e97ab.1 h1:lOMJE1EzhYelGdR5FQEHUfUl60/3E69Y8dM4O1C8nMc= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260227212944-1f05724e97ab.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= +buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 h1:6amhprQmCKJ4wgJ6ngkh32d9V+dQcOLUZ/SfHdOnYgo= +buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1/go.mod h1:O+pnSHMru/naTMrm4tmpBoH3wz6PHa+R75HR7Mv8X2g= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -35,6 +41,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= @@ -208,8 +216,8 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -776,8 +784,8 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 8521ff26..cd733deb 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -13,6 +13,8 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/connect" "github.com/brevdev/brev-cli/pkg/cmd/copy" "github.com/brevdev/brev-cli/pkg/cmd/delete" + "github.com/brevdev/brev-cli/pkg/cmd/deregister" + "github.com/brevdev/brev-cli/pkg/cmd/enablessh" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/exec" "github.com/brevdev/brev-cli/pkg/cmd/fu" @@ -305,7 +307,9 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(reset.NewCmdReset(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(profile.NewCmdProfile(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(refresh.NewCmdRefresh(t, loginCmdStore)) - cmd.AddCommand(register.NewCmdRegister(t)) + cmd.AddCommand(register.NewCmdRegister(t, loginCmdStore)) + cmd.AddCommand(deregister.NewCmdDeregister(t, loginCmdStore)) + cmd.AddCommand(enablessh.NewCmdEnableSSH(t, loginCmdStore)) cmd.AddCommand(runtasks.NewCmdRunTasks(t, noLoginCmdStore)) cmd.AddCommand(proxy.NewCmdProxy(t, noLoginCmdStore)) cmd.AddCommand(healthcheck.NewCmdHealthcheck(t, noLoginCmdStore)) diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go new file mode 100644 index 00000000..ddab96d5 --- /dev/null +++ b/pkg/cmd/deregister/deregister.go @@ -0,0 +1,152 @@ +// Package deregister provides the brev deregister command for device deregistration +package deregister + +import ( + "context" + "fmt" + "runtime" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// DeregisterStore defines the store methods needed by the deregister command. +type DeregisterStore interface { + GetCurrentUser() (*entity.User, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) +} + +// deregisterDeps bundles the side-effecting dependencies of runDeregister so +// they can be replaced in tests. +type deregisterDeps struct { + goos string + promptSelect func(label string, items []string) string + uninstallNetbird func() error + newNodeClient func(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + registrationStore register.RegistrationStore +} + +func prodDeregisterDeps(brevHome string) deregisterDeps { + return deregisterDeps{ + goos: runtime.GOOS, + promptSelect: func(label string, items []string) string { + return terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: label, + Items: items, + }) + }, + uninstallNetbird: register.UninstallNetbird, + newNodeClient: register.NewNodeServiceClient, + registrationStore: register.NewFileRegistrationStore(brevHome), + } +} + +var ( + deregisterLong = `Deregister your device from NVIDIA Brev + +This command removes the local registration data and optionally uninstalls +NetBird (network agent).` + + deregisterExample = ` brev deregister` +) + +func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "deregister", + DisableFlagsInUseLine: true, + Short: "Deregister your device from Brev", + Long: deregisterLong, + Example: deregisterExample, + RunE: func(cmd *cobra.Command, args []string) error { + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runDeregister(cmd.Context(), t, store, prodDeregisterDeps(brevHome)) + }, + } + + return cmd +} + +func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, deps deregisterDeps) error { //nolint:funlen // deregistration flow + if deps.goos != "linux" { + return fmt.Errorf("brev deregister is only supported on Linux") + } + + registered, err := deps.registrationStore.Exists() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if !registered { + return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return fmt.Errorf("failed to read registration file: %w", err) + } + + t.Vprint("") + t.Vprint(t.Green("Deregistering device")) + t.Vprint("") + t.Vprintf(" Node ID: %s\n", reg.ExternalNodeID) + t.Vprintf(" Name: %s\n", reg.DisplayName) + t.Vprint("") + + confirm := deps.promptSelect( + "Proceed with deregistration?", + []string{"Yes, proceed", "No, cancel"}, + ) + if confirm != "Yes, proceed" { + t.Vprint("Deregistration canceled.") + return nil + } + + t.Vprint("") + t.Vprint(t.Yellow("Removing node from Brev...")) + client := deps.newNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.RemoveNode(ctx, connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: reg.ExternalNodeID, + })); err != nil { + return fmt.Errorf("failed to deregister node: %w", err) + } + t.Vprint(t.Green(" Node removed from Brev.")) + t.Vprint("") + + removeNetbird := deps.promptSelect( + "Would you also like to uninstall NetBird?", + []string{"Yes, uninstall NetBird", "No, keep NetBird installed"}, + ) + if removeNetbird == "Yes, uninstall NetBird" { + t.Vprint("Removing NetBird...") + if err := deps.uninstallNetbird(); err != nil { + t.Vprintf(" Warning: failed to uninstall NetBird: %v\n", err) + } else { + t.Vprint(t.Green(" NetBird uninstalled.")) + } + t.Vprint("") + } + + t.Vprint("Removing registration data...") + if err := deps.registrationStore.Delete(); err != nil { + t.Vprintf(" Warning: failed to remove local registration file: %v\n", err) + t.Vprint(" You can manually remove it with: rm ~/.brev/device_registration.json") + } + + t.Vprint(t.Green("Deregistration complete.")) + t.Vprint("") + + return nil +} diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go new file mode 100644 index 00000000..02d3c41f --- /dev/null +++ b/pkg/cmd/deregister/deregister_test.go @@ -0,0 +1,297 @@ +package deregister + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/terminal" +) + +type mockDeregisterStore struct { + user *entity.User + home string + token string + err error +} + +func (m *mockDeregisterStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockDeregisterStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockDeregisterStore) GetAccessToken() (string, error) { return m.token, nil } + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) +} + +func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nodev1.RemoveNodeRequest]) (*connect.Response[nodev1.RemoveNodeResponse], error) { + resp, err := f.removeNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +// mockRegistrationStore satisfies register.RegistrationStore for deregister tests. +type mockRegistrationStore struct { + reg *register.DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *register.DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*register.DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +// testDeregisterDeps returns deps with all side-effects stubbed. The +// promptSelect defaults to confirming all prompts. +func testDeregisterDeps(t *testing.T, svc *fakeNodeService, regStore register.RegistrationStore) (deregisterDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return deregisterDeps{ + goos: "linux", + promptSelect: func(_ string, items []string) string { + // Default: pick first item (Yes, ...) + if len(items) > 0 { + return items[0] + } + return "" + }, + uninstallNetbird: func() error { return nil }, + newNodeClient: func(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return register.NewNodeServiceClient(provider, server.URL) + }, + registrationStore: regStore, + }, server +} + +func Test_runDeregister_HappyPath(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + DeviceID: "dev-uuid", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var gotNodeID string + svc := &fakeNodeService{ + removeNodeFn: func(req *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + gotNodeID = req.GetExternalNodeId() + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if gotNodeID != "unode_abc" { + t.Errorf("expected node ID unode_abc, got %s", gotNodeID) + } + + // Registration should be deleted + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("expected registration to be deleted after deregister") + } +} + +func Test_runDeregister_UserCancels(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + callCount := 0 + deps.promptSelect = func(_ string, _ []string) string { + callCount++ + if callCount == 2 { + // Second prompt is the confirmation — cancel it + return "No, cancel" + } + return "No, keep NetBird installed" + } + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("expected nil error on cancel, got: %v", err) + } + + // Registration should still exist + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if !exists { + t.Error("registration should still exist after cancel") + } +} + +func Test_runDeregister_NotRegistered(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when not registered") + } +} + +func Test_runDeregister_RemoveNodeFails(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when RemoveNode fails") + } + + // Registration should still exist (server-side removal failed) + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if !exists { + t.Error("registration should still exist when RemoveNode fails") + } +} + +func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + uninstallCalled := false + deps, server := testDeregisterDeps(t, svc, regStore) + defer server.Close() + + deps.promptSelect = func(label string, items []string) string { + if label == "Would you also like to uninstall NetBird?" { + return "No, keep NetBird installed" + } + return "Yes, proceed" + } + deps.uninstallNetbird = func() error { + uninstallCalled = true + return nil + } + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if uninstallCalled { + t.Error("NetBird uninstall should not be called when user declines") + } +} diff --git a/pkg/cmd/enablessh/enablessh.go b/pkg/cmd/enablessh/enablessh.go new file mode 100644 index 00000000..2df12055 --- /dev/null +++ b/pkg/cmd/enablessh/enablessh.go @@ -0,0 +1,193 @@ +// Package enablessh provides the brev enableSSH command for enabling SSH access +// to a registered external node. +package enablessh + +import ( + "context" + "fmt" + "os" + "os/exec" + "os/user" + "path/filepath" + "runtime" + "strings" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// EnableSSHStore defines the store methods needed by the enableSSH command. +type EnableSSHStore interface { + GetCurrentUser() (*entity.User, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) +} + +// enableSSHDeps bundles the side-effecting dependencies of runEnableSSH so they +// can be replaced in tests. +type enableSSHDeps struct { + goos string + newNodeClient func(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + registrationStore register.RegistrationStore +} + +func prodEnableSSHDeps(brevHome string) enableSSHDeps { + return enableSSHDeps{ + goos: runtime.GOOS, + newNodeClient: register.NewNodeServiceClient, + registrationStore: register.NewFileRegistrationStore(brevHome), + } +} + +func NewCmdEnableSSH(t *terminal.Terminal, store EnableSSHStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "enable-ssh", + DisableFlagsInUseLine: true, + Short: "Enable SSH access to this registered device", + Long: "Enable SSH access to this registered device for the current Brev user.", + Example: " brev enable-ssh", + RunE: func(cmd *cobra.Command, args []string) error { + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runEnableSSH(cmd.Context(), t, store, prodEnableSSHDeps(brevHome)) + }, + } + + return cmd +} + +func runEnableSSH(ctx context.Context, t *terminal.Terminal, s EnableSSHStore, deps enableSSHDeps) error { + if deps.goos != "linux" { + return fmt.Errorf("brev enable-ssh is only supported on Linux") + } + + registered, err := deps.registrationStore.Exists() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if !registered { + return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device first") + } + + reg, err := deps.registrationStore.Load() + if err != nil { + return fmt.Errorf("failed to read registration file: %w", err) + } + + brevUser, err := s.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + return EnableSSH(ctx, t, deps.newNodeClient, s, reg, brevUser) +} + +// EnableSSH grants SSH access to the given node for the specified Brev user. +// It is exported so that the register command can reuse it after registration. +func EnableSSH( + ctx context.Context, + t *terminal.Terminal, + newClient func(register.TokenProvider, string) nodev1connect.ExternalNodeServiceClient, + tokenProvider register.TokenProvider, + reg *register.DeviceRegistration, + brevUser *entity.User, +) error { + u, err := user.Current() + if err != nil { + return fmt.Errorf("failed to determine current Linux user: %w", err) + } + linuxUser := u.Username + + checkSSHDaemon(t) + + t.Vprint("") + t.Vprint(t.Green("Enabling SSH access on this device")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Brev user: %s\n", brevUser.ID) + t.Vprintf(" Linux user: %s\n", linuxUser) + t.Vprint("") + + client := newClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: brevUser.ID, + LinuxUser: linuxUser, + OrganizationId: reg.OrgID, + })); err != nil { + return fmt.Errorf("failed to enable SSH access: %w", err) + } + + if brevUser.PublicKey != "" { + if err := installAuthorizedKey(u, brevUser.PublicKey); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) + } else { + t.Vprint(" Brev public key added to authorized_keys.") + } + } + + t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) + return nil +} + +// installAuthorizedKey appends the given public key to the user's +// ~/.ssh/authorized_keys if it isn't already present. +func installAuthorizedKey(u *user.User, pubKey string) error { + pubKey = strings.TrimSpace(pubKey) + if pubKey == "" { + return nil + } + + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("creating .ssh directory: %w", err) + } + + authKeysPath := filepath.Join(sshDir, "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("reading authorized_keys: %w", err) + } + + if strings.Contains(string(existing), pubKey) { + return nil // already present + } + + // Ensure existing content ends with a newline before appending. + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += pubKey + "\n" + + if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + + return nil +} + +// checkSSHDaemon prints a warning if neither "ssh" nor "sshd" systemd services +// appear to be active. It never returns an error — it is best-effort. +func checkSSHDaemon(t *terminal.Terminal) { + for _, svc := range []string{"ssh", "sshd"} { + out, err := exec.Command("systemctl", "is-active", svc).Output() //nolint:gosec // fixed service names + if err == nil && len(out) > 0 && string(out[:len(out)-1]) == "active" { + return + } + } + t.Vprintf(" %s\n", t.Yellow("Warning: SSH daemon does not appear to be running. SSH access may not work until sshd is started.")) +} diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index da95aa89..47c4e743 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -2,14 +2,19 @@ package ls import ( + "context" "encoding/json" "fmt" "os" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + "github.com/brevdev/brev-cli/pkg/analytics" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/hello" + "github.com/brevdev/brev-cli/pkg/cmd/register" cmdutil "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/cmdcontext" "github.com/brevdev/brev-cli/pkg/config" @@ -32,6 +37,7 @@ type LsStore interface { GetUsers(queryParams map[string]string) ([]entity.User, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) + GetAccessToken() (string, error) hello.HelloStore } @@ -99,7 +105,7 @@ with other commands like stop, start, or delete.`, return nil }, Args: cmderrors.TransformToValidationError(cobra.MinimumNArgs(0)), - ValidArgs: []string{"orgs", "workspaces"}, + ValidArgs: []string{"orgs", "workspaces", "nodes"}, RunE: func(cmd *cobra.Command, args []string) error { err := RunLs(t, loginLsStore, args, org, showAll, jsonOutput) if err != nil { @@ -226,6 +232,12 @@ func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization return breverrors.WrapAndTrace(err) } return nil + } else if util.IsSingularOrPlural(arg, "node") { + err := ls.RunNodes(org) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil } return nil } @@ -234,13 +246,19 @@ type Ls struct { lsStore LsStore terminal *terminal.Terminal jsonOutput bool + piped bool } func NewLs(lsStore LsStore, terminal *terminal.Terminal, jsonOutput bool) *Ls { + piped := false + if fi, err := os.Stdout.Stat(); err == nil { + piped = fi.Mode()&os.ModeCharDevice == 0 + } return &Ls{ lsStore: lsStore, terminal: terminal, jsonOutput: jsonOutput, + piped: piped, } } @@ -422,6 +440,10 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll } else { ls.ShowUserWorkspaces(org, orgs, user, allWorkspaces) } + + // Also show external nodes in the default listing + ls.showNodesSection(org) + return nil } @@ -624,3 +646,123 @@ func getStatusColoredText(t *terminal.Terminal, status string) string { return status } } + +// NodeInfo represents external node data for JSON output. +type NodeInfo struct { + Name string `json:"name"` + ExternalNodeID string `json:"external_node_id"` + DeviceID string `json:"device_id"` + OrgID string `json:"org_id"` + Status string `json:"status"` +} + +func (ls Ls) listNodes(org *entity.Organization) ([]*nodev1.ExternalNode, error) { + client := register.NewNodeServiceClient(ls.lsStore, config.GlobalConfig.GetBrevPublicAPIURL()) + resp, err := client.ListNodes(context.Background(), connect.NewRequest(&nodev1.ListNodesRequest{ + OrganizationId: org.ID, + })) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp.Msg.GetItems(), nil +} + +// RunNodes lists external nodes for the given org. +func (ls Ls) RunNodes(org *entity.Organization) error { + nodes, err := ls.listNodes(org) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if len(nodes) == 0 { + if ls.jsonOutput { + fmt.Println("[]") + return nil + } + if ls.piped { + return nil + } + ls.terminal.Vprint(ls.terminal.Yellow("No external nodes in this org.")) + return nil + } + + if ls.jsonOutput { + return ls.outputNodesJSON(nodes) + } + if ls.piped { + displayNodesTablePlain(nodes) + return nil + } + + ls.terminal.Vprintf("\nYou have %d external node(s) in Org %s\n", len(nodes), ls.terminal.Yellow(org.Name)) + displayNodesTable(ls.terminal, nodes) + return nil +} + +// showNodesSection appends external nodes to the default `brev ls` output. +// Errors are silently ignored so that a ListNodes failure doesn't break the +// workspace listing. +func (ls Ls) showNodesSection(org *entity.Organization) { + nodes, err := ls.listNodes(org) + if err != nil || len(nodes) == 0 { + return + } + + if ls.jsonOutput || ls.piped { + // JSON and piped modes are already handled per-section; skip here to + // avoid duplicating output when the user runs `brev ls nodes` explicitly. + return + } + + ls.terminal.Vprintf("\nExternal Nodes (%d):\n", len(nodes)) + displayNodesTable(ls.terminal, nodes) +} + +func (ls Ls) outputNodesJSON(nodes []*nodev1.ExternalNode) error { + var infos []NodeInfo + for _, n := range nodes { + infos = append(infos, NodeInfo{ + Name: n.GetName(), + ExternalNodeID: n.GetExternalNodeId(), + DeviceID: n.GetDeviceId(), + OrgID: n.GetOrganizationId(), + Status: nodeConnectionStatus(n), + }) + } + output, err := json.MarshalIndent(infos, "", " ") + if err != nil { + return breverrors.WrapAndTrace(err) + } + fmt.Println(string(output)) + return nil +} + +func displayNodesTable(t *terminal.Terminal, nodes []*nodev1.ExternalNode) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + ta.AppendHeader(table.Row{"NAME", "NODE ID", "DEVICE ID", "STATUS"}) + for _, n := range nodes { + status := nodeConnectionStatus(n) + ta.AppendRows([]table.Row{{n.GetName(), n.GetExternalNodeId(), n.GetDeviceId(), getStatusColoredText(t, status)}}) + } + ta.Render() +} + +func displayNodesTablePlain(nodes []*nodev1.ExternalNode) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + ta.AppendHeader(table.Row{"NAME", "NODE ID", "DEVICE ID", "STATUS"}) + for _, n := range nodes { + ta.AppendRows([]table.Row{{n.GetName(), n.GetExternalNodeId(), n.GetDeviceId(), nodeConnectionStatus(n)}}) + } + ta.Render() +} + +func nodeConnectionStatus(n *nodev1.ExternalNode) string { + if ci := n.GetConnectivityInfo(); ci != nil && ci.GetRegistrationCommand() != "" { + return "REGISTERED" + } + return "UNKNOWN" +} diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go new file mode 100644 index 00000000..866231bf --- /dev/null +++ b/pkg/cmd/register/hardware.go @@ -0,0 +1,306 @@ +package register + +import ( + "bufio" + "fmt" + "os/exec" + "runtime" + "strconv" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// CommandRunner abstracts command execution for testability. +type CommandRunner interface { + Run(name string, args ...string) ([]byte, error) +} + +// ExecCommandRunner is the real implementation that runs OS commands. +type ExecCommandRunner struct{} + +func (r ExecCommandRunner) Run(name string, args ...string) ([]byte, error) { + out, err := exec.Command(name, args...).Output() // #nosec G204 + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return out, nil +} + +// NodeSpec matches the proto NodeSpec message from dev-plane. +// All fields are best-effort. +type NodeSpec struct { + GPUs []NodeGPU `json:"gpus"` + RAMBytes *int64 `json:"ram_bytes,omitempty"` + CPUCount *int32 `json:"cpu_count,omitempty"` + Architecture string `json:"architecture,omitempty"` + Storage []NodeStorage `json:"storage,omitempty"` + OS string `json:"os,omitempty"` + OSVersion string `json:"os_version,omitempty"` +} + +// NodeStorage represents a single storage device with its size and type. +type NodeStorage struct { + StorageBytes int64 `json:"storage_bytes"` + StorageType string `json:"storage_type,omitempty"` // "SSD" or "HDD" +} + +// NodeGPU matches the proto NodeGPU message. +type NodeGPU struct { + Model string `json:"model"` + Count int32 `json:"count"` + MemoryBytes *int64 `json:"memory_bytes,omitempty"` +} + +// FileReader abstracts file reading for testability. +type FileReader interface { + ReadFile(path string) ([]byte, error) +} + +// CollectHardwareProfile gathers system hardware information. +// All fields are best-effort; failures are silently ignored. +func CollectHardwareProfile(runner CommandRunner, reader FileReader) (*NodeSpec, error) { + spec := &NodeSpec{ + Architecture: runtime.GOARCH, + } + + if gpus, err := parseNvidiaSMI(runner); err == nil { + spec.GPUs = gpus + } + + if cpuCount, err := parseCPUCount(reader); err == nil { + count32 := int32(cpuCount) + spec.CPUCount = &count32 + } + + if ramBytes, err := parseMemInfo(reader); err == nil { + spec.RAMBytes = &ramBytes + } + + osName, osVersion := parseOSRelease(reader) + spec.OS = osName + spec.OSVersion = osVersion + + spec.Storage = collectStorage(runner) + + return spec, nil +} + +// parseCPUCount reads /proc/cpuinfo and returns the number of logical processors. +func parseCPUCount(reader FileReader) (int, error) { + data, err := reader.ReadFile("/proc/cpuinfo") + if err != nil { + return 0, breverrors.WrapAndTrace(err) + } + return parseCPUCountContent(string(data)) +} + +// parseCPUCountContent parses the content of /proc/cpuinfo for processor count. +func parseCPUCountContent(content string) (int, error) { + count := 0 + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + if strings.HasPrefix(scanner.Text(), "processor") { + count++ + } + } + if count == 0 { + return 0, fmt.Errorf("no processors found in /proc/cpuinfo") + } + return count, nil +} + +// parseMemInfo reads /proc/meminfo and returns total RAM in bytes. +func parseMemInfo(reader FileReader) (int64, error) { + data, err := reader.ReadFile("/proc/meminfo") + if err != nil { + return 0, breverrors.WrapAndTrace(err) + } + return parseMemInfoContent(string(data)) +} + +// parseMemInfoContent parses the content of /proc/meminfo. +func parseMemInfoContent(content string) (int64, error) { + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "MemTotal:") { + fields := strings.Fields(line) + if len(fields) < 2 { + return 0, fmt.Errorf("unexpected MemTotal format: %s", line) + } + kb, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse MemTotal value: %w", err) + } + return kb * 1024, nil // convert kB to bytes + } + } + return 0, fmt.Errorf("MemTotal not found in /proc/meminfo") +} + +// parseOSRelease reads /etc/os-release and returns (name, version). +func parseOSRelease(reader FileReader) (string, string) { + data, err := reader.ReadFile("/etc/os-release") + if err != nil { + return "", "" + } + return parseOSReleaseContent(string(data)) +} + +// parseOSReleaseContent parses the content of /etc/os-release. +func parseOSReleaseContent(content string) (string, string) { + name := "" + version := "" + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if val, ok := strings.CutPrefix(line, "NAME="); ok { + name = unquote(val) + } + if val, ok := strings.CutPrefix(line, "VERSION_ID="); ok { + version = unquote(val) + } + } + return name, version +} + +// unquote removes surrounding double quotes from a string. +func unquote(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +// parseNvidiaSMI queries nvidia-smi for GPU information. +// Returns an error if nvidia-smi fails or no GPUs are found. +func parseNvidiaSMI(runner CommandRunner) ([]NodeGPU, error) { + out, err := runner.Run("nvidia-smi", + "--query-gpu=name,memory.total", + "--format=csv,noheader,nounits", + ) + if err != nil { + return nil, fmt.Errorf("nvidia-smi not available: %w", err) + } + gpus := parseNvidiaSMIOutput(string(out)) + if len(gpus) == 0 { + return nil, fmt.Errorf("nvidia-smi returned no GPUs") + } + return gpus, nil +} + +// parseNvidiaSMIOutput parses nvidia-smi CSV output, grouping identical GPU +// models into a single NodeGPU with a count. +func parseNvidiaSMIOutput(output string) []NodeGPU { + type gpuKey struct { + model string + memoryBytes int64 + } + + counts := make(map[gpuKey]int32) + var order []gpuKey + + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + parts := strings.Split(line, ", ") + if len(parts) < 2 { + continue + } + model := strings.TrimSpace(parts[0]) + memMB, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) + if err != nil { + continue + } + key := gpuKey{model: model, memoryBytes: memMB * 1024 * 1024} + if counts[key] == 0 { + order = append(order, key) + } + counts[key]++ + } + + gpus := make([]NodeGPU, 0, len(order)) + for _, key := range order { + mem := key.memoryBytes + gpus = append(gpus, NodeGPU{ + Model: key.model, + Count: counts[key], + MemoryBytes: &mem, + }) + } + return gpus +} + +// collectStorage returns per-device storage entries from lsblk, +// using the ROTA column to determine device type. +func collectStorage(runner CommandRunner) []NodeStorage { + out, err := runner.Run("lsblk", "-b", "-d", "-n", "-o", "NAME,SIZE,TYPE,ROTA") + if err != nil { + return nil + } + return parseStorageOutput(string(out)) +} + +// parseStorageOutput parses lsblk output (NAME,SIZE,TYPE,ROTA columns), +// returning one NodeStorage entry per disk device. ROTA=0 → SSD, ROTA=1 → HDD. +func parseStorageOutput(output string) []NodeStorage { + var devices []NodeStorage + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 4 || fields[2] != "disk" { + continue + } + size, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + continue + } + entry := NodeStorage{StorageBytes: size} + rota, err := strconv.Atoi(fields[3]) + if err == nil { + if rota == 0 { + entry.StorageType = "SSD" + } else { + entry.StorageType = "HDD" + } + } + devices = append(devices, entry) + } + return devices +} + +// FormatNodeSpec returns a human-readable summary of the hardware profile. +func FormatNodeSpec(s *NodeSpec) string { + var b strings.Builder + if s.CPUCount != nil { + _, _ = fmt.Fprintf(&b, " CPU: %d cores\n", *s.CPUCount) + } + if s.RAMBytes != nil { + _, _ = fmt.Fprintf(&b, " RAM: %.1f GB\n", float64(*s.RAMBytes)/(1024*1024*1024)) + } + for _, gpu := range s.GPUs { + if gpu.MemoryBytes != nil { + memGB := float64(*gpu.MemoryBytes) / (1024 * 1024 * 1024) + _, _ = fmt.Fprintf(&b, " GPUs: %d x %s (%.1f GB)\n", gpu.Count, gpu.Model, memGB) + } else { + _, _ = fmt.Fprintf(&b, " GPUs: %d x %s\n", gpu.Count, gpu.Model) + } + } + _, _ = fmt.Fprintf(&b, " Arch: %s\n", s.Architecture) + if s.OS != "" || s.OSVersion != "" { + _, _ = fmt.Fprintf(&b, " OS: %s %s\n", s.OS, s.OSVersion) + } + for _, st := range s.Storage { + _, _ = fmt.Fprintf(&b, " Storage: %.1f GB", float64(st.StorageBytes)/(1024*1024*1024)) + if st.StorageType != "" { + _, _ = fmt.Fprintf(&b, " (%s)", st.StorageType) + } + b.WriteString("\n") + } + return b.String() +} diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go new file mode 100644 index 00000000..929c79d3 --- /dev/null +++ b/pkg/cmd/register/hardware_test.go @@ -0,0 +1,425 @@ +package register + +import ( + "strings" + "testing" +) + +func Test_parseCPUCountContent_ValidInput(t *testing.T) { + content := `processor : 0 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 + +processor : 1 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 + +processor : 2 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 +` + count, err := parseCPUCountContent(content) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if count != 3 { + t.Errorf("expected 3 CPUs, got %d", count) + } +} + +func Test_parseCPUCountContent_EmptyInput(t *testing.T) { + _, err := parseCPUCountContent("") + if err == nil { + t.Fatal("expected error for empty input") + } +} + +func Test_parseMemInfoContent_ValidInput(t *testing.T) { + content := `MemTotal: 131886028 kB +MemFree: 1234567 kB +MemAvailable: 98765432 kB +` + bytes, err := parseMemInfoContent(content) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := int64(131886028) * 1024 + if bytes != expected { + t.Errorf("expected %d bytes, got %d", expected, bytes) + } +} + +func Test_parseMemInfoContent_MissingMemTotal(t *testing.T) { + content := `MemFree: 1234567 kB +MemAvailable: 98765432 kB +` + _, err := parseMemInfoContent(content) + if err == nil { + t.Fatal("expected error for missing MemTotal") + } +} + +func Test_parseOSReleaseContent(t *testing.T) { + content := `NAME="Ubuntu" +VERSION="24.04 LTS (Noble Numbat)" +ID=ubuntu +VERSION_ID="24.04" +PRETTY_NAME="Ubuntu 24.04 LTS" +` + name, version := parseOSReleaseContent(content) + if name != "Ubuntu" { + t.Errorf("expected Ubuntu, got %s", name) + } + if version != "24.04" { + t.Errorf("expected 24.04, got %s", version) + } +} + +func Test_parseOSReleaseContent_Unquoted(t *testing.T) { + content := `NAME=Fedora +VERSION_ID=39 +` + name, version := parseOSReleaseContent(content) + if name != "Fedora" { + t.Errorf("expected Fedora, got %s", name) + } + if version != "39" { + t.Errorf("expected 39, got %s", version) + } +} + +func Test_parseNvidiaSMIOutput_GroupsByModel(t *testing.T) { + output := `NVIDIA GB10, 131072 +NVIDIA GB10, 131072 +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 1 { + t.Fatalf("expected 1 GPU group, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected GPU model: %s", gpus[0].Model) + } + if gpus[0].Count != 2 { + t.Errorf("expected count 2, got %d", gpus[0].Count) + } + expectedMem := int64(131072) * 1024 * 1024 + if gpus[0].MemoryBytes == nil || *gpus[0].MemoryBytes != expectedMem { + t.Errorf("expected %d bytes, got %v", expectedMem, gpus[0].MemoryBytes) + } +} + +func Test_parseNvidiaSMIOutput_MultipleModels(t *testing.T) { + output := `NVIDIA A100, 81920 +NVIDIA GB10, 131072 +NVIDIA A100, 81920 +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 2 { + t.Fatalf("expected 2 GPU groups, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA A100" || gpus[0].Count != 2 { + t.Errorf("expected 2x NVIDIA A100, got %dx %s", gpus[0].Count, gpus[0].Model) + } + if gpus[1].Model != "NVIDIA GB10" || gpus[1].Count != 1 { + t.Errorf("expected 1x NVIDIA GB10, got %dx %s", gpus[1].Count, gpus[1].Model) + } +} + +func Test_parseNvidiaSMIOutput_Empty(t *testing.T) { + gpus := parseNvidiaSMIOutput("") + if len(gpus) != 0 { + t.Errorf("expected 0 GPUs, got %d", len(gpus)) + } +} + +func Test_parseStorageOutput(t *testing.T) { + output := `nvme0n1 500107862016 disk 0 +nvme1n1 1000204886016 disk 0 +sda 2048 rom 1 +` + devices := parseStorageOutput(output) + if len(devices) != 2 { + t.Fatalf("expected 2 devices, got %d", len(devices)) + } + if devices[0].StorageBytes != 500107862016 { + t.Errorf("expected 500107862016, got %d", devices[0].StorageBytes) + } + if devices[0].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", devices[0].StorageType) + } + if devices[1].StorageBytes != 1000204886016 { + t.Errorf("expected 1000204886016, got %d", devices[1].StorageBytes) + } + if devices[1].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", devices[1].StorageType) + } +} + +func Test_parseStorageOutput_SDA(t *testing.T) { + output := `sda 500107862016 disk 1 +` + devices := parseStorageOutput(output) + if len(devices) != 1 { + t.Fatalf("expected 1 device, got %d", len(devices)) + } + if devices[0].StorageBytes != 500107862016 { + t.Errorf("expected 500107862016 bytes, got %d", devices[0].StorageBytes) + } + if devices[0].StorageType != "HDD" { + t.Errorf("expected HDD, got %s", devices[0].StorageType) + } +} + +func Test_unquote(t *testing.T) { + tests := []struct { + input string + want string + }{ + {`"Ubuntu"`, "Ubuntu"}, + {`Ubuntu`, "Ubuntu"}, + {`""`, ""}, + {`"a"`, "a"}, + {``, ""}, + } + for _, tt := range tests { + got := unquote(tt.input) + if got != tt.want { + t.Errorf("unquote(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func Test_FormatNodeSpec(t *testing.T) { + cpuCount := int32(12) + ramBytes := int64(137438953472) // 128 GB + memBytes := int64(137438953472) // 128 GB + s := &NodeSpec{ + CPUCount: &cpuCount, + RAMBytes: &ramBytes, + Architecture: "arm64", + OS: "Ubuntu", + OSVersion: "24.04", + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 1, MemoryBytes: &memBytes}, + }, + } + output := FormatNodeSpec(s) + if output == "" { + t.Fatal("expected non-empty output") + } + if !strings.Contains(output, "12 cores") { + t.Errorf("expected CPU info in output: %s", output) + } + if !strings.Contains(output, "128.0 GB") { + t.Errorf("expected RAM info in output: %s", output) + } + if !strings.Contains(output, "NVIDIA GB10") { + t.Errorf("expected GPU info in output: %s", output) + } +} + +func Test_FormatNodeSpec_MinimalFields(t *testing.T) { + s := &NodeSpec{ + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 1}, + }, + Architecture: "arm64", + } + output := FormatNodeSpec(s) + if strings.Contains(output, "CPU:") { + t.Errorf("should not contain CPU when nil: %s", output) + } + if strings.Contains(output, "RAM:") { + t.Errorf("should not contain RAM when nil: %s", output) + } + if !strings.Contains(output, "NVIDIA GB10") { + t.Errorf("expected GPU info: %s", output) + } + if !strings.Contains(output, "arm64") { + t.Errorf("expected arch info: %s", output) + } +} + +func Test_FormatNodeSpec_WithStorage(t *testing.T) { + s := &NodeSpec{ + Architecture: "amd64", + Storage: []NodeStorage{ + {StorageBytes: 500107862016, StorageType: "SSD"}, + {StorageBytes: 1000204886016, StorageType: "HDD"}, + }, + } + output := FormatNodeSpec(s) + if !strings.Contains(output, "Storage:") { + t.Errorf("expected storage in output: %s", output) + } + if !strings.Contains(output, "SSD") { + t.Errorf("expected SSD in output: %s", output) + } + if !strings.Contains(output, "HDD") { + t.Errorf("expected HDD in output: %s", output) + } +} + +func Test_parseNvidiaSMIOutput_MalformedLines(t *testing.T) { + output := ` +malformed line +NVIDIA GB10, 131072 +, , +just-a-name +NVIDIA A100, not-a-number +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 1 { + t.Fatalf("expected 1 valid GPU, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected model: %s", gpus[0].Model) + } +} + +func Test_parseStorageOutput_Empty(t *testing.T) { + devices := parseStorageOutput("") + if len(devices) != 0 { + t.Errorf("expected 0 devices, got %d", len(devices)) + } +} + +func Test_parseStorageOutput_NoDiskDevices(t *testing.T) { + output := `sr0 1073741312 rom 1 +loop0 123456 loop 0 +` + devices := parseStorageOutput(output) + if len(devices) != 0 { + t.Errorf("expected 0 devices for non-disk entries, got %d", len(devices)) + } +} + +// mockCommandRunner for testing CollectHardwareProfile +type mockCommandRunner struct { + outputs map[string][]byte + errors map[string]error +} + +func (m *mockCommandRunner) Run(name string, args ...string) ([]byte, error) { + key := name + if err, ok := m.errors[key]; ok { + return nil, err + } + if out, ok := m.outputs[key]; ok { + return out, nil + } + return nil, nil +} + +type mockFileReader struct { + files map[string][]byte +} + +func (m *mockFileReader) ReadFile(path string) ([]byte, error) { + if data, ok := m.files[path]; ok { + return data, nil + } + return nil, &mockFileNotFoundError{path: path} +} + +type mockFileNotFoundError struct{ path string } + +func (e *mockFileNotFoundError) Error() string { return "file not found: " + e.path } + +func Test_CollectHardwareProfile_WithMocks(t *testing.T) { + runner := &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\nNVIDIA GB10, 131072\n"), + "lsblk": []byte("nvme0n1 500107862016 disk 0\n"), + }, + } + reader := &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\nprocessor\t: 1\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), + }, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(spec.GPUs) != 1 || spec.GPUs[0].Count != 2 { + t.Errorf("expected 1 GPU group with count 2, got %v", spec.GPUs) + } + if spec.CPUCount == nil || *spec.CPUCount != 2 { + t.Errorf("expected 2 CPUs, got %v", spec.CPUCount) + } + if spec.RAMBytes == nil || *spec.RAMBytes != 131886028*1024 { + t.Errorf("unexpected RAM: %v", spec.RAMBytes) + } + if spec.OS != "Ubuntu" || spec.OSVersion != "24.04" { + t.Errorf("unexpected OS: %s %s", spec.OS, spec.OSVersion) + } + if len(spec.Storage) != 1 || spec.Storage[0].StorageBytes != 500107862016 { + t.Errorf("unexpected storage: %v", spec.Storage) + } + if spec.Storage[0].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", spec.Storage[0].StorageType) + } +} + +func Test_CollectHardwareProfile_GPUBestEffort(t *testing.T) { + runner := &mockCommandRunner{ + errors: map[string]error{ + "nvidia-smi": &mockFileNotFoundError{path: "nvidia-smi"}, + }, + } + reader := &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + }, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(spec.GPUs) != 0 { + t.Errorf("expected 0 GPUs when nvidia-smi fails, got %d", len(spec.GPUs)) + } + if spec.CPUCount == nil || *spec.CPUCount != 1 { + t.Errorf("expected 1 CPU, got %v", spec.CPUCount) + } +} + +func Test_CollectHardwareProfile_OptionalFieldsMissing(t *testing.T) { + runner := &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), + }, + errors: map[string]error{ + "lsblk": &mockFileNotFoundError{path: "lsblk"}, + }, + } + reader := &mockFileReader{ + files: map[string][]byte{}, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if spec.CPUCount != nil { + t.Errorf("expected nil CPUCount when /proc/cpuinfo missing") + } + if spec.RAMBytes != nil { + t.Errorf("expected nil RAMBytes when /proc/meminfo missing") + } + if len(spec.Storage) != 0 { + t.Errorf("expected empty Storage when lsblk fails, got %v", spec.Storage) + } + if len(spec.GPUs) != 1 { + t.Errorf("expected 1 GPU, got %d", len(spec.GPUs)) + } +} diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go new file mode 100644 index 00000000..ad7652d3 --- /dev/null +++ b/pkg/cmd/register/netbird.go @@ -0,0 +1,72 @@ +package register + +import ( + "fmt" + "os" + "os/exec" +) + +// InstallNetbird installs NetBird if it is not already present. +func InstallNetbird() error { + if _, err := exec.LookPath("netbird"); err == nil { + return nil + } + + script := `(curl -fsSL https://pkgs.netbird.io/install.sh | sh) || (curl -fsSL https://pkgs.netbird.io/install.sh | sh -s -- --update)` + + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install NetBird: %w", err) + } + return nil +} + +// runSetupCommand executes the setup command returned by the AddNode RPC. +func runSetupCommand(script string) error { + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command failed: %w", err) + } + return nil +} + +// UninstallNetbird stops, uninstalls the service, and removes the NetBird +// package or binary. It reads /etc/netbird/install.conf (written by the +// install script) to determine the original installation method. +// The down/stop steps are best-effort since the service may already be +// disconnected or stopped after deregistration. +func UninstallNetbird() error { + script := ` +sudo netbird down 2>/dev/null +sudo netbird service stop 2>/dev/null +sudo netbird service uninstall 2>/dev/null + +PKG_MGR="bin" +if [ -f /etc/netbird/install.conf ]; then + PKG_MGR=$(grep -oP '(?<=package_manager=)\S+' /etc/netbird/install.conf 2>/dev/null || echo "bin") +fi + +case "$PKG_MGR" in + apt) sudo apt-get remove -y netbird ;; + dnf) sudo dnf remove -y netbird ;; + yum) sudo yum remove -y netbird ;; + *) sudo rm -f /usr/bin/netbird /usr/local/bin/netbird ;; +esac + +sudo rm -rf /etc/netbird +` + + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to uninstall NetBird: %w", err) + } + return nil +} diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 1c16586b..293e15f5 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -1,44 +1,297 @@ -// Package register provides the brev register command for DGX Spark registration +// Package register provides the brev register command for device registration package register import ( + "context" + "fmt" + "os" + "os/user" + "path/filepath" + "runtime" + "strings" + "time" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + "github.com/google/uuid" + + "github.com/brevdev/brev-cli/pkg/config" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" ) +// RegisterStore defines the store methods needed by the register command. +type RegisterStore interface { + GetCurrentUser() (*entity.User, error) + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetBrevHomePath() (string, error) + GetAccessToken() (string, error) +} + +// OSFileReader reads files from the real OS filesystem. +type OSFileReader struct{} + +func (r OSFileReader) ReadFile(path string) ([]byte, error) { + data, err := os.ReadFile(path) // #nosec G304 + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return data, nil +} + +// registerDeps bundles the side-effecting dependencies of runRegister so they +// can be replaced in tests. +type registerDeps struct { + goos string + promptYesNo func(label string) bool + installNetbird func() error + runSetupCommand func(script string) error + newNodeClient func(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + commandRunner CommandRunner + fileReader FileReader + registrationStore RegistrationStore +} + +func prodRegisterDeps(brevHome string) registerDeps { + return registerDeps{ + goos: runtime.GOOS, + promptYesNo: func(label string) bool { + result := terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: label, + Items: []string{"Yes, proceed", "No, cancel"}, + }) + return result == "Yes, proceed" + }, + installNetbird: InstallNetbird, + runSetupCommand: runSetupCommand, + newNodeClient: NewNodeServiceClient, + commandRunner: ExecCommandRunner{}, + fileReader: OSFileReader{}, + registrationStore: NewFileRegistrationStore(brevHome), + } +} + var ( - registerLong = `Register your DGX Spark with NVIDIA Brev + registerLong = `Register your device with NVIDIA Brev -Join the waitlist to be among the first to register your DGX Spark -for early access integration with Brev.` +This command installs NetBird (network agent), and registers this machine with Brev.` - registerExample = ` brev register` + registerExample = ` brev register "My DGX Spark"` ) -func NewCmdRegister(t *terminal.Terminal) *cobra.Command { +func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { cmd := &cobra.Command{ Annotations: map[string]string{"configuration": ""}, - Use: "register", - Aliases: []string{"spark"}, + Use: "register ", DisableFlagsInUseLine: true, - Short: "Register your DGX Spark with Brev", + Short: "Register this device with Brev", Long: registerLong, Example: registerExample, + Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { - runRegister(t) - return nil + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runRegister(cmd.Context(), t, store, args[0], prodRegisterDeps(brevHome)) }, } return cmd } -func runRegister(t *terminal.Terminal) { - t.Vprint("\n") - t.Vprint(t.Green("Thanks so much for your interest in registering your DGX Spark with Brev!\n\n")) - t.Vprint("To be on the waitlist for early access to this feature, please fill out this form:\n\n") - t.Vprint(t.Yellow(" 👉 https://forms.gle/RHCHGmZuiMQQ2faA6\n\n")) - t.Vprint("We will reach out to the provided email with updates and instructions on how to register soon (:\n") - t.Vprint("\n") +func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow + org, err := getOrgToRegisterFor(deps, s) + if err != nil { + return err + } + + brevUser, err := s.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + u, _ := user.Current() + linuxUser := u.Username + + t.Vprint("") + t.Vprint(t.Green("Registering your device with Brev")) + t.Vprint("") + t.Vprintf(" Name: %s\n", t.Yellow(name)) + t.Vprintf(" Organization: %s\n", org.Name) + t.Vprintf(" Registering for Linux user: %s\n", linuxUser) + t.Vprint("") + t.Vprint("This will perform the following steps:") + t.Vprint(" 1. Install NetBird") + t.Vprint(" 2. Collect hardware profile") + t.Vprint(" 3. Register this machine with Brev") + t.Vprint("") + + if !deps.promptYesNo("Proceed with registration?") { + t.Vprint("Registration canceled.") + return nil + } + + t.Vprint("") + t.Vprint(t.Yellow("[Step 1/3] Installing NetBird...")) + if err := deps.installNetbird(); err != nil { + return fmt.Errorf("NetBird installation failed: %w", err) + } + t.Vprint(t.Green(" NetBird installed successfully.")) + + t.Vprint("") + t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile...")) + t.Vprint("") + + nodeSpec, err := CollectHardwareProfile(deps.commandRunner, deps.fileReader) + if err != nil { + return fmt.Errorf("failed to collect hardware profile: %w", err) + } + + t.Vprint(" Hardware profile:") + t.Vprint(FormatNodeSpec(nodeSpec)) + + t.Vprint("") + t.Vprint(t.Yellow("[Step 3/3] Registering with Brev...")) + + deviceID := uuid.New().String() + client := deps.newNodeClient(s, config.GlobalConfig.GetBrevPublicAPIURL()) + addResp, err := client.AddNode(ctx, connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: org.ID, + Name: name, + DeviceId: deviceID, + NodeSpec: toProtoNodeSpec(nodeSpec), + })) + if err != nil { + return fmt.Errorf("failed to register node: %w", err) + } + + node := addResp.Msg.GetExternalNode() + reg := &DeviceRegistration{ + ExternalNodeID: node.GetExternalNodeId(), + DisplayName: name, + OrgID: org.ID, + DeviceID: deviceID, + RegisteredAt: time.Now().UTC().Format(time.RFC3339), + NodeSpec: *nodeSpec, + } + if err := deps.registrationStore.Save(reg); err != nil { + return fmt.Errorf("node registered but failed to save locally: %w", err) + } + + t.Vprint(t.Green(" Registration complete.")) + + if ci := node.GetConnectivityInfo(); ci != nil { + if cmd := ci.GetRegistrationCommand(); cmd != "" { + if err := deps.runSetupCommand(cmd); err != nil { + t.Vprintf(" Warning: setup command failed: %v\n", err) + } + } + } + + if deps.promptYesNo("Would you like to enable SSH access to this device?") { + grantSSHAccess(ctx, t, deps, s, reg, brevUser, u) + } + + return nil +} + +func grantSSHAccess(ctx context.Context, t *terminal.Terminal, deps registerDeps, tokenProvider TokenProvider, reg *DeviceRegistration, brevUser *entity.User, u *user.User) { + t.Vprint("") + t.Vprint(t.Green("Enabling SSH access on this device")) + t.Vprint("") + t.Vprintf(" Node: %s (%s)\n", reg.DisplayName, reg.ExternalNodeID) + t.Vprintf(" Brev user: %s\n", brevUser.ID) + t.Vprintf(" Linux user: %s\n", u.Username) + t.Vprint("") + + client := deps.newNodeClient(tokenProvider, config.GlobalConfig.GetBrevPublicAPIURL()) + if _, err := client.GrantNodeSSHAccess(ctx, connect.NewRequest(&nodev1.GrantNodeSSHAccessRequest{ + ExternalNodeId: reg.ExternalNodeID, + UserId: brevUser.ID, + LinuxUser: u.Username, + OrganizationId: reg.OrgID, + })); err != nil { + t.Vprintf(" Warning: failed to enable SSH: %v\n", err) + return + } + + if brevUser.PublicKey != "" { + if err := installAuthorizedKey(u, brevUser.PublicKey); err != nil { + t.Vprintf(" %s\n", t.Yellow(fmt.Sprintf("Warning: failed to install SSH public key: %v", err))) + } else { + t.Vprint(" Brev public key added to authorized_keys.") + } + } + + t.Vprint(t.Green(fmt.Sprintf("SSH access enabled. You can now SSH to this device via: brev shell %s", reg.DisplayName))) +} + +// installAuthorizedKey appends the given public key to the user's +// ~/.ssh/authorized_keys if it isn't already present. +func installAuthorizedKey(u *user.User, pubKey string) error { + pubKey = strings.TrimSpace(pubKey) + if pubKey == "" { + return nil + } + + sshDir := filepath.Join(u.HomeDir, ".ssh") + if err := os.MkdirAll(sshDir, 0o700); err != nil { + return fmt.Errorf("creating .ssh directory: %w", err) + } + + authKeysPath := filepath.Join(sshDir, "authorized_keys") + + existing, err := os.ReadFile(authKeysPath) // #nosec G304 + if err != nil && !os.IsNotExist(err) { + return fmt.Errorf("reading authorized_keys: %w", err) + } + + if strings.Contains(string(existing), pubKey) { + return nil // already present + } + + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += pubKey + "\n" + + if err := os.WriteFile(authKeysPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("writing authorized_keys: %w", err) + } + + return nil +} + +func getOrgToRegisterFor(deps registerDeps, s RegisterStore) (*entity.Organization, error) { + if deps.goos != "linux" { + return nil, fmt.Errorf("brev register is only supported on Linux") + } + + _, err := s.GetCurrentUser() // ensure active token + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + org, err := s.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if org == nil { + return nil, fmt.Errorf("no organization found; please create or join an organization first") + } + + alreadyRegistered, err := deps.registrationStore.Exists() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if alreadyRegistered { + return nil, fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") + } + return org, nil } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go new file mode 100644 index 00000000..95c5e513 --- /dev/null +++ b/pkg/cmd/register/register_test.go @@ -0,0 +1,331 @@ +package register + +import ( + "context" + "fmt" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/terminal" +) + +// mockRegisterStore satisfies RegisterStore for orchestration tests. +type mockRegisterStore struct { + user *entity.User + org *entity.Organization + home string + token string + err error +} + +func (m *mockRegisterStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.org, nil +} + +func (m *mockRegisterStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } + +// mockRegistrationStore satisfies RegistrationStore for orchestration tests. +type mockRegistrationStore struct { + reg *DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +// testRegisterDeps returns deps with all side effects stubbed out, and a fake +// ConnectRPC server backed by the provided fakeNodeService. +func testRegisterDeps(t *testing.T, svc *fakeNodeService, regStore RegistrationStore) (registerDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return registerDeps{ + goos: "linux", + promptYesNo: func(_ string) bool { return true }, + installNetbird: func() error { return nil }, + runSetupCommand: func(_ string) error { return nil }, + newNodeClient: func(provider TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return NewNodeServiceClient(provider, server.URL) + }, + commandRunner: &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), + "lsblk": []byte("nvme0n1 500107862016 disk 0\n"), + }, + }, + fileReader: &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\nprocessor\t: 1\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), + }, + }, + registrationStore: regStore, + }, server +} + +func Test_runRegister_HappyPath(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + var gotSetupCmd string + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org: %s", req.GetOrganizationId()) + } + if req.GetName() != "My Spark" { + t.Errorf("unexpected name: %s", req.GetName()) + } + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + ConnectivityInfo: &nodev1.ConnectivityInfo{ + RegistrationCommand: "netbird up --key abc", + }, + }, + }, nil + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.runSetupCommand = func(cmd string) error { + gotSetupCmd = cmd + return nil + } + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + // Verify registration was persisted + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if !exists { + t.Fatal("expected registration to exist after successful register") + } + + reg, err := regStore.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + if reg.ExternalNodeID != "unode_abc" { + t.Errorf("expected ExternalNodeID unode_abc, got %s", reg.ExternalNodeID) + } + if reg.DisplayName != "My Spark" { + t.Errorf("expected display name 'My Spark', got %s", reg.DisplayName) + } + if reg.OrgID != "org_123" { + t.Errorf("expected org org_123, got %s", reg.OrgID) + } + + // Verify setup command was executed + if gotSetupCmd != "netbird up --key abc" { + t.Errorf("expected setup command 'netbird up --key abc', got %q", gotSetupCmd) + } +} + +func Test_runRegister_UserCancels(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.promptYesNo = func(_ string) bool { return false } + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("expected nil error on cancel, got: %v", err) + } + + // Registration should not exist + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("registration should not exist after cancel") + } +} + +func Test_runRegister_AlreadyRegistered(t *testing.T) { + regStore := &mockRegistrationStore{ + reg: &DeviceRegistration{ + ExternalNodeID: "unode_existing", + DisplayName: "Existing", + }, + } + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error for already-registered machine") + } +} + +func Test_runRegister_NoOrganization(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: nil, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error when no org exists") + } +} + +func Test_runRegister_AddNodeFails(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + svc := &fakeNodeService{ + addNodeFn: func(_ *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error when AddNode fails") + } + + // Registration should not exist on failure + exists, err := regStore.Exists() + if err != nil { + t.Fatalf("Exists error: %v", err) + } + if exists { + t.Error("registration should not exist after AddNode failure") + } +} + +func Test_runRegister_NoSetupCommand(t *testing.T) { + regStore := &mockRegistrationStore{} + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: "/home/testuser/.brev", + token: "tok", + } + + setupCalled := false + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + // No ConnectivityInfo / RegistrationCommand + }, nil + }, + } + + deps, server := testRegisterDeps(t, svc, regStore) + defer server.Close() + + deps.runSetupCommand = func(_ string) error { + setupCalled = true + return nil + } + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + if setupCalled { + t.Error("setup command should not be called when empty") + } +} diff --git a/pkg/cmd/register/registration.go b/pkg/cmd/register/registration.go new file mode 100644 index 00000000..0161ce76 --- /dev/null +++ b/pkg/cmd/register/registration.go @@ -0,0 +1,92 @@ +package register + +import ( + "encoding/json" + "os" + "path/filepath" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/files" + "github.com/spf13/afero" +) + +const registrationFileName = "device_registration.json" + +// DeviceRegistration is the persistent identity file for a registered device. +// Fields align with the AddNodeResponse from dev-plane. +type DeviceRegistration struct { + ExternalNodeID string `json:"external_node_id"` + DisplayName string `json:"display_name"` + OrgID string `json:"org_id"` + DeviceID string `json:"device_id"` + RegisteredAt string `json:"registered_at"` + NodeSpec NodeSpec `json:"node_spec"` +} + +// RegistrationStore defines the contract for persisting device registration data. +type RegistrationStore interface { + Save(reg *DeviceRegistration) error + Load() (*DeviceRegistration, error) + Delete() error + Exists() (bool, error) +} + +// FileRegistrationStore implements RegistrationStore using the local filesystem. +type FileRegistrationStore struct { + brevHome string +} + +// NewFileRegistrationStore returns a FileRegistrationStore rooted at brevHome. +func NewFileRegistrationStore(brevHome string) *FileRegistrationStore { + return &FileRegistrationStore{brevHome: brevHome} +} + +func (s *FileRegistrationStore) path() string { + return filepath.Join(s.brevHome, registrationFileName) +} + +func (s *FileRegistrationStore) Save(reg *DeviceRegistration) error { + path := s.path() + data, err := json.MarshalIndent(reg, "", " ") + if err != nil { + return breverrors.WrapAndTrace(err) + } + if err := files.AppFs.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return breverrors.WrapAndTrace(err) + } + if err := afero.WriteFile(files.AppFs, path, data, 0o600); err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) { + path := s.path() + var reg DeviceRegistration + err := files.ReadJSON(files.AppFs, path, ®) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return ®, nil +} + +func (s *FileRegistrationStore) Delete() error { + path := s.path() + err := files.DeleteFile(files.AppFs, path) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +func (s *FileRegistrationStore) Exists() (bool, error) { + path := s.path() + _, err := files.AppFs.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, breverrors.WrapAndTrace(err) +} diff --git a/pkg/cmd/register/registration_test.go b/pkg/cmd/register/registration_test.go new file mode 100644 index 00000000..6e14514d --- /dev/null +++ b/pkg/cmd/register/registration_test.go @@ -0,0 +1,158 @@ +package register + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/files" + "github.com/spf13/afero" +) + +func setupTestFs(t *testing.T) (string, func()) { + t.Helper() + origFs := files.AppFs + files.AppFs = afero.NewMemMapFs() + brevHome := "/home/testuser/.brev" + if err := files.AppFs.MkdirAll(brevHome, 0o770); err != nil { + t.Fatalf("failed to create test dir: %v", err) + } + return brevHome, func() { files.AppFs = origFs } +} + +func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + cpuCount := int32(12) + ramBytes := int64(137438953472) + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "My Spark", + OrgID: "org_xyz", + DeviceID: "device-uuid-123", + RegisteredAt: "2026-02-25T00:00:00Z", + NodeSpec: NodeSpec{ + CPUCount: &cpuCount, + RAMBytes: &ramBytes, + Architecture: "arm64", + }, + } + + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + loaded, err := store.Load() + if err != nil { + t.Fatalf("Load failed: %v", err) + } + + if loaded.ExternalNodeID != reg.ExternalNodeID { + t.Errorf("ExternalNodeID mismatch: got %s, want %s", loaded.ExternalNodeID, reg.ExternalNodeID) + } + if loaded.DisplayName != reg.DisplayName { + t.Errorf("DisplayName mismatch: got %s, want %s", loaded.DisplayName, reg.DisplayName) + } + if loaded.OrgID != reg.OrgID { + t.Errorf("OrgID mismatch: got %s, want %s", loaded.OrgID, reg.OrgID) + } + if loaded.DeviceID != reg.DeviceID { + t.Errorf("DeviceID mismatch: got %s, want %s", loaded.DeviceID, reg.DeviceID) + } + if loaded.NodeSpec.Architecture != "arm64" { + t.Errorf("Architecture mismatch: got %s", loaded.NodeSpec.Architecture) + } + if loaded.NodeSpec.CPUCount == nil || *loaded.NodeSpec.CPUCount != 12 { + t.Errorf("CPUCount mismatch: got %v", loaded.NodeSpec.CPUCount) + } +} + +func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + exists, err := store.Exists() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { + t.Error("expected Exists to return false") + } +} + +func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "Test", + } + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + exists, err := store.Exists() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !exists { + t.Error("expected Exists to return true") + } +} + +func Test_DeleteRegistration_RemovesFile(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "Test", + } + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) + } + + if err := store.Delete(); err != nil { + t.Fatalf("Delete failed: %v", err) + } + + exists, err := store.Exists() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { + t.Error("expected Exists to return false after delete") + } +} + +func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + _, err := store.Load() + if err == nil { + t.Error("expected error loading missing registration") + } +} + +func Test_DeleteRegistration_FailsWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + store := NewFileRegistrationStore(brevHome) + + err := store.Delete() + if err == nil { + t.Error("expected error deleting missing registration") + } +} diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go new file mode 100644 index 00000000..501d4aeb --- /dev/null +++ b/pkg/cmd/register/rpcclient.go @@ -0,0 +1,96 @@ +package register + +import ( + "net/http" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// TokenProvider abstracts access token retrieval for the HTTP transport. +type TokenProvider interface { + GetAccessToken() (string, error) +} + +// bearerTokenTransport injects a Bearer token into every request. +type bearerTokenTransport struct { + provider TokenProvider + base http.RoundTripper +} + +func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.provider.GetAccessToken() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+token) + resp, err := t.base.RoundTrip(req) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp, nil +} + +// newAuthenticatedHTTPClient creates an http.Client that injects the bearer token +// from the given provider on every request. +func newAuthenticatedHTTPClient(provider TokenProvider) *http.Client { + return &http.Client{ + Transport: &bearerTokenTransport{ + provider: provider, + base: http.DefaultTransport, + }, + } +} + +// NewNodeServiceClient creates a ConnectRPC ExternalNodeServiceClient using the +// given token provider for authentication. +func NewNodeServiceClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { + return nodev1connect.NewExternalNodeServiceClient( + newAuthenticatedHTTPClient(provider), + baseURL, + ) +} + +// toProtoNodeSpec converts the local NodeSpec (used for collection, display, persistence) +// to the generated proto NodeSpec for RPC calls. +func toProtoNodeSpec(s *NodeSpec) *nodev1.NodeSpec { + if s == nil { + return nil + } + + proto := &nodev1.NodeSpec{ + RamBytes: s.RAMBytes, + CpuCount: s.CPUCount, + } + + for _, st := range s.Storage { + proto.Storage = append(proto.Storage, &nodev1.StorageSpec{ + StorageBytes: st.StorageBytes, + StorageType: st.StorageType, + }) + } + + if s.Architecture != "" { + proto.Architecture = &s.Architecture + } + if s.OS != "" { + proto.Os = &s.OS + } + if s.OSVersion != "" { + proto.OsVersion = &s.OSVersion + } + + for _, g := range s.GPUs { + pg := &nodev1.GPUSpec{ + Model: g.Model, + Count: g.Count, + MemoryBytes: g.MemoryBytes, + } + proto.Gpus = append(proto.Gpus, pg) + } + + return proto +} diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go new file mode 100644 index 00000000..56123094 --- /dev/null +++ b/pkg/cmd/register/rpcclient_test.go @@ -0,0 +1,274 @@ +package register + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" +) + +type mockTokenProvider struct { + token string + err error +} + +func (m *mockTokenProvider) GetAccessToken() (string, error) { + return m.token, m.err +} + +func Test_bearerTokenTransport_InjectsHeader(t *testing.T) { + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + provider := &mockTokenProvider{token: "test-token-123"} + client := newAuthenticatedHTTPClient(provider) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() //nolint:errcheck // test + + if gotAuth != "Bearer test-token-123" { + t.Errorf("expected 'Bearer test-token-123', got %q", gotAuth) + } +} + +func Test_bearerTokenTransport_PropagatesTokenError(t *testing.T) { + provider := &mockTokenProvider{err: http.ErrAbortHandler} + client := newAuthenticatedHTTPClient(provider) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil) + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() //nolint:errcheck // test + t.Fatal("expected error from token provider") + } +} + +func Test_toProtoNodeSpec(t *testing.T) { + cpuCount := int32(12) + ramBytes := int64(137438953472) + memBytes := int64(137438953472) + + local := &NodeSpec{ + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 2, MemoryBytes: &memBytes}, + }, + RAMBytes: &ramBytes, + CPUCount: &cpuCount, + Architecture: "arm64", + Storage: []NodeStorage{ + {StorageBytes: 500107862016, StorageType: "SSD"}, + }, + OS: "Ubuntu", + OSVersion: "24.04", + } + + proto := toProtoNodeSpec(local) + + if proto.GetCpuCount() != 12 { + t.Errorf("expected CpuCount 12, got %d", proto.GetCpuCount()) + } + if proto.GetRamBytes() != 137438953472 { + t.Errorf("expected RamBytes, got %d", proto.GetRamBytes()) + } + if proto.GetArchitecture() != "arm64" { + t.Errorf("expected arm64, got %s", proto.GetArchitecture()) + } + if proto.GetOs() != "Ubuntu" { + t.Errorf("expected Ubuntu, got %s", proto.GetOs()) + } + if proto.GetOsVersion() != "24.04" { + t.Errorf("expected 24.04, got %s", proto.GetOsVersion()) + } + if len(proto.GetStorage()) != 1 { + t.Fatalf("expected 1 storage entry, got %d", len(proto.GetStorage())) + } + if proto.GetStorage()[0].GetStorageBytes() != 500107862016 { + t.Errorf("expected StorageBytes 500107862016, got %d", proto.GetStorage()[0].GetStorageBytes()) + } + if proto.GetStorage()[0].GetStorageType() != "SSD" { + t.Errorf("expected SSD, got %s", proto.GetStorage()[0].GetStorageType()) + } + if len(proto.GetGpus()) != 1 { + t.Fatalf("expected 1 GPU, got %d", len(proto.GetGpus())) + } + gpu := proto.GetGpus()[0] + if gpu.GetModel() != "NVIDIA GB10" { + t.Errorf("expected NVIDIA GB10, got %s", gpu.GetModel()) + } + if gpu.GetCount() != 2 { + t.Errorf("expected count 2, got %d", gpu.GetCount()) + } + if gpu.GetMemoryBytes() != 137438953472 { + t.Errorf("expected memory bytes, got %d", gpu.GetMemoryBytes()) + } +} + +func Test_toProtoNodeSpec_Nil(t *testing.T) { + if toProtoNodeSpec(nil) != nil { + t.Error("expected nil for nil input") + } +} + +func Test_toProtoNodeSpec_MinimalFields(t *testing.T) { + local := &NodeSpec{ + Architecture: "amd64", + } + proto := toProtoNodeSpec(local) + if proto.GetArchitecture() != "amd64" { + t.Errorf("expected amd64, got %s", proto.GetArchitecture()) + } + if proto.RamBytes != nil { + t.Error("expected nil RamBytes") + } + if proto.CpuCount != nil { + t.Error("expected nil CpuCount") + } + if len(proto.GetGpus()) != 0 { + t.Error("expected no GPUs") + } +} + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) +} + +func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) { + resp, err := f.addNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nodev1.RemoveNodeRequest]) (*connect.Response[nodev1.RemoveNodeResponse], error) { + resp, err := f.removeNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func Test_NewNodeServiceClient_AddNode(t *testing.T) { + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org ID: %s", req.GetOrganizationId()) + } + if req.GetName() != "My Spark" { + t.Errorf("unexpected name: %s", req.GetName()) + } + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + }, nil + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + resp, err := client.AddNode(context.Background(), connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: "org_123", + Name: "My Spark", + DeviceId: "dev-uuid", + NodeSpec: &nodev1.NodeSpec{Architecture: strPtr("arm64")}, + })) + if err != nil { + t.Fatalf("AddNode failed: %v", err) + } + if resp.Msg.GetExternalNode().GetExternalNodeId() != "unode_abc" { + t.Errorf("unexpected node ID: %s", resp.Msg.GetExternalNode().GetExternalNodeId()) + } +} + +func Test_NewNodeServiceClient_AddNode_ServerError(t *testing.T) { + svc := &fakeNodeService{ + addNodeFn: func(_ *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + _, err := client.AddNode(context.Background(), connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: "org_123", + Name: "Test", + DeviceId: "dev", + })) + if err == nil { + t.Fatal("expected error for server error response") + } +} + +func Test_NewNodeServiceClient_RemoveNode(t *testing.T) { + svc := &fakeNodeService{ + removeNodeFn: func(req *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + if req.GetExternalNodeId() != "unode_abc" { + t.Errorf("unexpected node ID: %s", req.GetExternalNodeId()) + } + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + _, err := client.RemoveNode(context.Background(), connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: "unode_abc", + })) + if err != nil { + t.Fatalf("RemoveNode failed: %v", err) + } +} + +func Test_NewNodeServiceClient_RemoveNode_ServerError(t *testing.T) { + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return nil, connect.NewError(connect.CodeNotFound, nil) + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + defer server.Close() + + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) + + _, err := client.RemoveNode(context.Background(), connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: "unode_missing", + })) + if err == nil { + t.Fatal("expected error for not found response") + } +} + +func strPtr(s string) *string { return &s } diff --git a/pkg/store/http.go b/pkg/store/http.go index c64d9217..60884f81 100644 --- a/pkg/store/http.go +++ b/pkg/store/http.go @@ -61,6 +61,15 @@ func (s *AuthHTTPStore) GetWindowsDir() (string, error) { return s.GetWSLHostHomeDir() } +// GetAccessToken returns a fresh access token, refreshing if needed. +func (s *AuthHTTPStore) GetAccessToken() (string, error) { + token, err := s.authHTTPClient.auth.GetAccessToken() + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return token, nil +} + func (f *FileStore) WithAuthHTTPClient(c *AuthHTTPClient) *AuthHTTPStore { // err never returned from GetCurrentWorkspaceID id, _ := f.GetCurrentWorkspaceID()