From 3c2548ac03b1344ac10696bb2645fccfcc32fcc8 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Tue, 24 Feb 2026 18:26:02 -0800 Subject: [PATCH 01/11] registration of Spark --- Makefile | 4 +- pkg/cmd/cmd.go | 4 +- pkg/cmd/deregister/deregister.go | 110 +++++++++++ pkg/cmd/register/hardware.go | 305 ++++++++++++++++++++++++++++++ pkg/cmd/register/hardware_test.go | 201 ++++++++++++++++++++ pkg/cmd/register/identity.go | 136 +++++++++++++ pkg/cmd/register/identity_test.go | 210 ++++++++++++++++++++ pkg/cmd/register/netbird.go | 37 ++++ pkg/cmd/register/osfile.go | 8 + pkg/cmd/register/register.go | 135 +++++++++++-- 10 files changed, 1134 insertions(+), 16 deletions(-) create mode 100644 pkg/cmd/deregister/deregister.go create mode 100644 pkg/cmd/register/hardware.go create mode 100644 pkg/cmd/register/hardware_test.go create mode 100644 pkg/cmd/register/identity.go create mode 100644 pkg/cmd/register/identity_test.go create mode 100644 pkg/cmd/register/netbird.go create mode 100644 pkg/cmd/register/osfile.go diff --git a/Makefile b/Makefile index f459f671..e9ddc9a6 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg, or mak ifdef env @echo "Building with env=$(env) wrapper..." @echo ${VERSION} - CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" + CGO_ENABLED=0 go build -o brev-local -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" @echo '#!/bin/sh' > brev @echo '# Auto-generated wrapper with environment overrides' >> brev @echo 'export BREV_CONSOLE_URL="https://localhost.nvidia.com:3000"' >> brev @@ -21,7 +21,7 @@ ifdef env @echo 'export BREV_AUTH_ISSUER_URL="https://stg.login.nvidia.com"' >> brev @echo 'export BREV_API_URL="https://bd.$(env).brev.nvidia.com"' >> brev @echo 'export BREV_GRPC_URL="api.$(env).brev.nvidia.com:443"' >> brev - @echo 'exec "$$(cd "$$(dirname "$$0")" && pwd)/brev" "$$@"' >> brev + @echo 'exec "$$(cd "$$(dirname "$$0")" && pwd)/brev-local" "$$@"' >> brev @chmod +x brev else @echo "Building without environment overrides (using config.go defaults)..." diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index c6a7edfe..572d9741 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -35,6 +35,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/redeem" "github.com/brevdev/brev-cli/pkg/cmd/refresh" "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/cmd/deregister" "github.com/brevdev/brev-cli/pkg/cmd/reset" "github.com/brevdev/brev-cli/pkg/cmd/runtasks" "github.com/brevdev/brev-cli/pkg/cmd/scale" @@ -291,7 +292,8 @@ func createCmdTree(cmd *cobra.Command, t *terminal.Terminal, loginCmdStore *stor cmd.AddCommand(reset.NewCmdReset(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(profile.NewCmdProfile(t, loginCmdStore, noLoginCmdStore)) cmd.AddCommand(refresh.NewCmdRefresh(t, loginCmdStore)) - cmd.AddCommand(register.NewCmdRegister(t)) + cmd.AddCommand(register.NewCmdRegister(t, loginCmdStore)) + cmd.AddCommand(deregister.NewCmdDeregister(t, loginCmdStore)) cmd.AddCommand(runtasks.NewCmdRunTasks(t, noLoginCmdStore)) cmd.AddCommand(proxy.NewCmdProxy(t, noLoginCmdStore)) cmd.AddCommand(healthcheck.NewCmdHealthcheck(t, noLoginCmdStore)) diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go new file mode 100644 index 00000000..4dff4868 --- /dev/null +++ b/pkg/cmd/deregister/deregister.go @@ -0,0 +1,110 @@ +// Package deregister provides the brev deregister command for DGX Spark deregistration +package deregister + +import ( + "fmt" + "runtime" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/terminal" + + "github.com/spf13/cobra" +) + +// DeregisterStore defines the store methods needed by the deregister command. +type DeregisterStore interface { + GetCurrentUser() (*entity.User, error) + GetBrevHomePath() (string, error) +} + +var ( + deregisterLong = `Deregister your DGX Spark from NVIDIA Brev + +This command removes the local registration data and optionally uninstalls +netbird (network agent).` + + deregisterExample = ` brev deregister` +) + +func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Command { + cmd := &cobra.Command{ + Annotations: map[string]string{"configuration": ""}, + Use: "deregister", + DisableFlagsInUseLine: true, + Short: "Deregister your DGX Spark from Brev", + Long: deregisterLong, + Example: deregisterExample, + RunE: func(cmd *cobra.Command, args []string) error { + return runDeregister(t, store) + }, + } + + return cmd +} + +func runDeregister(t *terminal.Terminal, s DeregisterStore) error { + if runtime.GOOS != "linux" { + return fmt.Errorf("brev deregister is only supported on Linux (DGX Spark)") + } + + brevHome, err := s.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if !register.RegistrationExists(brevHome) { + return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your DGX Spark") + } + + reg, err := register.LoadRegistration(brevHome) + if err != nil { + return fmt.Errorf("failed to read registration file: %w", err) + } + + t.Vprint("") + t.Vprint(t.Green("Deregistering DGX Spark")) + t.Vprint("") + t.Vprintf(" Node ID: %s\n", reg.BrevCloudNodeID) + t.Vprintf(" Name: %s\n", reg.DisplayName) + t.Vprint("") + + removeNetbird := terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: "Would you also like to uninstall netbird?", + Items: []string{"Yes, uninstall netbird", "No, keep netbird installed"}, + }) + + confirm := terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: "Proceed with deregistration?", + Items: []string{"Yes, proceed", "No, cancel"}, + }) + if confirm != "Yes, proceed" { + t.Vprint("Deregistration cancelled.") + return nil + } + + t.Vprint("") + t.Vprint(t.Yellow("[TODO] Deregistration API call not yet implemented.")) + t.Vprint("") + + if removeNetbird == "Yes, uninstall netbird" { + t.Vprint("Removing netbird...") + if err := register.UninstallNetbird(t); err != nil { + t.Vprintf(" Warning: failed to uninstall netbird: %v\n", err) + } else { + t.Vprint(t.Green(" Netbird uninstalled.")) + } + t.Vprint("") + } + + t.Vprint("Removing local registration data...") + if err := register.DeleteRegistration(brevHome); err != nil { + return fmt.Errorf("failed to remove registration data: %w", err) + } + + t.Vprint(t.Green("Deregistration complete.")) + t.Vprint("") + + return nil +} diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go new file mode 100644 index 00000000..c6c4053f --- /dev/null +++ b/pkg/cmd/register/hardware.go @@ -0,0 +1,305 @@ +package register + +import ( + "bufio" + "fmt" + "os/exec" + "runtime" + "strconv" + "strings" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// CommandRunner abstracts command execution for testability. +type CommandRunner interface { + Run(name string, args ...string) ([]byte, error) +} + +// ExecCommandRunner is the real implementation that runs OS commands. +type ExecCommandRunner struct{} + +func (r ExecCommandRunner) Run(name string, args ...string) ([]byte, error) { + out, err := exec.Command(name, args...).Output() // #nosec G204 + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return out, nil +} + +// HardwareProfile mirrors the proto HardwareInfo message. +type HardwareProfile struct { + GPUs []GPUInfo `json:"gpus"` + RAMBytes int64 `json:"ram_bytes"` + CPUCount int `json:"cpu_count"` + CPUModel string `json:"cpu_model,omitempty"` + SystemVendor string `json:"system_vendor,omitempty"` + SystemModel string `json:"system_model,omitempty"` + Architecture string `json:"architecture,omitempty"` + OSName string `json:"os_name,omitempty"` + OSVersion string `json:"os_version,omitempty"` + Storage []StorageInfo `json:"storage,omitempty"` +} + +// GPUInfo describes a single GPU. +type GPUInfo struct { + Name string `json:"name"` + MemoryMB int64 `json:"memory_mb"` + DriverVersion string `json:"driver_version,omitempty"` + PCIBusID string `json:"pci_bus_id,omitempty"` +} + +// StorageInfo describes a block storage device. +type StorageInfo struct { + Name string `json:"name"` + Bytes int64 `json:"bytes"` + Type string `json:"type"` +} + +// FileReader abstracts file reading for testability. +type FileReader interface { + ReadFile(path string) ([]byte, error) +} + +// CollectHardwareProfile gathers system hardware information. +// CPU count/model and RAM are required; everything else is best-effort. +func CollectHardwareProfile(runner CommandRunner, reader FileReader) (*HardwareProfile, error) { + profile := &HardwareProfile{ + Architecture: runtime.GOARCH, + } + + cpuCount, cpuModel, err := parseCPUInfo(reader) + if err != nil { + return nil, fmt.Errorf("failed to read CPU info: %w", err) + } + profile.CPUCount = cpuCount + profile.CPUModel = cpuModel + + ramBytes, err := parseMemInfo(reader) + if err != nil { + return nil, fmt.Errorf("failed to read memory info: %w", err) + } + profile.RAMBytes = ramBytes + + osName, osVersion := parseOSRelease(reader) + profile.OSName = osName + profile.OSVersion = osVersion + + profile.SystemVendor = readSysFile(reader, "/sys/class/dmi/id/sys_vendor") + profile.SystemModel = readSysFile(reader, "/sys/class/dmi/id/product_name") + + profile.GPUs = parseNvidiaSMI(runner) + profile.Storage = parseLsblk(runner) + + return profile, nil +} + +// parseCPUInfo reads /proc/cpuinfo and returns (count, model). +func parseCPUInfo(reader FileReader) (int, string, error) { + data, err := reader.ReadFile("/proc/cpuinfo") + if err != nil { + return 0, "", breverrors.WrapAndTrace(err) + } + return parseCPUInfoContent(string(data)) +} + +// parseCPUInfoContent parses the content of /proc/cpuinfo. +func parseCPUInfoContent(content string) (int, string, error) { + count := 0 + model := "" + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "processor") { + count++ + } + if strings.HasPrefix(line, "model name") && model == "" { + parts := strings.SplitN(line, ":", 2) + if len(parts) == 2 { + model = strings.TrimSpace(parts[1]) + } + } + } + if count == 0 { + return 0, "", fmt.Errorf("no processors found in /proc/cpuinfo") + } + return count, model, nil +} + +// parseMemInfo reads /proc/meminfo and returns total RAM in bytes. +func parseMemInfo(reader FileReader) (int64, error) { + data, err := reader.ReadFile("/proc/meminfo") + if err != nil { + return 0, breverrors.WrapAndTrace(err) + } + return parseMemInfoContent(string(data)) +} + +// parseMemInfoContent parses the content of /proc/meminfo. +func parseMemInfoContent(content string) (int64, error) { + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "MemTotal:") { + fields := strings.Fields(line) + if len(fields) < 2 { + return 0, fmt.Errorf("unexpected MemTotal format: %s", line) + } + kb, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + return 0, fmt.Errorf("failed to parse MemTotal value: %w", err) + } + return kb * 1024, nil // convert kB to bytes + } + } + return 0, fmt.Errorf("MemTotal not found in /proc/meminfo") +} + +// parseOSRelease reads /etc/os-release and returns (name, version). +func parseOSRelease(reader FileReader) (string, string) { + data, err := reader.ReadFile("/etc/os-release") + if err != nil { + return "", "" + } + return parseOSReleaseContent(string(data)) +} + +// parseOSReleaseContent parses the content of /etc/os-release. +func parseOSReleaseContent(content string) (string, string) { + name := "" + version := "" + scanner := bufio.NewScanner(strings.NewReader(content)) + for scanner.Scan() { + line := scanner.Text() + if val, ok := strings.CutPrefix(line, "NAME="); ok { + name = unquote(val) + } + if val, ok := strings.CutPrefix(line, "VERSION_ID="); ok { + version = unquote(val) + } + } + return name, version +} + +// unquote removes surrounding double quotes from a string. +func unquote(s string) string { + s = strings.TrimSpace(s) + if len(s) >= 2 && s[0] == '"' && s[len(s)-1] == '"' { + return s[1 : len(s)-1] + } + return s +} + +// readSysFile reads a single-line sysfs file, returning empty string on failure. +func readSysFile(reader FileReader, path string) string { + data, err := reader.ReadFile(path) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// parseNvidiaSMI queries nvidia-smi for GPU information. +func parseNvidiaSMI(runner CommandRunner) []GPUInfo { + out, err := runner.Run("nvidia-smi", + "--query-gpu=name,memory.total,driver_version,pci.bus_id", + "--format=csv,noheader,nounits", + ) + if err != nil { + return nil + } + return parseNvidiaSMIOutput(string(out)) +} + +// parseNvidiaSMIOutput parses nvidia-smi CSV output into GPUInfo slices. +func parseNvidiaSMIOutput(output string) []GPUInfo { + var gpus []GPUInfo + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + parts := strings.Split(line, ", ") + if len(parts) < 2 { + continue + } + gpu := GPUInfo{ + Name: strings.TrimSpace(parts[0]), + } + memMB, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) + if err == nil { + gpu.MemoryMB = memMB + } + if len(parts) >= 3 { + gpu.DriverVersion = strings.TrimSpace(parts[2]) + } + if len(parts) >= 4 { + gpu.PCIBusID = strings.TrimSpace(parts[3]) + } + gpus = append(gpus, gpu) + } + return gpus +} + +// parseLsblk queries lsblk for block device information. +func parseLsblk(runner CommandRunner) []StorageInfo { + out, err := runner.Run("lsblk", "-b", "-d", "-n", "-o", "NAME,SIZE,TYPE") + if err != nil { + return nil + } + return parseLsblkOutput(string(out)) +} + +// parseLsblkOutput parses lsblk output into StorageInfo slices. +func parseLsblkOutput(output string) []StorageInfo { + var devices []StorageInfo + scanner := bufio.NewScanner(strings.NewReader(output)) + for scanner.Scan() { + fields := strings.Fields(scanner.Text()) + if len(fields) < 3 { + continue + } + size, err := strconv.ParseInt(fields[1], 10, 64) + if err != nil { + continue + } + devices = append(devices, StorageInfo{ + Name: fields[0], + Bytes: size, + Type: fields[2], + }) + } + return devices +} + +// FormatHardwareProfile returns a human-readable summary of the hardware profile. +func FormatHardwareProfile(p *HardwareProfile) string { + var b strings.Builder + fmt.Fprintf(&b, " CPU: %d x %s\n", p.CPUCount, p.CPUModel) + fmt.Fprintf(&b, " RAM: %d GB\n", p.RAMBytes/(1024*1024*1024)) + if len(p.GPUs) > 0 { + // Group GPUs by name + gpuCounts := make(map[string]int) + gpuMemory := make(map[string]int64) + var gpuOrder []string + for _, gpu := range p.GPUs { + if gpuCounts[gpu.Name] == 0 { + gpuOrder = append(gpuOrder, gpu.Name) + } + gpuCounts[gpu.Name]++ + gpuMemory[gpu.Name] = gpu.MemoryMB + } + for _, name := range gpuOrder { + memGB := gpuMemory[name] / 1024 + fmt.Fprintf(&b, " GPUs: %d x %s (%d GB)\n", gpuCounts[name], name, memGB) + } + } else { + b.WriteString(" GPUs: none detected\n") + } + fmt.Fprintf(&b, " Arch: %s\n", p.Architecture) + if p.OSName != "" || p.OSVersion != "" { + fmt.Fprintf(&b, " OS: %s %s\n", p.OSName, p.OSVersion) + } + return b.String() +} diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go new file mode 100644 index 00000000..27a948c6 --- /dev/null +++ b/pkg/cmd/register/hardware_test.go @@ -0,0 +1,201 @@ +package register + +import ( + "testing" +) + +func Test_parseCPUInfoContent_ValidInput(t *testing.T) { + content := `processor : 0 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 + +processor : 1 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 + +processor : 2 +vendor_id : AuthenticAMD +model name : AMD EPYC 7763 64-Core Processor +cpu MHz : 2450.000 +` + count, model, err := parseCPUInfoContent(content) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if count != 3 { + t.Errorf("expected 3 CPUs, got %d", count) + } + if model != "AMD EPYC 7763 64-Core Processor" { + t.Errorf("unexpected CPU model: %s", model) + } +} + +func Test_parseCPUInfoContent_EmptyInput(t *testing.T) { + _, _, err := parseCPUInfoContent("") + if err == nil { + t.Fatal("expected error for empty input") + } +} + +func Test_parseMemInfoContent_ValidInput(t *testing.T) { + content := `MemTotal: 131886028 kB +MemFree: 1234567 kB +MemAvailable: 98765432 kB +` + bytes, err := parseMemInfoContent(content) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + expected := int64(131886028) * 1024 + if bytes != expected { + t.Errorf("expected %d bytes, got %d", expected, bytes) + } +} + +func Test_parseMemInfoContent_MissingMemTotal(t *testing.T) { + content := `MemFree: 1234567 kB +MemAvailable: 98765432 kB +` + _, err := parseMemInfoContent(content) + if err == nil { + t.Fatal("expected error for missing MemTotal") + } +} + +func Test_parseOSReleaseContent(t *testing.T) { + content := `NAME="Ubuntu" +VERSION="24.04 LTS (Noble Numbat)" +ID=ubuntu +VERSION_ID="24.04" +PRETTY_NAME="Ubuntu 24.04 LTS" +` + name, version := parseOSReleaseContent(content) + if name != "Ubuntu" { + t.Errorf("expected Ubuntu, got %s", name) + } + if version != "24.04" { + t.Errorf("expected 24.04, got %s", version) + } +} + +func Test_parseOSReleaseContent_Unquoted(t *testing.T) { + content := `NAME=Fedora +VERSION_ID=39 +` + name, version := parseOSReleaseContent(content) + if name != "Fedora" { + t.Errorf("expected Fedora, got %s", name) + } + if version != "39" { + t.Errorf("expected 39, got %s", version) + } +} + +func Test_parseNvidiaSMIOutput(t *testing.T) { + output := `NVIDIA GB10, 131072, 570.86.15, 00000000:01:00.0 +NVIDIA GB10, 131072, 570.86.15, 00000000:02:00.0 +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 2 { + t.Fatalf("expected 2 GPUs, got %d", len(gpus)) + } + if gpus[0].Name != "NVIDIA GB10" { + t.Errorf("unexpected GPU name: %s", gpus[0].Name) + } + if gpus[0].MemoryMB != 131072 { + t.Errorf("expected 131072 MB, got %d", gpus[0].MemoryMB) + } + if gpus[0].DriverVersion != "570.86.15" { + t.Errorf("unexpected driver version: %s", gpus[0].DriverVersion) + } + if gpus[0].PCIBusID != "00000000:01:00.0" { + t.Errorf("unexpected PCI bus ID: %s", gpus[0].PCIBusID) + } +} + +func Test_parseNvidiaSMIOutput_Empty(t *testing.T) { + gpus := parseNvidiaSMIOutput("") + if len(gpus) != 0 { + t.Errorf("expected 0 GPUs, got %d", len(gpus)) + } +} + +func Test_parseLsblkOutput(t *testing.T) { + output := `sda 500107862016 disk +nvme0n1 1000204886016 disk +` + devices := parseLsblkOutput(output) + if len(devices) != 2 { + t.Fatalf("expected 2 devices, got %d", len(devices)) + } + if devices[0].Name != "sda" { + t.Errorf("unexpected device name: %s", devices[0].Name) + } + if devices[0].Bytes != 500107862016 { + t.Errorf("unexpected device size: %d", devices[0].Bytes) + } + if devices[0].Type != "disk" { + t.Errorf("unexpected device type: %s", devices[0].Type) + } +} + +func Test_unquote(t *testing.T) { + tests := []struct { + input string + want string + }{ + {`"Ubuntu"`, "Ubuntu"}, + {`Ubuntu`, "Ubuntu"}, + {`""`, ""}, + {`"a"`, "a"}, + {``, ""}, + } + for _, tt := range tests { + got := unquote(tt.input) + if got != tt.want { + t.Errorf("unquote(%q) = %q, want %q", tt.input, got, tt.want) + } + } +} + +func Test_FormatHardwareProfile(t *testing.T) { + p := &HardwareProfile{ + CPUCount: 12, + CPUModel: "AMD EPYC 7763", + RAMBytes: 137438953472, // 128 GB + Architecture: "arm64", + OSName: "Ubuntu", + OSVersion: "24.04", + GPUs: []GPUInfo{ + {Name: "NVIDIA GB10", MemoryMB: 131072}, + }, + } + output := FormatHardwareProfile(p) + if output == "" { + t.Fatal("expected non-empty output") + } + if !contains(output, "12 x AMD EPYC 7763") { + t.Errorf("expected CPU info in output: %s", output) + } + if !contains(output, "128 GB") { + t.Errorf("expected RAM info in output: %s", output) + } + if !contains(output, "NVIDIA GB10") { + t.Errorf("expected GPU info in output: %s", output) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchString(s, substr) +} + +func searchString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/pkg/cmd/register/identity.go b/pkg/cmd/register/identity.go new file mode 100644 index 00000000..4dee6efc --- /dev/null +++ b/pkg/cmd/register/identity.go @@ -0,0 +1,136 @@ +package register + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "path/filepath" + "sort" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/files" +) + +const registrationFileName = "spark_registration.json" + +// SparkRegistration is the persistent identity file for a registered DGX Spark. +type SparkRegistration struct { + BrevCloudNodeID string `json:"brev_cloud_node_id"` + DisplayName string `json:"display_name"` + OrgID string `json:"org_id"` + HardwareFingerprint string `json:"hardware_fingerprint"` + DeviceFingerprintHash string `json:"device_fingerprint_hash"` + RegisteredAt string `json:"registered_at"` + HardwareProfile HardwareProfile `json:"hardware_profile"` +} + +// HardwareDescriptor captures high-level hardware traits for fingerprinting. +// Must produce byte-identical JSON to dev-plane's HardwareDescriptor. +type HardwareDescriptor struct { + GPUs []GPUDescriptor `json:"gpus"` + RAM int64 `json:"ram_bytes"` + CPUs int `json:"cpus"` +} + +// GPUDescriptor describes a single GPU for fingerprinting. +type GPUDescriptor struct { + Model string `json:"model"` + Memory int64 `json:"memory_bytes"` +} + +// ComputeHardwareFingerprint returns a deterministic SHA-256 fingerprint. +// This must produce identical output to dev-plane's ComputeHardwareFingerprint +// for the same input. +func ComputeHardwareFingerprint(desc HardwareDescriptor) (string, error) { + if desc.CPUs < 1 { + return "", fmt.Errorf("CPUs must be at least 1") + } + if desc.RAM < 1 { + return "", fmt.Errorf("RAM must be at least 1") + } + for _, gpu := range desc.GPUs { + if gpu.Model == "" { + return "", fmt.Errorf("GPU model must not be empty") + } + if gpu.Memory < 1 { + return "", fmt.Errorf("GPU memory must be at least 1") + } + } + + // Sort GPUs by model then memory for stable ordering. + gpus := make([]GPUDescriptor, len(desc.GPUs)) + copy(gpus, desc.GPUs) + sort.Slice(gpus, func(i, j int) bool { + if gpus[i].Model == gpus[j].Model { + return gpus[i].Memory < gpus[j].Memory + } + return gpus[i].Model < gpus[j].Model + }) + desc.GPUs = gpus + + payload, err := json.Marshal(desc) + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + sum := sha256.Sum256(payload) + return hex.EncodeToString(sum[:]), nil +} + +// HardwareProfileToDescriptor converts a HardwareProfile to a HardwareDescriptor +// for fingerprinting. +func HardwareProfileToDescriptor(p *HardwareProfile) HardwareDescriptor { + desc := HardwareDescriptor{ + RAM: p.RAMBytes, + CPUs: p.CPUCount, + } + for _, gpu := range p.GPUs { + desc.GPUs = append(desc.GPUs, GPUDescriptor{ + Model: gpu.Name, + Memory: gpu.MemoryMB * 1024 * 1024, // convert MB to bytes + }) + } + return desc +} + +func registrationPath(brevHome string) string { + return filepath.Join(brevHome, registrationFileName) +} + +// SaveRegistration writes the registration to ~/.brev/spark_registration.json. +func SaveRegistration(brevHome string, reg *SparkRegistration) error { + path := registrationPath(brevHome) + err := files.OverwriteJSON(files.AppFs, path, reg) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +// LoadRegistration reads the registration from ~/.brev/spark_registration.json. +func LoadRegistration(brevHome string) (*SparkRegistration, error) { + path := registrationPath(brevHome) + var reg SparkRegistration + err := files.ReadJSON(files.AppFs, path, ®) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return ®, nil +} + +// DeleteRegistration removes ~/.brev/spark_registration.json. +func DeleteRegistration(brevHome string) error { + path := registrationPath(brevHome) + err := files.DeleteFile(files.AppFs, path) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + +// RegistrationExists checks if a registration file exists. +func RegistrationExists(brevHome string) bool { + path := registrationPath(brevHome) + exists, _ := files.AppFs.Stat(path) + return exists != nil +} diff --git a/pkg/cmd/register/identity_test.go b/pkg/cmd/register/identity_test.go new file mode 100644 index 00000000..93009dba --- /dev/null +++ b/pkg/cmd/register/identity_test.go @@ -0,0 +1,210 @@ +package register + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "sort" + "testing" +) + +// Test_ComputeHardwareFingerprint_Deterministic verifies that the same input +// always produces the same fingerprint. +func Test_ComputeHardwareFingerprint_Deterministic(t *testing.T) { + desc := HardwareDescriptor{ + GPUs: []GPUDescriptor{ + {Model: "NVIDIA GB10", Memory: 137438953472}, + }, + RAM: 137438953472, + CPUs: 12, + } + + fp1, err := ComputeHardwareFingerprint(desc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + fp2, err := ComputeHardwareFingerprint(desc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fp1 != fp2 { + t.Errorf("fingerprints differ: %s != %s", fp1, fp2) + } +} + +// Test_ComputeHardwareFingerprint_GPUOrderIndependent verifies that GPU order +// does not affect the fingerprint. +func Test_ComputeHardwareFingerprint_GPUOrderIndependent(t *testing.T) { + desc1 := HardwareDescriptor{ + GPUs: []GPUDescriptor{ + {Model: "NVIDIA A100", Memory: 85899345920}, + {Model: "NVIDIA GB10", Memory: 137438953472}, + }, + RAM: 274877906944, + CPUs: 64, + } + desc2 := HardwareDescriptor{ + GPUs: []GPUDescriptor{ + {Model: "NVIDIA GB10", Memory: 137438953472}, + {Model: "NVIDIA A100", Memory: 85899345920}, + }, + RAM: 274877906944, + CPUs: 64, + } + + fp1, err := ComputeHardwareFingerprint(desc1) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + fp2, err := ComputeHardwareFingerprint(desc2) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fp1 != fp2 { + t.Errorf("fingerprints should be identical regardless of GPU order: %s != %s", fp1, fp2) + } +} + +// Test_ComputeHardwareFingerprint_ByteIdenticalToDevPlane verifies that our +// fingerprint is byte-identical to what dev-plane would produce. We replicate +// the dev-plane logic inline to prove equivalence. +func Test_ComputeHardwareFingerprint_ByteIdenticalToDevPlane(t *testing.T) { + desc := HardwareDescriptor{ + GPUs: []GPUDescriptor{ + {Model: "NVIDIA GB10", Memory: 137438953472}, + {Model: "NVIDIA A100", Memory: 85899345920}, + }, + RAM: 274877906944, + CPUs: 64, + } + + // Compute using our function + got, err := ComputeHardwareFingerprint(desc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Replicate dev-plane logic exactly + gpus := make([]GPUDescriptor, len(desc.GPUs)) + copy(gpus, desc.GPUs) + sort.Slice(gpus, func(i, j int) bool { + if gpus[i].Model == gpus[j].Model { + return gpus[i].Memory < gpus[j].Memory + } + return gpus[i].Model < gpus[j].Model + }) + // Build the same struct shape dev-plane uses + type devPlaneGPU struct { + Model string `json:"model"` + Memory int64 `json:"memory_bytes"` + } + type devPlaneDesc struct { + GPUs []devPlaneGPU `json:"gpus"` + RAM int64 `json:"ram_bytes"` + CPUs int `json:"cpus"` + } + dpGPUs := make([]devPlaneGPU, len(gpus)) + for i, g := range gpus { + dpGPUs[i] = devPlaneGPU{Model: g.Model, Memory: g.Memory} + } + dpDesc := devPlaneDesc{GPUs: dpGPUs, RAM: desc.RAM, CPUs: desc.CPUs} + payload, err := json.Marshal(dpDesc) + if err != nil { + t.Fatalf("json.Marshal failed: %v", err) + } + sum := sha256.Sum256(payload) + want := hex.EncodeToString(sum[:]) + + if got != want { + t.Errorf("fingerprint mismatch with dev-plane logic:\ngot: %s\nwant: %s", got, want) + } +} + +// Test_ComputeHardwareFingerprint_NoGPUs verifies fingerprinting works with +// no GPUs present. +func Test_ComputeHardwareFingerprint_NoGPUs(t *testing.T) { + desc := HardwareDescriptor{ + GPUs: nil, + RAM: 8589934592, + CPUs: 4, + } + fp, err := ComputeHardwareFingerprint(desc) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fp == "" { + t.Fatal("expected non-empty fingerprint") + } +} + +// Test_ComputeHardwareFingerprint_ValidationErrors verifies validation. +func Test_ComputeHardwareFingerprint_ValidationErrors(t *testing.T) { + tests := []struct { + name string + desc HardwareDescriptor + }{ + { + name: "zero CPUs", + desc: HardwareDescriptor{RAM: 1024, CPUs: 0}, + }, + { + name: "zero RAM", + desc: HardwareDescriptor{RAM: 0, CPUs: 1}, + }, + { + name: "GPU with empty model", + desc: HardwareDescriptor{ + RAM: 1024, + CPUs: 1, + GPUs: []GPUDescriptor{{Model: "", Memory: 1024}}, + }, + }, + { + name: "GPU with zero memory", + desc: HardwareDescriptor{ + RAM: 1024, + CPUs: 1, + GPUs: []GPUDescriptor{{Model: "test", Memory: 0}}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ComputeHardwareFingerprint(tt.desc) + if err == nil { + t.Error("expected validation error") + } + }) + } +} + +// Test_HardwareProfileToDescriptor verifies the conversion. +func Test_HardwareProfileToDescriptor(t *testing.T) { + profile := &HardwareProfile{ + CPUCount: 12, + RAMBytes: 137438953472, + GPUs: []GPUInfo{ + {Name: "NVIDIA GB10", MemoryMB: 131072}, + }, + } + + desc := HardwareProfileToDescriptor(profile) + if desc.CPUs != 12 { + t.Errorf("expected 12 CPUs, got %d", desc.CPUs) + } + if desc.RAM != 137438953472 { + t.Errorf("expected RAM 137438953472, got %d", desc.RAM) + } + if len(desc.GPUs) != 1 { + t.Fatalf("expected 1 GPU, got %d", len(desc.GPUs)) + } + if desc.GPUs[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected GPU model: %s", desc.GPUs[0].Model) + } + // 131072 MB = 131072 * 1024 * 1024 bytes + expectedMem := int64(131072) * 1024 * 1024 + if desc.GPUs[0].Memory != expectedMem { + t.Errorf("expected GPU memory %d, got %d", expectedMem, desc.GPUs[0].Memory) + } +} diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go new file mode 100644 index 00000000..6eb1cff6 --- /dev/null +++ b/pkg/cmd/register/netbird.go @@ -0,0 +1,37 @@ +package register + +import ( + "fmt" + "os" + "os/exec" + + "github.com/brevdev/brev-cli/pkg/terminal" +) + +// InstallNetbird downloads and installs netbird using the official install script. +func InstallNetbird(t *terminal.Terminal) error { + script := `(curl -fsSL https://pkgs.netbird.io/install.sh | sh) || (curl -fsSL https://pkgs.netbird.io/install.sh | sh -s -- --update)` + + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to install netbird: %w", err) + } + return nil +} + +// UninstallNetbird stops, uninstalls, and removes netbird. +func UninstallNetbird(t *terminal.Terminal) error { + script := `netbird service stop && netbird service uninstall && sudo apt-get remove -y netbird` + + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to uninstall netbird: %w", err) + } + return nil +} diff --git a/pkg/cmd/register/osfile.go b/pkg/cmd/register/osfile.go new file mode 100644 index 00000000..b4906e31 --- /dev/null +++ b/pkg/cmd/register/osfile.go @@ -0,0 +1,8 @@ +package register + +import "os" + +// readOSFile reads a file from the real filesystem. +func readOSFile(path string) ([]byte, error) { + return os.ReadFile(path) // #nosec G304 +} diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 1c16586b..aa430cd9 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -2,21 +2,43 @@ package register import ( + "fmt" + "os/user" + "runtime" + + "github.com/brevdev/brev-cli/pkg/entity" + breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" ) +// RegisterStore defines the store methods needed by the register command. +type RegisterStore interface { + GetCurrentUser() (*entity.User, error) + GetActiveOrganizationOrDefault() (*entity.Organization, error) + GetBrevHomePath() (string, error) +} + +// OSFileReader reads files from the real OS filesystem. +type OSFileReader struct{} + +func (r OSFileReader) ReadFile(path string) ([]byte, error) { + return readOSFile(path) +} + var ( registerLong = `Register your DGX Spark with NVIDIA Brev -Join the waitlist to be among the first to register your DGX Spark -for early access integration with Brev.` +This command installs netbird (network agent), collects a hardware profile, +and registers this machine with Brev.` - registerExample = ` brev register` + registerExample = ` brev register --name "My DGX Spark"` ) -func NewCmdRegister(t *terminal.Terminal) *cobra.Command { +func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { + var name string + cmd := &cobra.Command{ Annotations: map[string]string{"configuration": ""}, Use: "register", @@ -26,19 +48,106 @@ func NewCmdRegister(t *terminal.Terminal) *cobra.Command { Long: registerLong, Example: registerExample, RunE: func(cmd *cobra.Command, args []string) error { - runRegister(t) - return nil + return runRegister(t, store, name) }, } + cmd.Flags().StringVarP(&name, "name", "n", "", "Display name for this DGX Spark (required)") + _ = cmd.MarkFlagRequired("name") + return cmd } -func runRegister(t *terminal.Terminal) { - t.Vprint("\n") - t.Vprint(t.Green("Thanks so much for your interest in registering your DGX Spark with Brev!\n\n")) - t.Vprint("To be on the waitlist for early access to this feature, please fill out this form:\n\n") - t.Vprint(t.Yellow(" 👉 https://forms.gle/RHCHGmZuiMQQ2faA6\n\n")) - t.Vprint("We will reach out to the provided email with updates and instructions on how to register soon (:\n") - t.Vprint("\n") +func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //nolint:funlen // registration flow + if runtime.GOOS != "linux" { + return fmt.Errorf("brev register is only supported on Linux (DGX Spark)") + } + + currentUser, err := s.GetCurrentUser() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + org, err := s.GetActiveOrganizationOrDefault() + if err != nil { + return breverrors.WrapAndTrace(err) + } + if org == nil { + return fmt.Errorf("no organization found; please create or join an organization first") + } + + brevHome, err := s.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if RegistrationExists(brevHome) { + return fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") + } + + linuxUser := currentUser.Username + if u, err := user.Current(); err == nil { + linuxUser = u.Username + } + + t.Vprint("") + t.Vprint(t.Green("Registering your DGX Spark with Brev")) + t.Vprint("") + t.Vprintf(" Name: %s\n", t.Yellow(name)) + t.Vprintf(" Organization: %s\n", org.Name) + t.Vprintf(" Linux user: %s\n", linuxUser) + t.Vprint("") + t.Vprint("This will perform the following steps:") + t.Vprint(" 1. Install netbird (network agent)") + t.Vprint(" 2. Register this machine with Brev") + t.Vprint("") + + result := terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: "Proceed with registration?", + Items: []string{"Yes, proceed", "No, cancel"}, + }) + if result != "Yes, proceed" { + t.Vprint("Registration cancelled.") + return nil + } + + t.Vprint("") + t.Vprint(t.Yellow("[Step 1/2] Installing netbird...")) + if err := InstallNetbird(t); err != nil { + return fmt.Errorf("netbird installation failed: %w", err) + } + t.Vprint(t.Green(" Netbird installed successfully.")) + + t.Vprint("") + t.Vprint(t.Yellow("[Step 2/2] Collecting hardware profile...")) + t.Vprint("") + + runner := ExecCommandRunner{} + reader := OSFileReader{} + profile, err := CollectHardwareProfile(runner, reader) + if err != nil { + return fmt.Errorf("failed to collect hardware profile: %w", err) + } + + t.Vprint(" Hardware profile:") + t.Vprint(FormatHardwareProfile(profile)) + + desc := HardwareProfileToDescriptor(profile) + fingerprint, err := ComputeHardwareFingerprint(desc) + if err != nil { + t.Vprintf(" Warning: could not compute hardware fingerprint: %v\n", err) + } else { + t.Vprintf(" Fingerprint: %s\n", fingerprint) + } + + t.Vprint("") + t.Vprint(t.Yellow("[TODO] Registration API call not yet implemented.")) + t.Vprint(" Once implemented, the backend will return a node ID") + t.Vprintf(" that will be persisted to %s/spark_registration.json.\n", brevHome) + t.Vprint("") + + _ = org.ID // will be used in the registration API call + _ = name // will be sent as display_name + + return nil } From 9fbe87d2f10cb12e3453f26d925386958c2fbe45 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Wed, 25 Feb 2026 15:22:09 -0800 Subject: [PATCH 02/11] removed fingerprinting. upated rpc --- pkg/cmd/cmd.go | 2 +- pkg/cmd/deregister/deregister.go | 4 +- pkg/cmd/register/hardware.go | 253 +++++++++++++++--------------- pkg/cmd/register/hardware_test.go | 123 ++++++++------- pkg/cmd/register/identity.go | 87 +--------- pkg/cmd/register/identity_test.go | 210 ------------------------- pkg/cmd/register/osfile.go | 8 - pkg/cmd/register/register.go | 27 ++-- 8 files changed, 213 insertions(+), 501 deletions(-) delete mode 100644 pkg/cmd/register/identity_test.go delete mode 100644 pkg/cmd/register/osfile.go diff --git a/pkg/cmd/cmd.go b/pkg/cmd/cmd.go index 572d9741..f9522bfb 100644 --- a/pkg/cmd/cmd.go +++ b/pkg/cmd/cmd.go @@ -11,6 +11,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/connect" "github.com/brevdev/brev-cli/pkg/cmd/copy" "github.com/brevdev/brev-cli/pkg/cmd/delete" + "github.com/brevdev/brev-cli/pkg/cmd/deregister" "github.com/brevdev/brev-cli/pkg/cmd/envvars" "github.com/brevdev/brev-cli/pkg/cmd/exec" "github.com/brevdev/brev-cli/pkg/cmd/fu" @@ -35,7 +36,6 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/redeem" "github.com/brevdev/brev-cli/pkg/cmd/refresh" "github.com/brevdev/brev-cli/pkg/cmd/register" - "github.com/brevdev/brev-cli/pkg/cmd/deregister" "github.com/brevdev/brev-cli/pkg/cmd/reset" "github.com/brevdev/brev-cli/pkg/cmd/runtasks" "github.com/brevdev/brev-cli/pkg/cmd/scale" diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 4dff4868..48ca16e2 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -66,7 +66,7 @@ func runDeregister(t *terminal.Terminal, s DeregisterStore) error { t.Vprint("") t.Vprint(t.Green("Deregistering DGX Spark")) t.Vprint("") - t.Vprintf(" Node ID: %s\n", reg.BrevCloudNodeID) + t.Vprintf(" Node ID: %s\n", reg.ExternalNodeID) t.Vprintf(" Name: %s\n", reg.DisplayName) t.Vprint("") @@ -80,7 +80,7 @@ func runDeregister(t *terminal.Terminal, s DeregisterStore) error { Items: []string{"Yes, proceed", "No, cancel"}, }) if confirm != "Yes, proceed" { - t.Vprint("Deregistration cancelled.") + t.Vprint("Deregistration canceled.") return nil } diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go index c6c4053f..31aa7f02 100644 --- a/pkg/cmd/register/hardware.go +++ b/pkg/cmd/register/hardware.go @@ -27,33 +27,24 @@ func (r ExecCommandRunner) Run(name string, args ...string) ([]byte, error) { return out, nil } -// HardwareProfile mirrors the proto HardwareInfo message. -type HardwareProfile struct { - GPUs []GPUInfo `json:"gpus"` - RAMBytes int64 `json:"ram_bytes"` - CPUCount int `json:"cpu_count"` - CPUModel string `json:"cpu_model,omitempty"` - SystemVendor string `json:"system_vendor,omitempty"` - SystemModel string `json:"system_model,omitempty"` - Architecture string `json:"architecture,omitempty"` - OSName string `json:"os_name,omitempty"` - OSVersion string `json:"os_version,omitempty"` - Storage []StorageInfo `json:"storage,omitempty"` +// NodeSpec matches the proto NodeSpec message from dev-plane. +// All fields are best-effort. +type NodeSpec struct { + GPUs []NodeGPU `json:"gpus"` + RAMBytes *int64 `json:"ram_bytes,omitempty"` + CPUCount *int32 `json:"cpu_count,omitempty"` + Architecture string `json:"architecture,omitempty"` + StorageBytes *int64 `json:"storage_bytes,omitempty"` + StorageType string `json:"storage_type,omitempty"` + OS string `json:"os,omitempty"` + OSVersion string `json:"os_version,omitempty"` } -// GPUInfo describes a single GPU. -type GPUInfo struct { - Name string `json:"name"` - MemoryMB int64 `json:"memory_mb"` - DriverVersion string `json:"driver_version,omitempty"` - PCIBusID string `json:"pci_bus_id,omitempty"` -} - -// StorageInfo describes a block storage device. -type StorageInfo struct { - Name string `json:"name"` - Bytes int64 `json:"bytes"` - Type string `json:"type"` +// NodeGPU matches the proto NodeGPU message. +type NodeGPU struct { + Model string `json:"model"` + Count int32 `json:"count"` + MemoryBytes *int64 `json:"memory_bytes,omitempty"` } // FileReader abstracts file reading for testability. @@ -62,68 +53,60 @@ type FileReader interface { } // CollectHardwareProfile gathers system hardware information. -// CPU count/model and RAM are required; everything else is best-effort. -func CollectHardwareProfile(runner CommandRunner, reader FileReader) (*HardwareProfile, error) { - profile := &HardwareProfile{ +// All fields are best-effort; failures are silently ignored. +func CollectHardwareProfile(runner CommandRunner, reader FileReader) (*NodeSpec, error) { + spec := &NodeSpec{ Architecture: runtime.GOARCH, } - cpuCount, cpuModel, err := parseCPUInfo(reader) - if err != nil { - return nil, fmt.Errorf("failed to read CPU info: %w", err) + if gpus, err := parseNvidiaSMI(runner); err == nil { + spec.GPUs = gpus } - profile.CPUCount = cpuCount - profile.CPUModel = cpuModel - ramBytes, err := parseMemInfo(reader) - if err != nil { - return nil, fmt.Errorf("failed to read memory info: %w", err) + if cpuCount, err := parseCPUCount(reader); err == nil { + count32 := int32(cpuCount) + spec.CPUCount = &count32 } - profile.RAMBytes = ramBytes - osName, osVersion := parseOSRelease(reader) - profile.OSName = osName - profile.OSVersion = osVersion + if ramBytes, err := parseMemInfo(reader); err == nil { + spec.RAMBytes = &ramBytes + } - profile.SystemVendor = readSysFile(reader, "/sys/class/dmi/id/sys_vendor") - profile.SystemModel = readSysFile(reader, "/sys/class/dmi/id/product_name") + osName, osVersion := parseOSRelease(reader) + spec.OS = osName + spec.OSVersion = osVersion - profile.GPUs = parseNvidiaSMI(runner) - profile.Storage = parseLsblk(runner) + storageBytes, storageType := collectStorage(runner) + if storageBytes > 0 { + spec.StorageBytes = &storageBytes + spec.StorageType = storageType + } - return profile, nil + return spec, nil } -// parseCPUInfo reads /proc/cpuinfo and returns (count, model). -func parseCPUInfo(reader FileReader) (int, string, error) { +// parseCPUCount reads /proc/cpuinfo and returns the number of logical processors. +func parseCPUCount(reader FileReader) (int, error) { data, err := reader.ReadFile("/proc/cpuinfo") if err != nil { - return 0, "", breverrors.WrapAndTrace(err) + return 0, breverrors.WrapAndTrace(err) } - return parseCPUInfoContent(string(data)) + return parseCPUCountContent(string(data)) } -// parseCPUInfoContent parses the content of /proc/cpuinfo. -func parseCPUInfoContent(content string) (int, string, error) { +// parseCPUCountContent parses the content of /proc/cpuinfo for processor count. +func parseCPUCountContent(content string) (int, error) { count := 0 - model := "" scanner := bufio.NewScanner(strings.NewReader(content)) for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "processor") { + if strings.HasPrefix(scanner.Text(), "processor") { count++ } - if strings.HasPrefix(line, "model name") && model == "" { - parts := strings.SplitN(line, ":", 2) - if len(parts) == 2 { - model = strings.TrimSpace(parts[1]) - } - } } if count == 0 { - return 0, "", fmt.Errorf("no processors found in /proc/cpuinfo") + return 0, fmt.Errorf("no processors found in /proc/cpuinfo") } - return count, model, nil + return count, nil } // parseMemInfo reads /proc/meminfo and returns total RAM in bytes. @@ -190,30 +173,34 @@ func unquote(s string) string { return s } -// readSysFile reads a single-line sysfs file, returning empty string on failure. -func readSysFile(reader FileReader, path string) string { - data, err := reader.ReadFile(path) - if err != nil { - return "" - } - return strings.TrimSpace(string(data)) -} - // parseNvidiaSMI queries nvidia-smi for GPU information. -func parseNvidiaSMI(runner CommandRunner) []GPUInfo { +// Returns an error if nvidia-smi fails or no GPUs are found. +func parseNvidiaSMI(runner CommandRunner) ([]NodeGPU, error) { out, err := runner.Run("nvidia-smi", - "--query-gpu=name,memory.total,driver_version,pci.bus_id", + "--query-gpu=name,memory.total", "--format=csv,noheader,nounits", ) if err != nil { - return nil + return nil, fmt.Errorf("nvidia-smi not available: %w", err) } - return parseNvidiaSMIOutput(string(out)) + gpus := parseNvidiaSMIOutput(string(out)) + if len(gpus) == 0 { + return nil, fmt.Errorf("nvidia-smi returned no GPUs") + } + return gpus, nil } -// parseNvidiaSMIOutput parses nvidia-smi CSV output into GPUInfo slices. -func parseNvidiaSMIOutput(output string) []GPUInfo { - var gpus []GPUInfo +// parseNvidiaSMIOutput parses nvidia-smi CSV output, grouping identical GPU +// models into a single NodeGPU with a count. +func parseNvidiaSMIOutput(output string) []NodeGPU { + type gpuKey struct { + model string + memoryBytes int64 + } + + counts := make(map[gpuKey]int32) + var order []gpuKey + scanner := bufio.NewScanner(strings.NewReader(output)) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -224,82 +211,94 @@ func parseNvidiaSMIOutput(output string) []GPUInfo { if len(parts) < 2 { continue } - gpu := GPUInfo{ - Name: strings.TrimSpace(parts[0]), - } + model := strings.TrimSpace(parts[0]) memMB, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) - if err == nil { - gpu.MemoryMB = memMB - } - if len(parts) >= 3 { - gpu.DriverVersion = strings.TrimSpace(parts[2]) + if err != nil { + continue } - if len(parts) >= 4 { - gpu.PCIBusID = strings.TrimSpace(parts[3]) + key := gpuKey{model: model, memoryBytes: memMB * 1024 * 1024} + if counts[key] == 0 { + order = append(order, key) } - gpus = append(gpus, gpu) + counts[key]++ + } + + gpus := make([]NodeGPU, 0, len(order)) + for _, key := range order { + mem := key.memoryBytes + gpus = append(gpus, NodeGPU{ + Model: key.model, + Count: counts[key], + MemoryBytes: &mem, + }) } return gpus } -// parseLsblk queries lsblk for block device information. -func parseLsblk(runner CommandRunner) []StorageInfo { +// collectStorage sums disk devices from lsblk to get total storage bytes +// and infers a storage type from the device names. +func collectStorage(runner CommandRunner) (int64, string) { out, err := runner.Run("lsblk", "-b", "-d", "-n", "-o", "NAME,SIZE,TYPE") if err != nil { - return nil + return 0, "" } - return parseLsblkOutput(string(out)) + return parseStorageOutput(string(out)) } -// parseLsblkOutput parses lsblk output into StorageInfo slices. -func parseLsblkOutput(output string) []StorageInfo { - var devices []StorageInfo +// parseStorageOutput parses lsblk output, summing disk device sizes and +// inferring storage type. +func parseStorageOutput(output string) (int64, string) { + var totalBytes int64 + storageType := "" scanner := bufio.NewScanner(strings.NewReader(output)) for scanner.Scan() { fields := strings.Fields(scanner.Text()) - if len(fields) < 3 { + if len(fields) < 3 || fields[2] != "disk" { continue } size, err := strconv.ParseInt(fields[1], 10, 64) if err != nil { continue } - devices = append(devices, StorageInfo{ - Name: fields[0], - Bytes: size, - Type: fields[2], - }) + totalBytes += size + if storageType == "" { + if strings.HasPrefix(fields[0], "nvme") { + storageType = "NVMe" + } else { + storageType = "SSD" + } + } } - return devices + return totalBytes, storageType } -// FormatHardwareProfile returns a human-readable summary of the hardware profile. -func FormatHardwareProfile(p *HardwareProfile) string { +// FormatNodeSpec returns a human-readable summary of the hardware profile. +func FormatNodeSpec(s *NodeSpec) string { var b strings.Builder - fmt.Fprintf(&b, " CPU: %d x %s\n", p.CPUCount, p.CPUModel) - fmt.Fprintf(&b, " RAM: %d GB\n", p.RAMBytes/(1024*1024*1024)) - if len(p.GPUs) > 0 { - // Group GPUs by name - gpuCounts := make(map[string]int) - gpuMemory := make(map[string]int64) - var gpuOrder []string - for _, gpu := range p.GPUs { - if gpuCounts[gpu.Name] == 0 { - gpuOrder = append(gpuOrder, gpu.Name) - } - gpuCounts[gpu.Name]++ - gpuMemory[gpu.Name] = gpu.MemoryMB - } - for _, name := range gpuOrder { - memGB := gpuMemory[name] / 1024 - fmt.Fprintf(&b, " GPUs: %d x %s (%d GB)\n", gpuCounts[name], name, memGB) + if s.CPUCount != nil { + fmt.Fprintf(&b, " CPU: %d cores\n", *s.CPUCount) + } + if s.RAMBytes != nil { + fmt.Fprintf(&b, " RAM: %d GB\n", *s.RAMBytes/(1024*1024*1024)) + } + for _, gpu := range s.GPUs { + if gpu.MemoryBytes != nil { + memGB := *gpu.MemoryBytes / (1024 * 1024 * 1024) + fmt.Fprintf(&b, " GPUs: %d x %s (%d GB)\n", gpu.Count, gpu.Model, memGB) + } else { + fmt.Fprintf(&b, " GPUs: %d x %s\n", gpu.Count, gpu.Model) } - } else { - b.WriteString(" GPUs: none detected\n") } - fmt.Fprintf(&b, " Arch: %s\n", p.Architecture) - if p.OSName != "" || p.OSVersion != "" { - fmt.Fprintf(&b, " OS: %s %s\n", p.OSName, p.OSVersion) + fmt.Fprintf(&b, " Arch: %s\n", s.Architecture) + if s.OS != "" || s.OSVersion != "" { + fmt.Fprintf(&b, " OS: %s %s\n", s.OS, s.OSVersion) + } + if s.StorageBytes != nil { + fmt.Fprintf(&b, " Storage: %d GB", *s.StorageBytes/(1024*1024*1024)) + if s.StorageType != "" { + fmt.Fprintf(&b, " (%s)", s.StorageType) + } + b.WriteString("\n") } return b.String() } diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go index 27a948c6..edf0e4ef 100644 --- a/pkg/cmd/register/hardware_test.go +++ b/pkg/cmd/register/hardware_test.go @@ -1,10 +1,11 @@ package register import ( + "strings" "testing" ) -func Test_parseCPUInfoContent_ValidInput(t *testing.T) { +func Test_parseCPUCountContent_ValidInput(t *testing.T) { content := `processor : 0 vendor_id : AuthenticAMD model name : AMD EPYC 7763 64-Core Processor @@ -20,20 +21,17 @@ vendor_id : AuthenticAMD model name : AMD EPYC 7763 64-Core Processor cpu MHz : 2450.000 ` - count, model, err := parseCPUInfoContent(content) + count, err := parseCPUCountContent(content) if err != nil { t.Fatalf("unexpected error: %v", err) } if count != 3 { t.Errorf("expected 3 CPUs, got %d", count) } - if model != "AMD EPYC 7763 64-Core Processor" { - t.Errorf("unexpected CPU model: %s", model) - } } -func Test_parseCPUInfoContent_EmptyInput(t *testing.T) { - _, _, err := parseCPUInfoContent("") +func Test_parseCPUCountContent_EmptyInput(t *testing.T) { + _, err := parseCPUCountContent("") if err == nil { t.Fatal("expected error for empty input") } @@ -93,25 +91,40 @@ VERSION_ID=39 } } -func Test_parseNvidiaSMIOutput(t *testing.T) { - output := `NVIDIA GB10, 131072, 570.86.15, 00000000:01:00.0 -NVIDIA GB10, 131072, 570.86.15, 00000000:02:00.0 +func Test_parseNvidiaSMIOutput_GroupsByModel(t *testing.T) { + output := `NVIDIA GB10, 131072 +NVIDIA GB10, 131072 ` gpus := parseNvidiaSMIOutput(output) - if len(gpus) != 2 { - t.Fatalf("expected 2 GPUs, got %d", len(gpus)) + if len(gpus) != 1 { + t.Fatalf("expected 1 GPU group, got %d", len(gpus)) } - if gpus[0].Name != "NVIDIA GB10" { - t.Errorf("unexpected GPU name: %s", gpus[0].Name) + if gpus[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected GPU model: %s", gpus[0].Model) } - if gpus[0].MemoryMB != 131072 { - t.Errorf("expected 131072 MB, got %d", gpus[0].MemoryMB) + if gpus[0].Count != 2 { + t.Errorf("expected count 2, got %d", gpus[0].Count) } - if gpus[0].DriverVersion != "570.86.15" { - t.Errorf("unexpected driver version: %s", gpus[0].DriverVersion) + expectedMem := int64(131072) * 1024 * 1024 + if gpus[0].MemoryBytes == nil || *gpus[0].MemoryBytes != expectedMem { + t.Errorf("expected %d bytes, got %v", expectedMem, gpus[0].MemoryBytes) } - if gpus[0].PCIBusID != "00000000:01:00.0" { - t.Errorf("unexpected PCI bus ID: %s", gpus[0].PCIBusID) +} + +func Test_parseNvidiaSMIOutput_MultipleModels(t *testing.T) { + output := `NVIDIA A100, 81920 +NVIDIA GB10, 131072 +NVIDIA A100, 81920 +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 2 { + t.Fatalf("expected 2 GPU groups, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA A100" || gpus[0].Count != 2 { + t.Errorf("expected 2x NVIDIA A100, got %dx %s", gpus[0].Count, gpus[0].Model) + } + if gpus[1].Model != "NVIDIA GB10" || gpus[1].Count != 1 { + t.Errorf("expected 1x NVIDIA GB10, got %dx %s", gpus[1].Count, gpus[1].Model) } } @@ -122,22 +135,27 @@ func Test_parseNvidiaSMIOutput_Empty(t *testing.T) { } } -func Test_parseLsblkOutput(t *testing.T) { - output := `sda 500107862016 disk -nvme0n1 1000204886016 disk +func Test_parseStorageOutput(t *testing.T) { + output := `nvme0n1 500107862016 disk +nvme1n1 1000204886016 disk +sda 2048 rom ` - devices := parseLsblkOutput(output) - if len(devices) != 2 { - t.Fatalf("expected 2 devices, got %d", len(devices)) - } - if devices[0].Name != "sda" { - t.Errorf("unexpected device name: %s", devices[0].Name) + totalBytes, storageType := parseStorageOutput(output) + expected := int64(500107862016 + 1000204886016) + if totalBytes != expected { + t.Errorf("expected %d bytes, got %d", expected, totalBytes) } - if devices[0].Bytes != 500107862016 { - t.Errorf("unexpected device size: %d", devices[0].Bytes) + if storageType != "NVMe" { + t.Errorf("expected NVMe, got %s", storageType) } - if devices[0].Type != "disk" { - t.Errorf("unexpected device type: %s", devices[0].Type) +} + +func Test_parseStorageOutput_SDA(t *testing.T) { + output := `sda 500107862016 disk +` + _, storageType := parseStorageOutput(output) + if storageType != "SSD" { + t.Errorf("expected SSD, got %s", storageType) } } @@ -160,42 +178,31 @@ func Test_unquote(t *testing.T) { } } -func Test_FormatHardwareProfile(t *testing.T) { - p := &HardwareProfile{ - CPUCount: 12, - CPUModel: "AMD EPYC 7763", - RAMBytes: 137438953472, // 128 GB +func Test_FormatNodeSpec(t *testing.T) { + cpuCount := int32(12) + ramBytes := int64(137438953472) // 128 GB + memBytes := int64(137438953472) // 128 GB + s := &NodeSpec{ + CPUCount: &cpuCount, + RAMBytes: &ramBytes, Architecture: "arm64", - OSName: "Ubuntu", + OS: "Ubuntu", OSVersion: "24.04", - GPUs: []GPUInfo{ - {Name: "NVIDIA GB10", MemoryMB: 131072}, + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 1, MemoryBytes: &memBytes}, }, } - output := FormatHardwareProfile(p) + output := FormatNodeSpec(s) if output == "" { t.Fatal("expected non-empty output") } - if !contains(output, "12 x AMD EPYC 7763") { + if !strings.Contains(output, "12 cores") { t.Errorf("expected CPU info in output: %s", output) } - if !contains(output, "128 GB") { + if !strings.Contains(output, "128 GB") { t.Errorf("expected RAM info in output: %s", output) } - if !contains(output, "NVIDIA GB10") { + if !strings.Contains(output, "NVIDIA GB10") { t.Errorf("expected GPU info in output: %s", output) } } - -func contains(s, substr string) bool { - return len(s) >= len(substr) && searchString(s, substr) -} - -func searchString(s, substr string) bool { - for i := 0; i <= len(s)-len(substr); i++ { - if s[i:i+len(substr)] == substr { - return true - } - } - return false -} diff --git a/pkg/cmd/register/identity.go b/pkg/cmd/register/identity.go index 4dee6efc..1835190b 100644 --- a/pkg/cmd/register/identity.go +++ b/pkg/cmd/register/identity.go @@ -1,12 +1,7 @@ package register import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" "path/filepath" - "sort" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/files" @@ -15,82 +10,14 @@ import ( const registrationFileName = "spark_registration.json" // SparkRegistration is the persistent identity file for a registered DGX Spark. +// Fields align with the AddNodeResponse from dev-plane. type SparkRegistration struct { - BrevCloudNodeID string `json:"brev_cloud_node_id"` - DisplayName string `json:"display_name"` - OrgID string `json:"org_id"` - HardwareFingerprint string `json:"hardware_fingerprint"` - DeviceFingerprintHash string `json:"device_fingerprint_hash"` - RegisteredAt string `json:"registered_at"` - HardwareProfile HardwareProfile `json:"hardware_profile"` -} - -// HardwareDescriptor captures high-level hardware traits for fingerprinting. -// Must produce byte-identical JSON to dev-plane's HardwareDescriptor. -type HardwareDescriptor struct { - GPUs []GPUDescriptor `json:"gpus"` - RAM int64 `json:"ram_bytes"` - CPUs int `json:"cpus"` -} - -// GPUDescriptor describes a single GPU for fingerprinting. -type GPUDescriptor struct { - Model string `json:"model"` - Memory int64 `json:"memory_bytes"` -} - -// ComputeHardwareFingerprint returns a deterministic SHA-256 fingerprint. -// This must produce identical output to dev-plane's ComputeHardwareFingerprint -// for the same input. -func ComputeHardwareFingerprint(desc HardwareDescriptor) (string, error) { - if desc.CPUs < 1 { - return "", fmt.Errorf("CPUs must be at least 1") - } - if desc.RAM < 1 { - return "", fmt.Errorf("RAM must be at least 1") - } - for _, gpu := range desc.GPUs { - if gpu.Model == "" { - return "", fmt.Errorf("GPU model must not be empty") - } - if gpu.Memory < 1 { - return "", fmt.Errorf("GPU memory must be at least 1") - } - } - - // Sort GPUs by model then memory for stable ordering. - gpus := make([]GPUDescriptor, len(desc.GPUs)) - copy(gpus, desc.GPUs) - sort.Slice(gpus, func(i, j int) bool { - if gpus[i].Model == gpus[j].Model { - return gpus[i].Memory < gpus[j].Memory - } - return gpus[i].Model < gpus[j].Model - }) - desc.GPUs = gpus - - payload, err := json.Marshal(desc) - if err != nil { - return "", breverrors.WrapAndTrace(err) - } - sum := sha256.Sum256(payload) - return hex.EncodeToString(sum[:]), nil -} - -// HardwareProfileToDescriptor converts a HardwareProfile to a HardwareDescriptor -// for fingerprinting. -func HardwareProfileToDescriptor(p *HardwareProfile) HardwareDescriptor { - desc := HardwareDescriptor{ - RAM: p.RAMBytes, - CPUs: p.CPUCount, - } - for _, gpu := range p.GPUs { - desc.GPUs = append(desc.GPUs, GPUDescriptor{ - Model: gpu.Name, - Memory: gpu.MemoryMB * 1024 * 1024, // convert MB to bytes - }) - } - return desc + ExternalNodeID string `json:"external_node_id"` + DisplayName string `json:"display_name"` + OrgID string `json:"org_id"` + DeviceID string `json:"device_id"` + RegisteredAt string `json:"registered_at"` + NodeSpec NodeSpec `json:"node_spec"` } func registrationPath(brevHome string) string { diff --git a/pkg/cmd/register/identity_test.go b/pkg/cmd/register/identity_test.go deleted file mode 100644 index 93009dba..00000000 --- a/pkg/cmd/register/identity_test.go +++ /dev/null @@ -1,210 +0,0 @@ -package register - -import ( - "crypto/sha256" - "encoding/hex" - "encoding/json" - "sort" - "testing" -) - -// Test_ComputeHardwareFingerprint_Deterministic verifies that the same input -// always produces the same fingerprint. -func Test_ComputeHardwareFingerprint_Deterministic(t *testing.T) { - desc := HardwareDescriptor{ - GPUs: []GPUDescriptor{ - {Model: "NVIDIA GB10", Memory: 137438953472}, - }, - RAM: 137438953472, - CPUs: 12, - } - - fp1, err := ComputeHardwareFingerprint(desc) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - fp2, err := ComputeHardwareFingerprint(desc) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if fp1 != fp2 { - t.Errorf("fingerprints differ: %s != %s", fp1, fp2) - } -} - -// Test_ComputeHardwareFingerprint_GPUOrderIndependent verifies that GPU order -// does not affect the fingerprint. -func Test_ComputeHardwareFingerprint_GPUOrderIndependent(t *testing.T) { - desc1 := HardwareDescriptor{ - GPUs: []GPUDescriptor{ - {Model: "NVIDIA A100", Memory: 85899345920}, - {Model: "NVIDIA GB10", Memory: 137438953472}, - }, - RAM: 274877906944, - CPUs: 64, - } - desc2 := HardwareDescriptor{ - GPUs: []GPUDescriptor{ - {Model: "NVIDIA GB10", Memory: 137438953472}, - {Model: "NVIDIA A100", Memory: 85899345920}, - }, - RAM: 274877906944, - CPUs: 64, - } - - fp1, err := ComputeHardwareFingerprint(desc1) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - fp2, err := ComputeHardwareFingerprint(desc2) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if fp1 != fp2 { - t.Errorf("fingerprints should be identical regardless of GPU order: %s != %s", fp1, fp2) - } -} - -// Test_ComputeHardwareFingerprint_ByteIdenticalToDevPlane verifies that our -// fingerprint is byte-identical to what dev-plane would produce. We replicate -// the dev-plane logic inline to prove equivalence. -func Test_ComputeHardwareFingerprint_ByteIdenticalToDevPlane(t *testing.T) { - desc := HardwareDescriptor{ - GPUs: []GPUDescriptor{ - {Model: "NVIDIA GB10", Memory: 137438953472}, - {Model: "NVIDIA A100", Memory: 85899345920}, - }, - RAM: 274877906944, - CPUs: 64, - } - - // Compute using our function - got, err := ComputeHardwareFingerprint(desc) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - // Replicate dev-plane logic exactly - gpus := make([]GPUDescriptor, len(desc.GPUs)) - copy(gpus, desc.GPUs) - sort.Slice(gpus, func(i, j int) bool { - if gpus[i].Model == gpus[j].Model { - return gpus[i].Memory < gpus[j].Memory - } - return gpus[i].Model < gpus[j].Model - }) - // Build the same struct shape dev-plane uses - type devPlaneGPU struct { - Model string `json:"model"` - Memory int64 `json:"memory_bytes"` - } - type devPlaneDesc struct { - GPUs []devPlaneGPU `json:"gpus"` - RAM int64 `json:"ram_bytes"` - CPUs int `json:"cpus"` - } - dpGPUs := make([]devPlaneGPU, len(gpus)) - for i, g := range gpus { - dpGPUs[i] = devPlaneGPU{Model: g.Model, Memory: g.Memory} - } - dpDesc := devPlaneDesc{GPUs: dpGPUs, RAM: desc.RAM, CPUs: desc.CPUs} - payload, err := json.Marshal(dpDesc) - if err != nil { - t.Fatalf("json.Marshal failed: %v", err) - } - sum := sha256.Sum256(payload) - want := hex.EncodeToString(sum[:]) - - if got != want { - t.Errorf("fingerprint mismatch with dev-plane logic:\ngot: %s\nwant: %s", got, want) - } -} - -// Test_ComputeHardwareFingerprint_NoGPUs verifies fingerprinting works with -// no GPUs present. -func Test_ComputeHardwareFingerprint_NoGPUs(t *testing.T) { - desc := HardwareDescriptor{ - GPUs: nil, - RAM: 8589934592, - CPUs: 4, - } - fp, err := ComputeHardwareFingerprint(desc) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if fp == "" { - t.Fatal("expected non-empty fingerprint") - } -} - -// Test_ComputeHardwareFingerprint_ValidationErrors verifies validation. -func Test_ComputeHardwareFingerprint_ValidationErrors(t *testing.T) { - tests := []struct { - name string - desc HardwareDescriptor - }{ - { - name: "zero CPUs", - desc: HardwareDescriptor{RAM: 1024, CPUs: 0}, - }, - { - name: "zero RAM", - desc: HardwareDescriptor{RAM: 0, CPUs: 1}, - }, - { - name: "GPU with empty model", - desc: HardwareDescriptor{ - RAM: 1024, - CPUs: 1, - GPUs: []GPUDescriptor{{Model: "", Memory: 1024}}, - }, - }, - { - name: "GPU with zero memory", - desc: HardwareDescriptor{ - RAM: 1024, - CPUs: 1, - GPUs: []GPUDescriptor{{Model: "test", Memory: 0}}, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - _, err := ComputeHardwareFingerprint(tt.desc) - if err == nil { - t.Error("expected validation error") - } - }) - } -} - -// Test_HardwareProfileToDescriptor verifies the conversion. -func Test_HardwareProfileToDescriptor(t *testing.T) { - profile := &HardwareProfile{ - CPUCount: 12, - RAMBytes: 137438953472, - GPUs: []GPUInfo{ - {Name: "NVIDIA GB10", MemoryMB: 131072}, - }, - } - - desc := HardwareProfileToDescriptor(profile) - if desc.CPUs != 12 { - t.Errorf("expected 12 CPUs, got %d", desc.CPUs) - } - if desc.RAM != 137438953472 { - t.Errorf("expected RAM 137438953472, got %d", desc.RAM) - } - if len(desc.GPUs) != 1 { - t.Fatalf("expected 1 GPU, got %d", len(desc.GPUs)) - } - if desc.GPUs[0].Model != "NVIDIA GB10" { - t.Errorf("unexpected GPU model: %s", desc.GPUs[0].Model) - } - // 131072 MB = 131072 * 1024 * 1024 bytes - expectedMem := int64(131072) * 1024 * 1024 - if desc.GPUs[0].Memory != expectedMem { - t.Errorf("expected GPU memory %d, got %d", expectedMem, desc.GPUs[0].Memory) - } -} diff --git a/pkg/cmd/register/osfile.go b/pkg/cmd/register/osfile.go deleted file mode 100644 index b4906e31..00000000 --- a/pkg/cmd/register/osfile.go +++ /dev/null @@ -1,8 +0,0 @@ -package register - -import "os" - -// readOSFile reads a file from the real filesystem. -func readOSFile(path string) ([]byte, error) { - return os.ReadFile(path) // #nosec G304 -} diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index aa430cd9..69caac86 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -3,6 +3,7 @@ package register import ( "fmt" + "os" "os/user" "runtime" @@ -24,7 +25,11 @@ type RegisterStore interface { type OSFileReader struct{} func (r OSFileReader) ReadFile(path string) ([]byte, error) { - return readOSFile(path) + data, err := os.ReadFile(path) // #nosec G304 + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return data, nil } var ( @@ -107,7 +112,7 @@ func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //n Items: []string{"Yes, proceed", "No, cancel"}, }) if result != "Yes, proceed" { - t.Vprint("Registration cancelled.") + t.Vprint("Registration canceled.") return nil } @@ -124,30 +129,22 @@ func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //n runner := ExecCommandRunner{} reader := OSFileReader{} - profile, err := CollectHardwareProfile(runner, reader) + nodeSpec, err := CollectHardwareProfile(runner, reader) if err != nil { return fmt.Errorf("failed to collect hardware profile: %w", err) } t.Vprint(" Hardware profile:") - t.Vprint(FormatHardwareProfile(profile)) - - desc := HardwareProfileToDescriptor(profile) - fingerprint, err := ComputeHardwareFingerprint(desc) - if err != nil { - t.Vprintf(" Warning: could not compute hardware fingerprint: %v\n", err) - } else { - t.Vprintf(" Fingerprint: %s\n", fingerprint) - } + t.Vprint(FormatNodeSpec(nodeSpec)) t.Vprint("") t.Vprint(t.Yellow("[TODO] Registration API call not yet implemented.")) - t.Vprint(" Once implemented, the backend will return a node ID") + t.Vprint(" Once implemented, the backend will return an external_node_id") t.Vprintf(" that will be persisted to %s/spark_registration.json.\n", brevHome) t.Vprint("") - _ = org.ID // will be used in the registration API call - _ = name // will be sent as display_name + _ = org.ID // will be used in AddNodeRequest.organization_id + _ = name // will be sent as AddNodeRequest.name return nil } From d1765692e700084b65b12802c493028b8fd8d635 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Wed, 25 Feb 2026 16:55:20 -0800 Subject: [PATCH 03/11] using rpc --- pkg/cmd/deregister/deregister.go | 18 +++- pkg/cmd/register/netbird.go | 14 +++ pkg/cmd/register/register.go | 60 ++++++++--- pkg/cmd/register/rpcclient.go | 165 +++++++++++++++++++++++++++++++ pkg/store/http.go | 5 + 5 files changed, 247 insertions(+), 15 deletions(-) create mode 100644 pkg/cmd/register/rpcclient.go diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 48ca16e2..4bea1993 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -2,6 +2,7 @@ package deregister import ( + "context" "fmt" "runtime" @@ -17,6 +18,7 @@ import ( type DeregisterStore interface { GetCurrentUser() (*entity.User, error) GetBrevHomePath() (string, error) + GetAccessToken() (string, error) } var ( @@ -37,14 +39,14 @@ func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Comman Long: deregisterLong, Example: deregisterExample, RunE: func(cmd *cobra.Command, args []string) error { - return runDeregister(t, store) + return runDeregister(cmd.Context(), t, store) }, } return cmd } -func runDeregister(t *terminal.Terminal, s DeregisterStore) error { +func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) error { //nolint:funlen // deregistration flow if runtime.GOOS != "linux" { return fmt.Errorf("brev deregister is only supported on Linux (DGX Spark)") } @@ -85,7 +87,15 @@ func runDeregister(t *terminal.Terminal, s DeregisterStore) error { } t.Vprint("") - t.Vprint(t.Yellow("[TODO] Deregistration API call not yet implemented.")) + t.Vprint(t.Yellow("Removing node from Brev...")) + client := register.NewConnectNodeClient(s, register.DevPlaneBaseURL) + if err := client.RemoveNode(ctx, ®ister.RemoveNodeRequest{ + ExternalNodeID: reg.ExternalNodeID, + OrganizationID: reg.OrgID, + }); err != nil { + return fmt.Errorf("failed to deregister node: %w", err) + } + t.Vprint(t.Green(" Node removed from Brev.")) t.Vprint("") if removeNetbird == "Yes, uninstall netbird" { @@ -98,7 +108,7 @@ func runDeregister(t *terminal.Terminal, s DeregisterStore) error { t.Vprint("") } - t.Vprint("Removing local registration data...") + t.Vprint("Removing registration data...") if err := register.DeleteRegistration(brevHome); err != nil { return fmt.Errorf("failed to remove registration data: %w", err) } diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go index 6eb1cff6..95b2c536 100644 --- a/pkg/cmd/register/netbird.go +++ b/pkg/cmd/register/netbird.go @@ -22,6 +22,20 @@ func InstallNetbird(t *terminal.Terminal) error { return nil } +// runSetupCommands executes the setup commands returned by the AddNode RPC. +// The commands are keyed by name; values are shell commands to execute. +func runSetupCommands(commands map[string]string) error { + for name, script := range commands { + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command %q failed: %w", name, err) + } + } + return nil +} + // UninstallNetbird stops, uninstalls, and removes netbird. func UninstallNetbird(t *terminal.Terminal) error { script := `netbird service stop && netbird service uninstall && sudo apt-get remove -y netbird` diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 69caac86..4ef8decf 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -2,10 +2,14 @@ package register import ( + "context" "fmt" "os" "os/user" "runtime" + "time" + + "github.com/google/uuid" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -19,6 +23,7 @@ type RegisterStore interface { GetCurrentUser() (*entity.User, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) GetBrevHomePath() (string, error) + GetAccessToken() (string, error) } // OSFileReader reads files from the real OS filesystem. @@ -53,7 +58,7 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { Long: registerLong, Example: registerExample, RunE: func(cmd *cobra.Command, args []string) error { - return runRegister(t, store, name) + return runRegister(cmd.Context(), t, store, name) }, } @@ -63,7 +68,7 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { return cmd } -func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //nolint:funlen // registration flow +func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string) error { //nolint:funlen // registration flow if runtime.GOOS != "linux" { return fmt.Errorf("brev register is only supported on Linux (DGX Spark)") } @@ -104,7 +109,8 @@ func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //n t.Vprint("") t.Vprint("This will perform the following steps:") t.Vprint(" 1. Install netbird (network agent)") - t.Vprint(" 2. Register this machine with Brev") + t.Vprint(" 2. Collect hardware profile") + t.Vprint(" 3. Register this machine with Brev") t.Vprint("") result := terminal.PromptSelectInput(terminal.PromptSelectContent{ @@ -117,14 +123,14 @@ func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //n } t.Vprint("") - t.Vprint(t.Yellow("[Step 1/2] Installing netbird...")) + t.Vprint(t.Yellow("[Step 1/3] Installing netbird...")) if err := InstallNetbird(t); err != nil { return fmt.Errorf("netbird installation failed: %w", err) } t.Vprint(t.Green(" Netbird installed successfully.")) t.Vprint("") - t.Vprint(t.Yellow("[Step 2/2] Collecting hardware profile...")) + t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile...")) t.Vprint("") runner := ExecCommandRunner{} @@ -138,13 +144,45 @@ func runRegister(t *terminal.Terminal, s RegisterStore, name string) error { //n t.Vprint(FormatNodeSpec(nodeSpec)) t.Vprint("") - t.Vprint(t.Yellow("[TODO] Registration API call not yet implemented.")) - t.Vprint(" Once implemented, the backend will return an external_node_id") - t.Vprintf(" that will be persisted to %s/spark_registration.json.\n", brevHome) - t.Vprint("") + t.Vprint(t.Yellow("[Step 3/3] Registering with Brev...")) + + deviceID := uuid.New().String() + client := NewConnectNodeClient(s, DevPlaneBaseURL) + addResp, err := client.AddNode(ctx, &AddNodeRequest{ + OrganizationID: org.ID, + Name: name, + DeviceID: deviceID, + NodeSpec: nodeSpec, + }) + if err != nil { + return fmt.Errorf("failed to register node: %w", err) + } - _ = org.ID // will be used in AddNodeRequest.organization_id - _ = name // will be sent as AddNodeRequest.name + reg := &SparkRegistration{ + ExternalNodeID: addResp.ExternalNode.ExternalNodeID, + DisplayName: name, + OrgID: org.ID, + DeviceID: deviceID, + RegisteredAt: time.Now().UTC().Format(time.RFC3339), + NodeSpec: *nodeSpec, + } + if err := SaveRegistration(brevHome, reg); err != nil { + return fmt.Errorf("node registered but failed to save locally: %w", err) + } + t.Vprint(t.Green(" Registration complete.")) + t.Vprintf(" Node ID: %s\n", addResp.ExternalNode.ExternalNodeID) + + if len(addResp.SetupCommands) > 0 { + t.Vprint("") + t.Vprint(" Running network setup commands...") + if err := runSetupCommands(addResp.SetupCommands); err != nil { + t.Vprintf(" Warning: setup command failed: %v\n", err) + } else { + t.Vprint(t.Green(" Network setup complete.")) + } + } + + t.Vprint("") return nil } diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go new file mode 100644 index 00000000..65fe6643 --- /dev/null +++ b/pkg/cmd/register/rpcclient.go @@ -0,0 +1,165 @@ +package register + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + + breverrors "github.com/brevdev/brev-cli/pkg/errors" +) + +// DevPlaneBaseURL is the base URL for the dev-plane API. +// TODO: source from config once the URL is finalized. +const DevPlaneBaseURL = "https://brevapi.us-west-2-prod.control-plane.brev.dev" + +// TODO: Replace these local types with generated proto types once the +// ExternalNodeService is published to buf.build: +// +// import ( +// nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" +// nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" +// ) + +// AddNodeRequest matches the proto AddNodeRequest message. +type AddNodeRequest struct { + OrganizationID string `json:"organization_id"` + Name string `json:"name"` + DeviceID string `json:"device_id"` + NodeSpec *NodeSpec `json:"node_spec"` +} + +// AddNodeResponse matches the proto AddNodeResponse message. +type AddNodeResponse struct { + ExternalNode *ExternalNode `json:"external_node"` + SetupCommands map[string]string `json:"setup_commands"` +} + +// ExternalNode matches the proto ExternalNode message (subset of fields we need). +type ExternalNode struct { + ExternalNodeID string `json:"external_node_id"` + OrganizationID string `json:"organization_id"` + Name string `json:"name"` + DeviceID string `json:"device_id"` +} + +// RemoveNodeRequest matches the proto RemoveNodeRequest message. +type RemoveNodeRequest struct { + ExternalNodeID string `json:"external_node_id"` + OrganizationID string `json:"organization_id"` +} + +// ExternalNodeServiceClient defines the RPCs we call from the CLI. +// This will be replaced by the generated ConnectRPC client interface +// once the service is published. +type ExternalNodeServiceClient interface { + AddNode(ctx context.Context, req *AddNodeRequest) (*AddNodeResponse, error) + RemoveNode(ctx context.Context, req *RemoveNodeRequest) error +} + +// tokenProvider abstracts access token retrieval for the HTTP transport. +type tokenProvider interface { + GetAccessToken() (string, error) +} + +// bearerTokenTransport injects a Bearer token into every request. +type bearerTokenTransport struct { + provider tokenProvider + base http.RoundTripper +} + +func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.provider.GetAccessToken() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+token) + return t.base.RoundTrip(req) +} + +// newAuthenticatedHTTPClient creates an http.Client that injects the bearer token +// from the given provider on every request. +func newAuthenticatedHTTPClient(provider tokenProvider) *http.Client { + return &http.Client{ + Transport: &bearerTokenTransport{ + provider: provider, + base: http.DefaultTransport, + }, + } +} + +// ConnectNodeClient is a temporary REST-based implementation of ExternalNodeServiceClient. +// It will be replaced by the generated ConnectRPC client once the service proto +// is published to buf.build. +// +// TODO: Replace with: +// +// httpClient := newAuthenticatedHTTPClient(store) +// client := nodev1connect.NewExternalNodeServiceClient(httpClient, baseURL) +type ConnectNodeClient struct { + httpClient *http.Client + baseURL string +} + +// NewConnectNodeClient creates a new ConnectNodeClient. +func NewConnectNodeClient(provider tokenProvider, baseURL string) *ConnectNodeClient { + return &ConnectNodeClient{ + httpClient: newAuthenticatedHTTPClient(provider), + baseURL: baseURL, + } +} + +func (c *ConnectNodeClient) AddNode(ctx context.Context, req *AddNodeRequest) (*AddNodeResponse, error) { + body, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal AddNodeRequest: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/devplaneapi.v1.ExternalNodeService/AddNode", bytes.NewReader(body)) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("AddNode request failed: %w", err) + } + defer resp.Body.Close() //nolint:errcheck // best-effort close + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("AddNode returned status %d", resp.StatusCode) + } + + var result AddNodeResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode AddNode response: %w", err) + } + return &result, nil +} + +func (c *ConnectNodeClient) RemoveNode(ctx context.Context, req *RemoveNodeRequest) error { + body, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal RemoveNodeRequest: %w", err) + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/devplaneapi.v1.ExternalNodeService/RemoveNode", bytes.NewReader(body)) + if err != nil { + return breverrors.WrapAndTrace(err) + } + httpReq.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(httpReq) + if err != nil { + return fmt.Errorf("RemoveNode request failed: %w", err) + } + defer resp.Body.Close() //nolint:errcheck // best-effort close + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("RemoveNode returned status %d", resp.StatusCode) + } + return nil +} diff --git a/pkg/store/http.go b/pkg/store/http.go index c64d9217..adcfeb79 100644 --- a/pkg/store/http.go +++ b/pkg/store/http.go @@ -61,6 +61,11 @@ func (s *AuthHTTPStore) GetWindowsDir() (string, error) { return s.GetWSLHostHomeDir() } +// GetAccessToken returns a fresh access token, refreshing if needed. +func (s *AuthHTTPStore) GetAccessToken() (string, error) { + return s.authHTTPClient.auth.GetAccessToken() +} + func (f *FileStore) WithAuthHTTPClient(c *AuthHTTPClient) *AuthHTTPStore { // err never returned from GetCurrentWorkspaceID id, _ := f.GetCurrentWorkspaceID() From f3a12c89e9b4e28e94da41aeda68aceef5273d91 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Wed, 25 Feb 2026 17:44:14 -0800 Subject: [PATCH 04/11] more tests --- pkg/cmd/deregister/deregister.go | 17 +-- pkg/cmd/register/hardware.go | 16 +-- pkg/cmd/register/hardware_test.go | 202 +++++++++++++++++++++++++++++ pkg/cmd/register/identity.go | 20 ++- pkg/cmd/register/identity_test.go | 134 +++++++++++++++++++ pkg/cmd/register/netbird.go | 8 +- pkg/cmd/register/register.go | 27 ++-- pkg/cmd/register/rpcclient.go | 10 +- pkg/cmd/register/rpcclient_test.go | 180 +++++++++++++++++++++++++ pkg/store/http.go | 6 +- 10 files changed, 571 insertions(+), 49 deletions(-) create mode 100644 pkg/cmd/register/identity_test.go create mode 100644 pkg/cmd/register/rpcclient_test.go diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 4bea1993..a07bece7 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -7,6 +7,7 @@ import ( "runtime" "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" @@ -25,7 +26,7 @@ var ( deregisterLong = `Deregister your DGX Spark from NVIDIA Brev This command removes the local registration data and optionally uninstalls -netbird (network agent).` +NetBird (network agent).` deregisterExample = ` brev deregister` ) @@ -73,8 +74,8 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) t.Vprint("") removeNetbird := terminal.PromptSelectInput(terminal.PromptSelectContent{ - Label: "Would you also like to uninstall netbird?", - Items: []string{"Yes, uninstall netbird", "No, keep netbird installed"}, + Label: "Would you also like to uninstall NetBird?", + Items: []string{"Yes, uninstall NetBird", "No, keep NetBird installed"}, }) confirm := terminal.PromptSelectInput(terminal.PromptSelectContent{ @@ -88,7 +89,7 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) t.Vprint("") t.Vprint(t.Yellow("Removing node from Brev...")) - client := register.NewConnectNodeClient(s, register.DevPlaneBaseURL) + client := register.NewConnectNodeClient(s, config.GlobalConfig.GetBrevAPIURl()) if err := client.RemoveNode(ctx, ®ister.RemoveNodeRequest{ ExternalNodeID: reg.ExternalNodeID, OrganizationID: reg.OrgID, @@ -98,12 +99,12 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) t.Vprint(t.Green(" Node removed from Brev.")) t.Vprint("") - if removeNetbird == "Yes, uninstall netbird" { - t.Vprint("Removing netbird...") + if removeNetbird == "Yes, uninstall NetBird" { + t.Vprint("Removing NetBird...") if err := register.UninstallNetbird(t); err != nil { - t.Vprintf(" Warning: failed to uninstall netbird: %v\n", err) + t.Vprintf(" Warning: failed to uninstall NetBird: %v\n", err) } else { - t.Vprint(t.Green(" Netbird uninstalled.")) + t.Vprint(t.Green(" NetBird uninstalled.")) } t.Vprint("") } diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go index 31aa7f02..f23f8721 100644 --- a/pkg/cmd/register/hardware.go +++ b/pkg/cmd/register/hardware.go @@ -276,27 +276,27 @@ func parseStorageOutput(output string) (int64, string) { func FormatNodeSpec(s *NodeSpec) string { var b strings.Builder if s.CPUCount != nil { - fmt.Fprintf(&b, " CPU: %d cores\n", *s.CPUCount) + _, _ = fmt.Fprintf(&b, " CPU: %d cores\n", *s.CPUCount) } if s.RAMBytes != nil { - fmt.Fprintf(&b, " RAM: %d GB\n", *s.RAMBytes/(1024*1024*1024)) + _, _ = fmt.Fprintf(&b, " RAM: %d GB\n", *s.RAMBytes/(1024*1024*1024)) } for _, gpu := range s.GPUs { if gpu.MemoryBytes != nil { memGB := *gpu.MemoryBytes / (1024 * 1024 * 1024) - fmt.Fprintf(&b, " GPUs: %d x %s (%d GB)\n", gpu.Count, gpu.Model, memGB) + _, _ = fmt.Fprintf(&b, " GPUs: %d x %s (%d GB)\n", gpu.Count, gpu.Model, memGB) } else { - fmt.Fprintf(&b, " GPUs: %d x %s\n", gpu.Count, gpu.Model) + _, _ = fmt.Fprintf(&b, " GPUs: %d x %s\n", gpu.Count, gpu.Model) } } - fmt.Fprintf(&b, " Arch: %s\n", s.Architecture) + _, _ = fmt.Fprintf(&b, " Arch: %s\n", s.Architecture) if s.OS != "" || s.OSVersion != "" { - fmt.Fprintf(&b, " OS: %s %s\n", s.OS, s.OSVersion) + _, _ = fmt.Fprintf(&b, " OS: %s %s\n", s.OS, s.OSVersion) } if s.StorageBytes != nil { - fmt.Fprintf(&b, " Storage: %d GB", *s.StorageBytes/(1024*1024*1024)) + _, _ = fmt.Fprintf(&b, " Storage: %d GB", *s.StorageBytes/(1024*1024*1024)) if s.StorageType != "" { - fmt.Fprintf(&b, " (%s)", s.StorageType) + _, _ = fmt.Fprintf(&b, " (%s)", s.StorageType) } b.WriteString("\n") } diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go index edf0e4ef..7e53fe7c 100644 --- a/pkg/cmd/register/hardware_test.go +++ b/pkg/cmd/register/hardware_test.go @@ -206,3 +206,205 @@ func Test_FormatNodeSpec(t *testing.T) { t.Errorf("expected GPU info in output: %s", output) } } + +func Test_FormatNodeSpec_MinimalFields(t *testing.T) { + s := &NodeSpec{ + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 1}, + }, + Architecture: "arm64", + } + output := FormatNodeSpec(s) + if strings.Contains(output, "CPU:") { + t.Errorf("should not contain CPU when nil: %s", output) + } + if strings.Contains(output, "RAM:") { + t.Errorf("should not contain RAM when nil: %s", output) + } + if !strings.Contains(output, "NVIDIA GB10") { + t.Errorf("expected GPU info: %s", output) + } + if !strings.Contains(output, "arm64") { + t.Errorf("expected arch info: %s", output) + } +} + +func Test_FormatNodeSpec_WithStorage(t *testing.T) { + storageBytes := int64(1099511627776) // 1 TB + s := &NodeSpec{ + Architecture: "amd64", + StorageBytes: &storageBytes, + StorageType: "NVMe", + } + output := FormatNodeSpec(s) + if !strings.Contains(output, "Storage:") { + t.Errorf("expected storage in output: %s", output) + } + if !strings.Contains(output, "NVMe") { + t.Errorf("expected NVMe in output: %s", output) + } +} + +func Test_parseNvidiaSMIOutput_MalformedLines(t *testing.T) { + output := ` +malformed line +NVIDIA GB10, 131072 +, , +just-a-name +NVIDIA A100, not-a-number +` + gpus := parseNvidiaSMIOutput(output) + if len(gpus) != 1 { + t.Fatalf("expected 1 valid GPU, got %d", len(gpus)) + } + if gpus[0].Model != "NVIDIA GB10" { + t.Errorf("unexpected model: %s", gpus[0].Model) + } +} + +func Test_parseStorageOutput_Empty(t *testing.T) { + totalBytes, storageType := parseStorageOutput("") + if totalBytes != 0 { + t.Errorf("expected 0 bytes, got %d", totalBytes) + } + if storageType != "" { + t.Errorf("expected empty storage type, got %s", storageType) + } +} + +func Test_parseStorageOutput_NoDiskDevices(t *testing.T) { + output := `sr0 1073741312 rom +loop0 123456 loop +` + totalBytes, storageType := parseStorageOutput(output) + if totalBytes != 0 { + t.Errorf("expected 0 bytes for non-disk devices, got %d", totalBytes) + } + if storageType != "" { + t.Errorf("expected empty storage type, got %s", storageType) + } +} + +// mockCommandRunner for testing CollectHardwareProfile +type mockCommandRunner struct { + outputs map[string][]byte + errors map[string]error +} + +func (m *mockCommandRunner) Run(name string, args ...string) ([]byte, error) { + key := name + if err, ok := m.errors[key]; ok { + return nil, err + } + if out, ok := m.outputs[key]; ok { + return out, nil + } + return nil, nil +} + +type mockFileReader struct { + files map[string][]byte +} + +func (m *mockFileReader) ReadFile(path string) ([]byte, error) { + if data, ok := m.files[path]; ok { + return data, nil + } + return nil, &mockFileNotFoundError{path: path} +} + +type mockFileNotFoundError struct{ path string } + +func (e *mockFileNotFoundError) Error() string { return "file not found: " + e.path } + +func Test_CollectHardwareProfile_WithMocks(t *testing.T) { + runner := &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\nNVIDIA GB10, 131072\n"), + "lsblk": []byte("nvme0n1 500107862016 disk\n"), + }, + } + reader := &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\nprocessor\t: 1\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), + }, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(spec.GPUs) != 1 || spec.GPUs[0].Count != 2 { + t.Errorf("expected 1 GPU group with count 2, got %v", spec.GPUs) + } + if spec.CPUCount == nil || *spec.CPUCount != 2 { + t.Errorf("expected 2 CPUs, got %v", spec.CPUCount) + } + if spec.RAMBytes == nil || *spec.RAMBytes != 131886028*1024 { + t.Errorf("unexpected RAM: %v", spec.RAMBytes) + } + if spec.OS != "Ubuntu" || spec.OSVersion != "24.04" { + t.Errorf("unexpected OS: %s %s", spec.OS, spec.OSVersion) + } + if spec.StorageBytes == nil || *spec.StorageBytes != 500107862016 { + t.Errorf("unexpected storage: %v", spec.StorageBytes) + } +} + +func Test_CollectHardwareProfile_GPUBestEffort(t *testing.T) { + runner := &mockCommandRunner{ + errors: map[string]error{ + "nvidia-smi": &mockFileNotFoundError{path: "nvidia-smi"}, + }, + } + reader := &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + }, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(spec.GPUs) != 0 { + t.Errorf("expected 0 GPUs when nvidia-smi fails, got %d", len(spec.GPUs)) + } + if spec.CPUCount == nil || *spec.CPUCount != 1 { + t.Errorf("expected 1 CPU, got %v", spec.CPUCount) + } +} + +func Test_CollectHardwareProfile_OptionalFieldsMissing(t *testing.T) { + runner := &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), + }, + errors: map[string]error{ + "lsblk": &mockFileNotFoundError{path: "lsblk"}, + }, + } + reader := &mockFileReader{ + files: map[string][]byte{}, + } + + spec, err := CollectHardwareProfile(runner, reader) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if spec.CPUCount != nil { + t.Errorf("expected nil CPUCount when /proc/cpuinfo missing") + } + if spec.RAMBytes != nil { + t.Errorf("expected nil RAMBytes when /proc/meminfo missing") + } + if spec.StorageBytes != nil { + t.Errorf("expected nil StorageBytes when lsblk fails") + } + if len(spec.GPUs) != 1 { + t.Errorf("expected 1 GPU, got %d", len(spec.GPUs)) + } +} diff --git a/pkg/cmd/register/identity.go b/pkg/cmd/register/identity.go index 1835190b..9b0c741a 100644 --- a/pkg/cmd/register/identity.go +++ b/pkg/cmd/register/identity.go @@ -1,17 +1,19 @@ package register import ( + "encoding/json" "path/filepath" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/files" + "github.com/spf13/afero" ) const registrationFileName = "spark_registration.json" -// SparkRegistration is the persistent identity file for a registered DGX Spark. +// DeviceRegistration is the persistent identity file for a registered device. // Fields align with the AddNodeResponse from dev-plane. -type SparkRegistration struct { +type DeviceRegistration struct { ExternalNodeID string `json:"external_node_id"` DisplayName string `json:"display_name"` OrgID string `json:"org_id"` @@ -25,19 +27,25 @@ func registrationPath(brevHome string) string { } // SaveRegistration writes the registration to ~/.brev/spark_registration.json. -func SaveRegistration(brevHome string, reg *SparkRegistration) error { +func SaveRegistration(brevHome string, reg *DeviceRegistration) error { path := registrationPath(brevHome) - err := files.OverwriteJSON(files.AppFs, path, reg) + data, err := json.MarshalIndent(reg, "", " ") if err != nil { return breverrors.WrapAndTrace(err) } + if err := files.AppFs.MkdirAll(filepath.Dir(path), 0o770); err != nil { + return breverrors.WrapAndTrace(err) + } + if err := afero.WriteFile(files.AppFs, path, data, 0o600); err != nil { + return breverrors.WrapAndTrace(err) + } return nil } // LoadRegistration reads the registration from ~/.brev/spark_registration.json. -func LoadRegistration(brevHome string) (*SparkRegistration, error) { +func LoadRegistration(brevHome string) (*DeviceRegistration, error) { path := registrationPath(brevHome) - var reg SparkRegistration + var reg DeviceRegistration err := files.ReadJSON(files.AppFs, path, ®) if err != nil { return nil, breverrors.WrapAndTrace(err) diff --git a/pkg/cmd/register/identity_test.go b/pkg/cmd/register/identity_test.go new file mode 100644 index 00000000..8fdfb1f0 --- /dev/null +++ b/pkg/cmd/register/identity_test.go @@ -0,0 +1,134 @@ +package register + +import ( + "testing" + + "github.com/brevdev/brev-cli/pkg/files" + "github.com/spf13/afero" +) + +func setupTestFs(t *testing.T) (string, func()) { + t.Helper() + origFs := files.AppFs + files.AppFs = afero.NewMemMapFs() + brevHome := "/home/testuser/.brev" + if err := files.AppFs.MkdirAll(brevHome, 0o770); err != nil { + t.Fatalf("failed to create test dir: %v", err) + } + return brevHome, func() { files.AppFs = origFs } +} + +func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + cpuCount := int32(12) + ramBytes := int64(137438953472) + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "My Spark", + OrgID: "org_xyz", + DeviceID: "device-uuid-123", + RegisteredAt: "2026-02-25T00:00:00Z", + NodeSpec: NodeSpec{ + CPUCount: &cpuCount, + RAMBytes: &ramBytes, + Architecture: "arm64", + }, + } + + if err := SaveRegistration(brevHome, reg); err != nil { + t.Fatalf("SaveRegistration failed: %v", err) + } + + loaded, err := LoadRegistration(brevHome) + if err != nil { + t.Fatalf("LoadRegistration failed: %v", err) + } + + if loaded.ExternalNodeID != reg.ExternalNodeID { + t.Errorf("ExternalNodeID mismatch: got %s, want %s", loaded.ExternalNodeID, reg.ExternalNodeID) + } + if loaded.DisplayName != reg.DisplayName { + t.Errorf("DisplayName mismatch: got %s, want %s", loaded.DisplayName, reg.DisplayName) + } + if loaded.OrgID != reg.OrgID { + t.Errorf("OrgID mismatch: got %s, want %s", loaded.OrgID, reg.OrgID) + } + if loaded.DeviceID != reg.DeviceID { + t.Errorf("DeviceID mismatch: got %s, want %s", loaded.DeviceID, reg.DeviceID) + } + if loaded.NodeSpec.Architecture != "arm64" { + t.Errorf("Architecture mismatch: got %s", loaded.NodeSpec.Architecture) + } + if loaded.NodeSpec.CPUCount == nil || *loaded.NodeSpec.CPUCount != 12 { + t.Errorf("CPUCount mismatch: got %v", loaded.NodeSpec.CPUCount) + } +} + +func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + if RegistrationExists(brevHome) { + t.Error("expected RegistrationExists to return false") + } +} + +func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "Test", + } + if err := SaveRegistration(brevHome, reg); err != nil { + t.Fatalf("SaveRegistration failed: %v", err) + } + + if !RegistrationExists(brevHome) { + t.Error("expected RegistrationExists to return true") + } +} + +func Test_DeleteRegistration_RemovesFile(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + reg := &DeviceRegistration{ + ExternalNodeID: "unode_abc123", + DisplayName: "Test", + } + if err := SaveRegistration(brevHome, reg); err != nil { + t.Fatalf("SaveRegistration failed: %v", err) + } + + if err := DeleteRegistration(brevHome); err != nil { + t.Fatalf("DeleteRegistration failed: %v", err) + } + + if RegistrationExists(brevHome) { + t.Error("expected RegistrationExists to return false after delete") + } +} + +func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + _, err := LoadRegistration(brevHome) + if err == nil { + t.Error("expected error loading missing registration") + } +} + +func Test_DeleteRegistration_FailsWhenMissing(t *testing.T) { + brevHome, cleanup := setupTestFs(t) + defer cleanup() + + err := DeleteRegistration(brevHome) + if err == nil { + t.Error("expected error deleting missing registration") + } +} diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go index 95b2c536..1f02ed50 100644 --- a/pkg/cmd/register/netbird.go +++ b/pkg/cmd/register/netbird.go @@ -8,7 +8,7 @@ import ( "github.com/brevdev/brev-cli/pkg/terminal" ) -// InstallNetbird downloads and installs netbird using the official install script. +// InstallNetbird downloads and installs NetBird using the official install script. func InstallNetbird(t *terminal.Terminal) error { script := `(curl -fsSL https://pkgs.netbird.io/install.sh | sh) || (curl -fsSL https://pkgs.netbird.io/install.sh | sh -s -- --update)` @@ -17,7 +17,7 @@ func InstallNetbird(t *terminal.Terminal) error { cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to install netbird: %w", err) + return fmt.Errorf("failed to install NetBird: %w", err) } return nil } @@ -36,7 +36,7 @@ func runSetupCommands(commands map[string]string) error { return nil } -// UninstallNetbird stops, uninstalls, and removes netbird. +// UninstallNetbird stops, uninstalls, and removes NetBird. func UninstallNetbird(t *terminal.Terminal) error { script := `netbird service stop && netbird service uninstall && sudo apt-get remove -y netbird` @@ -45,7 +45,7 @@ func UninstallNetbird(t *terminal.Terminal) error { cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { - return fmt.Errorf("failed to uninstall netbird: %w", err) + return fmt.Errorf("failed to uninstall NetBird: %w", err) } return nil } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 4ef8decf..3e199f73 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -11,6 +11,7 @@ import ( "github.com/google/uuid" + "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" "github.com/brevdev/brev-cli/pkg/terminal" @@ -40,8 +41,7 @@ func (r OSFileReader) ReadFile(path string) ([]byte, error) { var ( registerLong = `Register your DGX Spark with NVIDIA Brev -This command installs netbird (network agent), collects a hardware profile, -and registers this machine with Brev.` +This command installs NetBird (network agent), and registers this machine with Brev.` registerExample = ` brev register --name "My DGX Spark"` ) @@ -54,7 +54,7 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { Use: "register", Aliases: []string{"spark"}, DisableFlagsInUseLine: true, - Short: "Register your DGX Spark with Brev", + Short: "Register this device with Brev", Long: registerLong, Example: registerExample, RunE: func(cmd *cobra.Command, args []string) error { @@ -62,7 +62,7 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { }, } - cmd.Flags().StringVarP(&name, "name", "n", "", "Display name for this DGX Spark (required)") + cmd.Flags().StringVarP(&name, "name", "n", "", "Display name (required)") _ = cmd.MarkFlagRequired("name") return cmd @@ -108,7 +108,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprintf(" Linux user: %s\n", linuxUser) t.Vprint("") t.Vprint("This will perform the following steps:") - t.Vprint(" 1. Install netbird (network agent)") + t.Vprint(" 1. Install NetBird (network agent)") t.Vprint(" 2. Collect hardware profile") t.Vprint(" 3. Register this machine with Brev") t.Vprint("") @@ -123,11 +123,11 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam } t.Vprint("") - t.Vprint(t.Yellow("[Step 1/3] Installing netbird...")) + t.Vprint(t.Yellow("[Step 1/3] Installing NetBird...")) if err := InstallNetbird(t); err != nil { - return fmt.Errorf("netbird installation failed: %w", err) + return fmt.Errorf("NetBird installation failed: %w", err) } - t.Vprint(t.Green(" Netbird installed successfully.")) + t.Vprint(t.Green(" NetBird installed successfully.")) t.Vprint("") t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile...")) @@ -147,7 +147,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Yellow("[Step 3/3] Registering with Brev...")) deviceID := uuid.New().String() - client := NewConnectNodeClient(s, DevPlaneBaseURL) + client := NewConnectNodeClient(s, config.GlobalConfig.GetBrevAPIURl()) addResp, err := client.AddNode(ctx, &AddNodeRequest{ OrganizationID: org.ID, Name: name, @@ -158,7 +158,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return fmt.Errorf("failed to register node: %w", err) } - reg := &SparkRegistration{ + reg := &DeviceRegistration{ ExternalNodeID: addResp.ExternalNode.ExternalNodeID, DisplayName: name, OrgID: org.ID, @@ -171,18 +171,11 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam } t.Vprint(t.Green(" Registration complete.")) - t.Vprintf(" Node ID: %s\n", addResp.ExternalNode.ExternalNodeID) if len(addResp.SetupCommands) > 0 { - t.Vprint("") - t.Vprint(" Running network setup commands...") if err := runSetupCommands(addResp.SetupCommands); err != nil { t.Vprintf(" Warning: setup command failed: %v\n", err) - } else { - t.Vprint(t.Green(" Network setup complete.")) } } - - t.Vprint("") return nil } diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go index 65fe6643..2b2c8a4d 100644 --- a/pkg/cmd/register/rpcclient.go +++ b/pkg/cmd/register/rpcclient.go @@ -10,10 +10,6 @@ import ( breverrors "github.com/brevdev/brev-cli/pkg/errors" ) -// DevPlaneBaseURL is the base URL for the dev-plane API. -// TODO: source from config once the URL is finalized. -const DevPlaneBaseURL = "https://brevapi.us-west-2-prod.control-plane.brev.dev" - // TODO: Replace these local types with generated proto types once the // ExternalNodeService is published to buf.build: // @@ -76,7 +72,11 @@ func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, err } req = req.Clone(req.Context()) req.Header.Set("Authorization", "Bearer "+token) - return t.base.RoundTrip(req) + resp, err := t.base.RoundTrip(req) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp, nil } // newAuthenticatedHTTPClient creates an http.Client that injects the bearer token diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go new file mode 100644 index 00000000..771d19f0 --- /dev/null +++ b/pkg/cmd/register/rpcclient_test.go @@ -0,0 +1,180 @@ +package register + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +type mockTokenProvider struct { + token string + err error +} + +func (m *mockTokenProvider) GetAccessToken() (string, error) { + return m.token, m.err +} + +func Test_bearerTokenTransport_InjectsHeader(t *testing.T) { + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + provider := &mockTokenProvider{token: "test-token-123"} + client := newAuthenticatedHTTPClient(provider) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() //nolint:errcheck // test + + if gotAuth != "Bearer test-token-123" { + t.Errorf("expected 'Bearer test-token-123', got %q", gotAuth) + } +} + +func Test_bearerTokenTransport_PropagatesTokenError(t *testing.T) { + provider := &mockTokenProvider{err: http.ErrAbortHandler} + client := newAuthenticatedHTTPClient(provider) + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://localhost", nil) + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() //nolint:errcheck // test + t.Fatal("expected error from token provider") + } +} + +func Test_ConnectNodeClient_AddNode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/devplaneapi.v1.ExternalNodeService/AddNode" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("expected application/json, got %s", ct) + } + + var req AddNodeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + if req.OrganizationID != "org_123" { + t.Errorf("unexpected org ID: %s", req.OrganizationID) + } + if req.Name != "My Spark" { + t.Errorf("unexpected name: %s", req.Name) + } + + resp := AddNodeResponse{ + ExternalNode: &ExternalNode{ + ExternalNodeID: "unode_abc", + OrganizationID: "org_123", + Name: "My Spark", + DeviceID: req.DeviceID, + }, + SetupCommands: map[string]string{"netbird": "netbird up"}, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) //nolint:errcheck // test + })) + defer server.Close() + + provider := &mockTokenProvider{token: "tok"} + client := NewConnectNodeClient(provider, server.URL) + + resp, err := client.AddNode(context.Background(), &AddNodeRequest{ + OrganizationID: "org_123", + Name: "My Spark", + DeviceID: "dev-uuid", + NodeSpec: &NodeSpec{Architecture: "arm64"}, + }) + if err != nil { + t.Fatalf("AddNode failed: %v", err) + } + if resp.ExternalNode.ExternalNodeID != "unode_abc" { + t.Errorf("unexpected node ID: %s", resp.ExternalNode.ExternalNodeID) + } + if len(resp.SetupCommands) != 1 { + t.Errorf("expected 1 setup command, got %d", len(resp.SetupCommands)) + } +} + +func Test_ConnectNodeClient_AddNode_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + provider := &mockTokenProvider{token: "tok"} + client := NewConnectNodeClient(provider, server.URL) + + _, err := client.AddNode(context.Background(), &AddNodeRequest{ + OrganizationID: "org_123", + Name: "Test", + DeviceID: "dev", + }) + if err == nil { + t.Fatal("expected error for 500 response") + } +} + +func Test_ConnectNodeClient_RemoveNode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/devplaneapi.v1.ExternalNodeService/RemoveNode" { + t.Errorf("unexpected path: %s", r.URL.Path) + } + + var req RemoveNodeRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + if req.ExternalNodeID != "unode_abc" { + t.Errorf("unexpected node ID: %s", req.ExternalNodeID) + } + if req.OrganizationID != "org_123" { + t.Errorf("unexpected org ID: %s", req.OrganizationID) + } + + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + provider := &mockTokenProvider{token: "tok"} + client := NewConnectNodeClient(provider, server.URL) + + err := client.RemoveNode(context.Background(), &RemoveNodeRequest{ + ExternalNodeID: "unode_abc", + OrganizationID: "org_123", + }) + if err != nil { + t.Fatalf("RemoveNode failed: %v", err) + } +} + +func Test_ConnectNodeClient_RemoveNode_ServerError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer server.Close() + + provider := &mockTokenProvider{token: "tok"} + client := NewConnectNodeClient(provider, server.URL) + + err := client.RemoveNode(context.Background(), &RemoveNodeRequest{ + ExternalNodeID: "unode_missing", + OrganizationID: "org_123", + }) + if err == nil { + t.Fatal("expected error for 404 response") + } +} diff --git a/pkg/store/http.go b/pkg/store/http.go index adcfeb79..60884f81 100644 --- a/pkg/store/http.go +++ b/pkg/store/http.go @@ -63,7 +63,11 @@ func (s *AuthHTTPStore) GetWindowsDir() (string, error) { // GetAccessToken returns a fresh access token, refreshing if needed. func (s *AuthHTTPStore) GetAccessToken() (string, error) { - return s.authHTTPClient.auth.GetAccessToken() + token, err := s.authHTTPClient.auth.GetAccessToken() + if err != nil { + return "", breverrors.WrapAndTrace(err) + } + return token, nil } func (f *FileStore) WithAuthHTTPClient(c *AuthHTTPClient) *AuthHTTPStore { From 9ea8431fd26d6321e00e2fd86efd7b1e44088fe8 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Wed, 25 Feb 2026 20:19:28 -0800 Subject: [PATCH 05/11] pulling rpc deps --- go.mod | 8 +- go.sum | 16 +- pkg/cmd/deregister/deregister.go | 13 +- pkg/cmd/register/register.go | 21 ++- pkg/cmd/register/rpcclient.go | 142 ++++---------- pkg/cmd/register/rpcclient_test.go | 286 +++++++++++++++++++---------- 6 files changed, 266 insertions(+), 220 deletions(-) diff --git a/go.mod b/go.mod index f9cad567..e0c66257 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,9 @@ module github.com/brevdev/brev-cli go 1.24.0 require ( + buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2 + buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1 + connectrpc.com/connect v1.19.1 github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 github.com/briandowns/spinner v1.16.0 @@ -12,7 +15,7 @@ require ( github.com/go-git/go-git/v5 v5.13.2 github.com/go-resty/resty/v2 v2.17.0 github.com/golang-jwt/jwt/v5 v5.3.0 - github.com/google/go-cmp v0.6.0 + github.com/google/go-cmp v0.7.0 github.com/google/huproxy v0.0.0-20210816191033-a131ee126ce3 github.com/google/uuid v1.6.0 github.com/gorilla/websocket v1.5.0 @@ -44,6 +47,7 @@ require ( ) require ( + buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 // indirect dario.cat/mergo v1.0.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect @@ -146,7 +150,7 @@ require ( golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.39.0 // indirect golang.org/x/time v0.12.0 // indirect - google.golang.org/protobuf v1.34.2 + google.golang.org/protobuf v1.36.11 gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index b4f2cc37..a831a91f 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2 h1:8ZrXfJx6gzHTeBU2Lfn2jdpi8q8QJMXZPo8GlVEm6+A= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2/go.mod h1:EGcIExX0SEtObIZr1l3pouENtdl2gsZtHjOYOfuB7ss= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1 h1:xkJkJcCnAq5WiEUevk7Kz3b+aFuK7aj64DyVUQM9ZQ0= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= +buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 h1:6amhprQmCKJ4wgJ6ngkh32d9V+dQcOLUZ/SfHdOnYgo= +buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1/go.mod h1:O+pnSHMru/naTMrm4tmpBoH3wz6PHa+R75HR7Mv8X2g= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.38.0/go.mod h1:990N+gfupTy94rShfmMCWGDn0LpTmnzTp2qbd1dvSRU= @@ -35,6 +41,8 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo= +connectrpc.com/connect v1.19.1 h1:R5M57z05+90EfEvCY1b7hBxDVOUl45PrtXtAV2fOC14= +connectrpc.com/connect v1.19.1/go.mod h1:tN20fjdGlewnSFeZxLKb0xwIZ6ozc3OQs2hTXy4du9w= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= @@ -208,8 +216,8 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -772,8 +780,8 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index a07bece7..6a1beddd 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -6,6 +6,9 @@ import ( "fmt" "runtime" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" @@ -89,11 +92,11 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) t.Vprint("") t.Vprint(t.Yellow("Removing node from Brev...")) - client := register.NewConnectNodeClient(s, config.GlobalConfig.GetBrevAPIURl()) - if err := client.RemoveNode(ctx, ®ister.RemoveNodeRequest{ - ExternalNodeID: reg.ExternalNodeID, - OrganizationID: reg.OrgID, - }); err != nil { + client := register.NewNodeServiceClient(s, config.GlobalConfig.GetBrevAPIURl()) + if _, err := client.RemoveNode(ctx, connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: reg.ExternalNodeID, + OrganizationId: reg.OrgID, + })); err != nil { return fmt.Errorf("failed to deregister node: %w", err) } t.Vprint(t.Green(" Node removed from Brev.")) diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 3e199f73..78b83246 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -9,6 +9,8 @@ import ( "runtime" "time" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" "github.com/google/uuid" "github.com/brevdev/brev-cli/pkg/config" @@ -147,19 +149,20 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Yellow("[Step 3/3] Registering with Brev...")) deviceID := uuid.New().String() - client := NewConnectNodeClient(s, config.GlobalConfig.GetBrevAPIURl()) - addResp, err := client.AddNode(ctx, &AddNodeRequest{ - OrganizationID: org.ID, + client := NewNodeServiceClient(s, config.GlobalConfig.GetBrevAPIURl()) + addResp, err := client.AddNode(ctx, connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: org.ID, Name: name, - DeviceID: deviceID, - NodeSpec: nodeSpec, - }) + DeviceId: deviceID, + NodeSpec: toProtoNodeSpec(nodeSpec), + })) if err != nil { return fmt.Errorf("failed to register node: %w", err) } + node := addResp.Msg.GetExternalNode() reg := &DeviceRegistration{ - ExternalNodeID: addResp.ExternalNode.ExternalNodeID, + ExternalNodeID: node.GetExternalNodeId(), DisplayName: name, OrgID: org.ID, DeviceID: deviceID, @@ -172,8 +175,8 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Green(" Registration complete.")) - if len(addResp.SetupCommands) > 0 { - if err := runSetupCommands(addResp.SetupCommands); err != nil { + if cmds := addResp.Msg.GetSetupCommands(); len(cmds) > 0 { + if err := runSetupCommands(cmds); err != nil { t.Vprintf(" Warning: setup command failed: %v\n", err) } } diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go index 2b2c8a4d..6ca80ba3 100644 --- a/pkg/cmd/register/rpcclient.go +++ b/pkg/cmd/register/rpcclient.go @@ -1,59 +1,14 @@ package register import ( - "bytes" - "context" - "encoding/json" - "fmt" "net/http" + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + breverrors "github.com/brevdev/brev-cli/pkg/errors" ) -// TODO: Replace these local types with generated proto types once the -// ExternalNodeService is published to buf.build: -// -// import ( -// nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" -// nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" -// ) - -// AddNodeRequest matches the proto AddNodeRequest message. -type AddNodeRequest struct { - OrganizationID string `json:"organization_id"` - Name string `json:"name"` - DeviceID string `json:"device_id"` - NodeSpec *NodeSpec `json:"node_spec"` -} - -// AddNodeResponse matches the proto AddNodeResponse message. -type AddNodeResponse struct { - ExternalNode *ExternalNode `json:"external_node"` - SetupCommands map[string]string `json:"setup_commands"` -} - -// ExternalNode matches the proto ExternalNode message (subset of fields we need). -type ExternalNode struct { - ExternalNodeID string `json:"external_node_id"` - OrganizationID string `json:"organization_id"` - Name string `json:"name"` - DeviceID string `json:"device_id"` -} - -// RemoveNodeRequest matches the proto RemoveNodeRequest message. -type RemoveNodeRequest struct { - ExternalNodeID string `json:"external_node_id"` - OrganizationID string `json:"organization_id"` -} - -// ExternalNodeServiceClient defines the RPCs we call from the CLI. -// This will be replaced by the generated ConnectRPC client interface -// once the service is published. -type ExternalNodeServiceClient interface { - AddNode(ctx context.Context, req *AddNodeRequest) (*AddNodeResponse, error) - RemoveNode(ctx context.Context, req *RemoveNodeRequest) error -} - // tokenProvider abstracts access token retrieval for the HTTP transport. type tokenProvider interface { GetAccessToken() (string, error) @@ -90,76 +45,49 @@ func newAuthenticatedHTTPClient(provider tokenProvider) *http.Client { } } -// ConnectNodeClient is a temporary REST-based implementation of ExternalNodeServiceClient. -// It will be replaced by the generated ConnectRPC client once the service proto -// is published to buf.build. -// -// TODO: Replace with: -// -// httpClient := newAuthenticatedHTTPClient(store) -// client := nodev1connect.NewExternalNodeServiceClient(httpClient, baseURL) -type ConnectNodeClient struct { - httpClient *http.Client - baseURL string +// NewNodeServiceClient creates a ConnectRPC ExternalNodeServiceClient using the +// given token provider for authentication. +func NewNodeServiceClient(provider tokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { + return nodev1connect.NewExternalNodeServiceClient( + newAuthenticatedHTTPClient(provider), + baseURL, + ) } -// NewConnectNodeClient creates a new ConnectNodeClient. -func NewConnectNodeClient(provider tokenProvider, baseURL string) *ConnectNodeClient { - return &ConnectNodeClient{ - httpClient: newAuthenticatedHTTPClient(provider), - baseURL: baseURL, +// toProtoNodeSpec converts the local NodeSpec (used for collection, display, persistence) +// to the generated proto NodeSpec for RPC calls. +func toProtoNodeSpec(s *NodeSpec) *nodev1.NodeSpec { + if s == nil { + return nil } -} -func (c *ConnectNodeClient) AddNode(ctx context.Context, req *AddNodeRequest) (*AddNodeResponse, error) { - body, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("failed to marshal AddNodeRequest: %w", err) + proto := &nodev1.NodeSpec{ + RamBytes: s.RAMBytes, + CpuCount: s.CPUCount, + StorageBytes: s.StorageBytes, } - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/devplaneapi.v1.ExternalNodeService/AddNode", bytes.NewReader(body)) - if err != nil { - return nil, breverrors.WrapAndTrace(err) - } - httpReq.Header.Set("Content-Type", "application/json") - - resp, err := c.httpClient.Do(httpReq) - if err != nil { - return nil, fmt.Errorf("AddNode request failed: %w", err) + if s.Architecture != "" { + proto.Architecture = &s.Architecture } - defer resp.Body.Close() //nolint:errcheck // best-effort close - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("AddNode returned status %d", resp.StatusCode) - } - - var result AddNodeResponse - if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { - return nil, fmt.Errorf("failed to decode AddNode response: %w", err) + if s.StorageType != "" { + proto.StorageType = &s.StorageType } - return &result, nil -} - -func (c *ConnectNodeClient) RemoveNode(ctx context.Context, req *RemoveNodeRequest) error { - body, err := json.Marshal(req) - if err != nil { - return fmt.Errorf("failed to marshal RemoveNodeRequest: %w", err) + if s.OS != "" { + proto.Os = &s.OS } - - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL+"/devplaneapi.v1.ExternalNodeService/RemoveNode", bytes.NewReader(body)) - if err != nil { - return breverrors.WrapAndTrace(err) + if s.OSVersion != "" { + proto.OsVersion = &s.OSVersion } - httpReq.Header.Set("Content-Type", "application/json") - resp, err := c.httpClient.Do(httpReq) - if err != nil { - return fmt.Errorf("RemoveNode request failed: %w", err) + for _, g := range s.GPUs { + pg := &nodev1.GPUSpec{ + Model: g.Model, + Count: g.Count, + MemoryBytes: g.MemoryBytes, + } + proto.Gpus = append(proto.Gpus, pg) } - defer resp.Body.Close() //nolint:errcheck // best-effort close - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("RemoveNode returned status %d", resp.StatusCode) - } - return nil + return proto } diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index 771d19f0..a62b9843 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -2,10 +2,13 @@ package register import ( "context" - "encoding/json" "net/http" "net/http/httptest" "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" ) type mockTokenProvider struct { @@ -52,129 +55,226 @@ func Test_bearerTokenTransport_PropagatesTokenError(t *testing.T) { } } -func Test_ConnectNodeClient_AddNode(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/devplaneapi.v1.ExternalNodeService/AddNode" { - t.Errorf("unexpected path: %s", r.URL.Path) - } - if r.Method != http.MethodPost { - t.Errorf("expected POST, got %s", r.Method) - } - if ct := r.Header.Get("Content-Type"); ct != "application/json" { - t.Errorf("expected application/json, got %s", ct) - } - - var req AddNodeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - t.Fatalf("failed to decode request: %v", err) - } - if req.OrganizationID != "org_123" { - t.Errorf("unexpected org ID: %s", req.OrganizationID) - } - if req.Name != "My Spark" { - t.Errorf("unexpected name: %s", req.Name) - } - - resp := AddNodeResponse{ - ExternalNode: &ExternalNode{ - ExternalNodeID: "unode_abc", - OrganizationID: "org_123", - Name: "My Spark", - DeviceID: req.DeviceID, - }, - SetupCommands: map[string]string{"netbird": "netbird up"}, - } - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) //nolint:errcheck // test - })) +func Test_toProtoNodeSpec(t *testing.T) { + cpuCount := int32(12) + ramBytes := int64(137438953472) + memBytes := int64(137438953472) + storageBytes := int64(500107862016) + + local := &NodeSpec{ + GPUs: []NodeGPU{ + {Model: "NVIDIA GB10", Count: 2, MemoryBytes: &memBytes}, + }, + RAMBytes: &ramBytes, + CPUCount: &cpuCount, + Architecture: "arm64", + StorageBytes: &storageBytes, + StorageType: "NVMe", + OS: "Ubuntu", + OSVersion: "24.04", + } + + proto := toProtoNodeSpec(local) + + if proto.GetCpuCount() != 12 { + t.Errorf("expected CpuCount 12, got %d", proto.GetCpuCount()) + } + if proto.GetRamBytes() != 137438953472 { + t.Errorf("expected RamBytes, got %d", proto.GetRamBytes()) + } + if proto.GetArchitecture() != "arm64" { + t.Errorf("expected arm64, got %s", proto.GetArchitecture()) + } + if proto.GetOs() != "Ubuntu" { + t.Errorf("expected Ubuntu, got %s", proto.GetOs()) + } + if proto.GetOsVersion() != "24.04" { + t.Errorf("expected 24.04, got %s", proto.GetOsVersion()) + } + if proto.GetStorageBytes() != 500107862016 { + t.Errorf("expected StorageBytes, got %d", proto.GetStorageBytes()) + } + if proto.GetStorageType() != "NVMe" { + t.Errorf("expected NVMe, got %s", proto.GetStorageType()) + } + if len(proto.GetGpus()) != 1 { + t.Fatalf("expected 1 GPU, got %d", len(proto.GetGpus())) + } + gpu := proto.GetGpus()[0] + if gpu.GetModel() != "NVIDIA GB10" { + t.Errorf("expected NVIDIA GB10, got %s", gpu.GetModel()) + } + if gpu.GetCount() != 2 { + t.Errorf("expected count 2, got %d", gpu.GetCount()) + } + if gpu.GetMemoryBytes() != 137438953472 { + t.Errorf("expected memory bytes, got %d", gpu.GetMemoryBytes()) + } +} + +func Test_toProtoNodeSpec_Nil(t *testing.T) { + if toProtoNodeSpec(nil) != nil { + t.Error("expected nil for nil input") + } +} + +func Test_toProtoNodeSpec_MinimalFields(t *testing.T) { + local := &NodeSpec{ + Architecture: "amd64", + } + proto := toProtoNodeSpec(local) + if proto.GetArchitecture() != "amd64" { + t.Errorf("expected amd64, got %s", proto.GetArchitecture()) + } + if proto.RamBytes != nil { + t.Error("expected nil RamBytes") + } + if proto.CpuCount != nil { + t.Error("expected nil CpuCount") + } + if len(proto.GetGpus()) != 0 { + t.Error("expected no GPUs") + } +} + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + addNodeFn func(*nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) +} + +func (f *fakeNodeService) AddNode(_ context.Context, req *connect.Request[nodev1.AddNodeRequest]) (*connect.Response[nodev1.AddNodeResponse], error) { + resp, err := f.addNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nodev1.RemoveNodeRequest]) (*connect.Response[nodev1.RemoveNodeResponse], error) { + resp, err := f.removeNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func Test_NewNodeServiceClient_AddNode(t *testing.T) { + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org ID: %s", req.GetOrganizationId()) + } + if req.GetName() != "My Spark" { + t.Errorf("unexpected name: %s", req.GetName()) + } + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + SetupCommands: map[string]string{"netbird": "netbird up"}, + }, nil + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) defer server.Close() - provider := &mockTokenProvider{token: "tok"} - client := NewConnectNodeClient(provider, server.URL) + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) - resp, err := client.AddNode(context.Background(), &AddNodeRequest{ - OrganizationID: "org_123", + resp, err := client.AddNode(context.Background(), connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: "org_123", Name: "My Spark", - DeviceID: "dev-uuid", - NodeSpec: &NodeSpec{Architecture: "arm64"}, - }) + DeviceId: "dev-uuid", + NodeSpec: &nodev1.NodeSpec{Architecture: strPtr("arm64")}, + })) if err != nil { t.Fatalf("AddNode failed: %v", err) } - if resp.ExternalNode.ExternalNodeID != "unode_abc" { - t.Errorf("unexpected node ID: %s", resp.ExternalNode.ExternalNodeID) + if resp.Msg.GetExternalNode().GetExternalNodeId() != "unode_abc" { + t.Errorf("unexpected node ID: %s", resp.Msg.GetExternalNode().GetExternalNodeId()) } - if len(resp.SetupCommands) != 1 { - t.Errorf("expected 1 setup command, got %d", len(resp.SetupCommands)) + if len(resp.Msg.GetSetupCommands()) != 1 { + t.Errorf("expected 1 setup command, got %d", len(resp.Msg.GetSetupCommands())) } } -func Test_ConnectNodeClient_AddNode_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusInternalServerError) - })) +func Test_NewNodeServiceClient_AddNode_ServerError(t *testing.T) { + svc := &fakeNodeService{ + addNodeFn: func(_ *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) defer server.Close() - provider := &mockTokenProvider{token: "tok"} - client := NewConnectNodeClient(provider, server.URL) + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) - _, err := client.AddNode(context.Background(), &AddNodeRequest{ - OrganizationID: "org_123", + _, err := client.AddNode(context.Background(), connect.NewRequest(&nodev1.AddNodeRequest{ + OrganizationId: "org_123", Name: "Test", - DeviceID: "dev", - }) + DeviceId: "dev", + })) if err == nil { - t.Fatal("expected error for 500 response") + t.Fatal("expected error for server error response") } } -func Test_ConnectNodeClient_RemoveNode(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/devplaneapi.v1.ExternalNodeService/RemoveNode" { - t.Errorf("unexpected path: %s", r.URL.Path) - } - - var req RemoveNodeRequest - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - t.Fatalf("failed to decode request: %v", err) - } - if req.ExternalNodeID != "unode_abc" { - t.Errorf("unexpected node ID: %s", req.ExternalNodeID) - } - if req.OrganizationID != "org_123" { - t.Errorf("unexpected org ID: %s", req.OrganizationID) - } +func Test_NewNodeServiceClient_RemoveNode(t *testing.T) { + svc := &fakeNodeService{ + removeNodeFn: func(req *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + if req.GetExternalNodeId() != "unode_abc" { + t.Errorf("unexpected node ID: %s", req.GetExternalNodeId()) + } + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org ID: %s", req.GetOrganizationId()) + } + return &nodev1.RemoveNodeResponse{}, nil + }, + } - w.WriteHeader(http.StatusOK) - })) + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) defer server.Close() - provider := &mockTokenProvider{token: "tok"} - client := NewConnectNodeClient(provider, server.URL) + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) - err := client.RemoveNode(context.Background(), &RemoveNodeRequest{ - ExternalNodeID: "unode_abc", - OrganizationID: "org_123", - }) + _, err := client.RemoveNode(context.Background(), connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + })) if err != nil { t.Fatalf("RemoveNode failed: %v", err) } } -func Test_ConnectNodeClient_RemoveNode_ServerError(t *testing.T) { - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotFound) - })) +func Test_NewNodeServiceClient_RemoveNode_ServerError(t *testing.T) { + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return nil, connect.NewError(connect.CodeNotFound, nil) + }, + } + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) defer server.Close() - provider := &mockTokenProvider{token: "tok"} - client := NewConnectNodeClient(provider, server.URL) + client := NewNodeServiceClient(&mockTokenProvider{token: "tok"}, server.URL) - err := client.RemoveNode(context.Background(), &RemoveNodeRequest{ - ExternalNodeID: "unode_missing", - OrganizationID: "org_123", - }) + _, err := client.RemoveNode(context.Background(), connect.NewRequest(&nodev1.RemoveNodeRequest{ + ExternalNodeId: "unode_missing", + OrganizationId: "org_123", + })) if err == nil { - t.Fatal("expected error for 404 response") + t.Fatal("expected error for not found response") } } + +func strPtr(s string) *string { return &s } From feb204d235cad6957a89bb7ab9d5b05fd8f3a210 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Thu, 26 Feb 2026 13:18:54 -0800 Subject: [PATCH 06/11] fix setup --- Makefile | 6 +++--- go.mod | 4 ++-- go.sum | 4 ++++ pkg/cmd/register/netbird.go | 17 +++++++---------- pkg/cmd/register/register.go | 4 ++-- pkg/cmd/register/rpcclient_test.go | 6 +++--- 6 files changed, 21 insertions(+), 20 deletions(-) diff --git a/Makefile b/Makefile index e9ddc9a6..e8514d4f 100644 --- a/Makefile +++ b/Makefile @@ -8,12 +8,12 @@ fast-build: ## go build -o brev CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" .PHONY: local -local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg, or make local for defaults) +local: ## build with env wrapper (use: make local env=dev0|dev1|dev2|stg arch=linux/amd64, or make local for defaults) $(call print-target) ifdef env @echo "Building with env=$(env) wrapper..." @echo ${VERSION} - CGO_ENABLED=0 go build -o brev-local -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" + $(if $(arch),GOOS=$(word 1,$(subst /, ,$(arch))) GOARCH=$(word 2,$(subst /, ,$(arch))),) CGO_ENABLED=0 go build -o brev-local -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" @echo '#!/bin/sh' > brev @echo '# Auto-generated wrapper with environment overrides' >> brev @echo 'export BREV_CONSOLE_URL="https://localhost.nvidia.com:3000"' >> brev @@ -25,7 +25,7 @@ ifdef env @chmod +x brev else @echo "Building without environment overrides (using config.go defaults)..." - $(MAKE) fast-build + $(if $(arch),GOOS=$(word 1,$(subst /, ,$(arch))) GOARCH=$(word 2,$(subst /, ,$(arch))),) CGO_ENABLED=0 go build -o brev -ldflags "-X github.com/brevdev/brev-cli/pkg/cmd/version.Version=${VERSION}" endif .PHONY: install-dev diff --git a/go.mod b/go.mod index e0c66257..352ecb2f 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/brevdev/brev-cli go 1.24.0 require ( - buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2 - buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1 + buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226200709-e1ac2ea142d1.2 + buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226200709-e1ac2ea142d1.1 connectrpc.com/connect v1.19.1 github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 diff --git a/go.sum b/go.sum index a831a91f..33a3e19e 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,11 @@ buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2 h1:8ZrXfJx6gzHTeBU2Lfn2jdpi8q8QJMXZPo8GlVEm6+A= buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2/go.mod h1:EGcIExX0SEtObIZr1l3pouENtdl2gsZtHjOYOfuB7ss= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226200709-e1ac2ea142d1.2 h1:7ld1AqzV9YsRWP5I4FvMUxDiq5fMCEEsNMECCcUJE/s= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226200709-e1ac2ea142d1.2/go.mod h1:ZqSAMH+RVqnfQsnUQ5OpJI7dUWx0UzPUWcceudIHWmI= buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1 h1:xkJkJcCnAq5WiEUevk7Kz3b+aFuK7aj64DyVUQM9ZQ0= buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226200709-e1ac2ea142d1.1 h1:pWlngsd33oF5xFhfTbxYProXMihFXmthzTAcgd3zXKg= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226200709-e1ac2ea142d1.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 h1:6amhprQmCKJ4wgJ6ngkh32d9V+dQcOLUZ/SfHdOnYgo= buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1/go.mod h1:O+pnSHMru/naTMrm4tmpBoH3wz6PHa+R75HR7Mv8X2g= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go index 1f02ed50..4ed9ca9f 100644 --- a/pkg/cmd/register/netbird.go +++ b/pkg/cmd/register/netbird.go @@ -22,16 +22,13 @@ func InstallNetbird(t *terminal.Terminal) error { return nil } -// runSetupCommands executes the setup commands returned by the AddNode RPC. -// The commands are keyed by name; values are shell commands to execute. -func runSetupCommands(commands map[string]string) error { - for name, script := range commands { - cmd := exec.Command("bash", "-c", script) // #nosec G204 - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("setup command %q failed: %w", name, err) - } +// runSetupCommand executes the setup command returned by the AddNode RPC. +func runSetupCommand(script string) error { + cmd := exec.Command("bash", "-c", script) // #nosec G204 + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("setup command failed: %w", err) } return nil } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 78b83246..996bac38 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -175,8 +175,8 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Green(" Registration complete.")) - if cmds := addResp.Msg.GetSetupCommands(); len(cmds) > 0 { - if err := runSetupCommands(cmds); err != nil { + if cmd := addResp.Msg.GetSetupCommand(); cmd != "" { + if err := runSetupCommand(cmd); err != nil { t.Vprintf(" Warning: setup command failed: %v\n", err) } } diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index a62b9843..85384c95 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -176,7 +176,7 @@ func Test_NewNodeServiceClient_AddNode(t *testing.T) { Name: req.GetName(), DeviceId: req.GetDeviceId(), }, - SetupCommands: map[string]string{"netbird": "netbird up"}, + SetupCommand: "netbird up", }, nil }, } @@ -199,8 +199,8 @@ func Test_NewNodeServiceClient_AddNode(t *testing.T) { if resp.Msg.GetExternalNode().GetExternalNodeId() != "unode_abc" { t.Errorf("unexpected node ID: %s", resp.Msg.GetExternalNode().GetExternalNodeId()) } - if len(resp.Msg.GetSetupCommands()) != 1 { - t.Errorf("expected 1 setup command, got %d", len(resp.Msg.GetSetupCommands())) + if resp.Msg.GetSetupCommand() != "netbird up" { + t.Errorf("expected setup command 'netbird up', got %q", resp.Msg.GetSetupCommand()) } } From c9dbf1d0c14c6133823314c9526ee6226ea65f81 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Thu, 26 Feb 2026 14:10:12 -0800 Subject: [PATCH 07/11] more tests --- pkg/cmd/deregister/deregister.go | 69 ++++-- pkg/cmd/deregister/deregister_test.go | 295 +++++++++++++++++++++++ pkg/cmd/register/hardware.go | 2 - pkg/cmd/register/hardware_test.go | 9 +- pkg/cmd/register/identity.go | 16 +- pkg/cmd/register/identity_test.go | 18 +- pkg/cmd/register/register.go | 69 ++++-- pkg/cmd/register/register_test.go | 321 ++++++++++++++++++++++++++ pkg/cmd/register/rpcclient.go | 10 +- 9 files changed, 754 insertions(+), 55 deletions(-) create mode 100644 pkg/cmd/deregister/deregister_test.go create mode 100644 pkg/cmd/register/register_test.go diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 6a1beddd..37fb8e30 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -6,6 +6,7 @@ import ( "fmt" "runtime" + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" @@ -25,6 +26,35 @@ type DeregisterStore interface { GetAccessToken() (string, error) } +// deregisterDeps bundles the side-effecting dependencies of runDeregister so +// they can be replaced in tests. +type deregisterDeps struct { + goos string + promptSelect func(label string, items []string) string + uninstallNetbird func(t *terminal.Terminal) error + newNodeClient func(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + registrationExists func(brevHome string) (bool, error) + loadRegistration func(brevHome string) (*register.DeviceRegistration, error) + deleteRegistration func(brevHome string) error +} + +func prodDeregisterDeps() deregisterDeps { + return deregisterDeps{ + goos: runtime.GOOS, + promptSelect: func(label string, items []string) string { + return terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: label, + Items: items, + }) + }, + uninstallNetbird: register.UninstallNetbird, + newNodeClient: register.NewNodeServiceClient, + registrationExists: register.RegistrationExists, + loadRegistration: register.LoadRegistration, + deleteRegistration: register.DeleteRegistration, + } +} + var ( deregisterLong = `Deregister your DGX Spark from NVIDIA Brev @@ -43,15 +73,15 @@ func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Comman Long: deregisterLong, Example: deregisterExample, RunE: func(cmd *cobra.Command, args []string) error { - return runDeregister(cmd.Context(), t, store) + return runDeregister(cmd.Context(), t, store, prodDeregisterDeps()) }, } return cmd } -func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) error { //nolint:funlen // deregistration flow - if runtime.GOOS != "linux" { +func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, deps deregisterDeps) error { //nolint:funlen // deregistration flow + if deps.goos != "linux" { return fmt.Errorf("brev deregister is only supported on Linux (DGX Spark)") } @@ -60,11 +90,15 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) return breverrors.WrapAndTrace(err) } - if !register.RegistrationExists(brevHome) { + registered, err := deps.registrationExists(brevHome) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if !registered { return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your DGX Spark") } - reg, err := register.LoadRegistration(brevHome) + reg, err := deps.loadRegistration(brevHome) if err != nil { return fmt.Errorf("failed to read registration file: %w", err) } @@ -76,15 +110,15 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) t.Vprintf(" Name: %s\n", reg.DisplayName) t.Vprint("") - removeNetbird := terminal.PromptSelectInput(terminal.PromptSelectContent{ - Label: "Would you also like to uninstall NetBird?", - Items: []string{"Yes, uninstall NetBird", "No, keep NetBird installed"}, - }) + removeNetbird := deps.promptSelect( + "Would you also like to uninstall NetBird?", + []string{"Yes, uninstall NetBird", "No, keep NetBird installed"}, + ) - confirm := terminal.PromptSelectInput(terminal.PromptSelectContent{ - Label: "Proceed with deregistration?", - Items: []string{"Yes, proceed", "No, cancel"}, - }) + confirm := deps.promptSelect( + "Proceed with deregistration?", + []string{"Yes, proceed", "No, cancel"}, + ) if confirm != "Yes, proceed" { t.Vprint("Deregistration canceled.") return nil @@ -92,7 +126,7 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) t.Vprint("") t.Vprint(t.Yellow("Removing node from Brev...")) - client := register.NewNodeServiceClient(s, config.GlobalConfig.GetBrevAPIURl()) + client := deps.newNodeClient(s, config.GlobalConfig.GetBrevAPIURl()) if _, err := client.RemoveNode(ctx, connect.NewRequest(&nodev1.RemoveNodeRequest{ ExternalNodeId: reg.ExternalNodeID, OrganizationId: reg.OrgID, @@ -104,7 +138,7 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) if removeNetbird == "Yes, uninstall NetBird" { t.Vprint("Removing NetBird...") - if err := register.UninstallNetbird(t); err != nil { + if err := deps.uninstallNetbird(t); err != nil { t.Vprintf(" Warning: failed to uninstall NetBird: %v\n", err) } else { t.Vprint(t.Green(" NetBird uninstalled.")) @@ -113,8 +147,9 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore) } t.Vprint("Removing registration data...") - if err := register.DeleteRegistration(brevHome); err != nil { - return fmt.Errorf("failed to remove registration data: %w", err) + if err := deps.deleteRegistration(brevHome); err != nil { + t.Vprintf(" Warning: failed to remove local registration file: %v\n", err) + t.Vprint(" You can manually remove it with: rm ~/.brev/spark_registration.json") } t.Vprint(t.Green("Deregistration complete.")) diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go new file mode 100644 index 00000000..ed4dfd5d --- /dev/null +++ b/pkg/cmd/deregister/deregister_test.go @@ -0,0 +1,295 @@ +package deregister + +import ( + "context" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/cmd/register" + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/files" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/spf13/afero" +) + +type mockDeregisterStore struct { + user *entity.User + home string + token string + err error +} + +func (m *mockDeregisterStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockDeregisterStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockDeregisterStore) GetAccessToken() (string, error) { return m.token, nil } + +// fakeNodeService implements the server side of ExternalNodeService for testing. +type fakeNodeService struct { + nodev1connect.UnimplementedExternalNodeServiceHandler + removeNodeFn func(*nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) +} + +func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nodev1.RemoveNodeRequest]) (*connect.Response[nodev1.RemoveNodeResponse], error) { + resp, err := f.removeNodeFn(req.Msg) + if err != nil { + return nil, err + } + return connect.NewResponse(resp), nil +} + +func setupDeregisterTestFs(t *testing.T) (string, func()) { + t.Helper() + origFs := files.AppFs + files.AppFs = afero.NewMemMapFs() + brevHome := "/home/testuser/.brev" + if err := files.AppFs.MkdirAll(brevHome, 0o700); err != nil { + t.Fatalf("failed to create test dir: %v", err) + } + return brevHome, func() { files.AppFs = origFs } +} + +// testDeregisterDeps returns deps with all side-effects stubbed. The +// promptSelect defaults to confirming all prompts. +func testDeregisterDeps(t *testing.T, svc *fakeNodeService) (deregisterDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return deregisterDeps{ + goos: "linux", + promptSelect: func(_ string, items []string) string { + // Default: pick first item (Yes, ...) + if len(items) > 0 { + return items[0] + } + return "" + }, + uninstallNetbird: func(_ *terminal.Terminal) error { return nil }, + newNodeClient: func(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return register.NewNodeServiceClient(provider, server.URL) + }, + registrationExists: register.RegistrationExists, + loadRegistration: register.LoadRegistration, + deleteRegistration: register.DeleteRegistration, + }, server +} + +func Test_runDeregister_HappyPath(t *testing.T) { + brevHome, cleanup := setupDeregisterTestFs(t) + defer cleanup() + + // Pre-save a registration + _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + DeviceID: "dev-uuid", + }) + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: brevHome, + token: "tok", + } + + var gotNodeID, gotOrgID string + svc := &fakeNodeService{ + removeNodeFn: func(req *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + gotNodeID = req.GetExternalNodeId() + gotOrgID = req.GetOrganizationId() + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + deps, server := testDeregisterDeps(t, svc) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if gotNodeID != "unode_abc" { + t.Errorf("expected node ID unode_abc, got %s", gotNodeID) + } + if gotOrgID != "org_123" { + t.Errorf("expected org ID org_123, got %s", gotOrgID) + } + + // Registration file should be deleted + exists, err := register.RegistrationExists(brevHome) + if err != nil { + t.Fatalf("RegistrationExists error: %v", err) + } + if exists { + t.Error("expected registration file to be deleted after deregister") + } +} + +func Test_runDeregister_UserCancels(t *testing.T) { + brevHome, cleanup := setupDeregisterTestFs(t) + defer cleanup() + + _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }) + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testDeregisterDeps(t, svc) + defer server.Close() + + callCount := 0 + deps.promptSelect = func(_ string, _ []string) string { + callCount++ + if callCount == 2 { + // Second prompt is the confirmation — cancel it + return "No, cancel" + } + return "No, keep NetBird installed" + } + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("expected nil error on cancel, got: %v", err) + } + + // Registration file should still exist + exists, err := register.RegistrationExists(brevHome) + if err != nil { + t.Fatalf("RegistrationExists error: %v", err) + } + if !exists { + t.Error("registration should still exist after cancel") + } +} + +func Test_runDeregister_NotRegistered(t *testing.T) { + brevHome, cleanup := setupDeregisterTestFs(t) + defer cleanup() + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testDeregisterDeps(t, svc) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when not registered") + } +} + +func Test_runDeregister_RemoveNodeFails(t *testing.T) { + brevHome, cleanup := setupDeregisterTestFs(t) + defer cleanup() + + _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }) + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testDeregisterDeps(t, svc) + defer server.Close() + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err == nil { + t.Fatal("expected error when RemoveNode fails") + } + + // Registration file should still exist (server-side removal failed) + exists, err := register.RegistrationExists(brevHome) + if err != nil { + t.Fatalf("RegistrationExists error: %v", err) + } + if !exists { + t.Error("registration should still exist when RemoveNode fails") + } +} + +func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { + brevHome, cleanup := setupDeregisterTestFs(t) + defer cleanup() + + _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }) + + store := &mockDeregisterStore{ + user: &entity.User{ID: "user_1"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{ + removeNodeFn: func(_ *nodev1.RemoveNodeRequest) (*nodev1.RemoveNodeResponse, error) { + return &nodev1.RemoveNodeResponse{}, nil + }, + } + + uninstallCalled := false + deps, server := testDeregisterDeps(t, svc) + defer server.Close() + + deps.promptSelect = func(label string, items []string) string { + if label == "Would you also like to uninstall NetBird?" { + return "No, keep NetBird installed" + } + return "Yes, proceed" + } + deps.uninstallNetbird = func(_ *terminal.Terminal) error { + uninstallCalled = true + return nil + } + + term := terminal.New() + err := runDeregister(context.Background(), term, store, deps) + if err != nil { + t.Fatalf("runDeregister failed: %v", err) + } + + if uninstallCalled { + t.Error("NetBird uninstall should not be called when user declines") + } +} diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go index f23f8721..5130294b 100644 --- a/pkg/cmd/register/hardware.go +++ b/pkg/cmd/register/hardware.go @@ -264,8 +264,6 @@ func parseStorageOutput(output string) (int64, string) { if storageType == "" { if strings.HasPrefix(fields[0], "nvme") { storageType = "NVMe" - } else { - storageType = "SSD" } } } diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go index 7e53fe7c..5370bf6a 100644 --- a/pkg/cmd/register/hardware_test.go +++ b/pkg/cmd/register/hardware_test.go @@ -153,9 +153,12 @@ sda 2048 rom func Test_parseStorageOutput_SDA(t *testing.T) { output := `sda 500107862016 disk ` - _, storageType := parseStorageOutput(output) - if storageType != "SSD" { - t.Errorf("expected SSD, got %s", storageType) + totalBytes, storageType := parseStorageOutput(output) + if totalBytes != 500107862016 { + t.Errorf("expected 500107862016 bytes, got %d", totalBytes) + } + if storageType != "" { + t.Errorf("expected empty storage type for non-nvme disk, got %s", storageType) } } diff --git a/pkg/cmd/register/identity.go b/pkg/cmd/register/identity.go index 9b0c741a..18455295 100644 --- a/pkg/cmd/register/identity.go +++ b/pkg/cmd/register/identity.go @@ -2,6 +2,7 @@ package register import ( "encoding/json" + "os" "path/filepath" breverrors "github.com/brevdev/brev-cli/pkg/errors" @@ -33,7 +34,7 @@ func SaveRegistration(brevHome string, reg *DeviceRegistration) error { if err != nil { return breverrors.WrapAndTrace(err) } - if err := files.AppFs.MkdirAll(filepath.Dir(path), 0o770); err != nil { + if err := files.AppFs.MkdirAll(filepath.Dir(path), 0o700); err != nil { return breverrors.WrapAndTrace(err) } if err := afero.WriteFile(files.AppFs, path, data, 0o600); err != nil { @@ -64,8 +65,15 @@ func DeleteRegistration(brevHome string) error { } // RegistrationExists checks if a registration file exists. -func RegistrationExists(brevHome string) bool { +// Returns (exists, error) so callers can distinguish "not found" from real errors. +func RegistrationExists(brevHome string) (bool, error) { path := registrationPath(brevHome) - exists, _ := files.AppFs.Stat(path) - return exists != nil + _, err := files.AppFs.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, breverrors.WrapAndTrace(err) } diff --git a/pkg/cmd/register/identity_test.go b/pkg/cmd/register/identity_test.go index 8fdfb1f0..d5d08f85 100644 --- a/pkg/cmd/register/identity_test.go +++ b/pkg/cmd/register/identity_test.go @@ -70,7 +70,11 @@ func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() - if RegistrationExists(brevHome) { + exists, err := RegistrationExists(brevHome) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { t.Error("expected RegistrationExists to return false") } } @@ -87,7 +91,11 @@ func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { t.Fatalf("SaveRegistration failed: %v", err) } - if !RegistrationExists(brevHome) { + exists, err := RegistrationExists(brevHome) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !exists { t.Error("expected RegistrationExists to return true") } } @@ -108,7 +116,11 @@ func Test_DeleteRegistration_RemovesFile(t *testing.T) { t.Fatalf("DeleteRegistration failed: %v", err) } - if RegistrationExists(brevHome) { + exists, err := RegistrationExists(brevHome) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if exists { t.Error("expected RegistrationExists to return false after delete") } } diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 996bac38..68dd8cba 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -9,6 +9,7 @@ import ( "runtime" "time" + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" "connectrpc.com/connect" "github.com/google/uuid" @@ -40,6 +41,36 @@ func (r OSFileReader) ReadFile(path string) ([]byte, error) { return data, nil } +// registerDeps bundles the side-effecting dependencies of runRegister so they +// can be replaced in tests. +type registerDeps struct { + goos string + promptYesNo func(label string) bool + installNetbird func(t *terminal.Terminal) error + runSetupCommand func(script string) error + newNodeClient func(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + commandRunner CommandRunner + fileReader FileReader +} + +func prodRegisterDeps() registerDeps { + return registerDeps{ + goos: runtime.GOOS, + promptYesNo: func(label string) bool { + result := terminal.PromptSelectInput(terminal.PromptSelectContent{ + Label: label, + Items: []string{"Yes, proceed", "No, cancel"}, + }) + return result == "Yes, proceed" + }, + installNetbird: InstallNetbird, + runSetupCommand: runSetupCommand, + newNodeClient: NewNodeServiceClient, + commandRunner: ExecCommandRunner{}, + fileReader: OSFileReader{}, + } +} + var ( registerLong = `Register your DGX Spark with NVIDIA Brev @@ -60,7 +91,7 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { Long: registerLong, Example: registerExample, RunE: func(cmd *cobra.Command, args []string) error { - return runRegister(cmd.Context(), t, store, name) + return runRegister(cmd.Context(), t, store, name, prodRegisterDeps()) }, } @@ -70,12 +101,12 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { return cmd } -func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string) error { //nolint:funlen // registration flow - if runtime.GOOS != "linux" { +func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow + if deps.goos != "linux" { return fmt.Errorf("brev register is only supported on Linux (DGX Spark)") } - currentUser, err := s.GetCurrentUser() + _, err := s.GetCurrentUser() // ensure active token if err != nil { return breverrors.WrapAndTrace(err) } @@ -93,14 +124,16 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam return breverrors.WrapAndTrace(err) } - if RegistrationExists(brevHome) { + alreadyRegistered, err := RegistrationExists(brevHome) + if err != nil { + return breverrors.WrapAndTrace(err) + } + if alreadyRegistered { return fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") } - linuxUser := currentUser.Username - if u, err := user.Current(); err == nil { - linuxUser = u.Username - } + u, _ := user.Current() + linuxUser := u.Username t.Vprint("") t.Vprint(t.Green("Registering your DGX Spark with Brev")) @@ -110,23 +143,19 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprintf(" Linux user: %s\n", linuxUser) t.Vprint("") t.Vprint("This will perform the following steps:") - t.Vprint(" 1. Install NetBird (network agent)") + t.Vprint(" 1. Install NetBird") t.Vprint(" 2. Collect hardware profile") t.Vprint(" 3. Register this machine with Brev") t.Vprint("") - result := terminal.PromptSelectInput(terminal.PromptSelectContent{ - Label: "Proceed with registration?", - Items: []string{"Yes, proceed", "No, cancel"}, - }) - if result != "Yes, proceed" { + if !deps.promptYesNo("Proceed with registration?") { t.Vprint("Registration canceled.") return nil } t.Vprint("") t.Vprint(t.Yellow("[Step 1/3] Installing NetBird...")) - if err := InstallNetbird(t); err != nil { + if err := deps.installNetbird(t); err != nil { return fmt.Errorf("NetBird installation failed: %w", err) } t.Vprint(t.Green(" NetBird installed successfully.")) @@ -135,9 +164,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Yellow("[Step 2/3] Collecting hardware profile...")) t.Vprint("") - runner := ExecCommandRunner{} - reader := OSFileReader{} - nodeSpec, err := CollectHardwareProfile(runner, reader) + nodeSpec, err := CollectHardwareProfile(deps.commandRunner, deps.fileReader) if err != nil { return fmt.Errorf("failed to collect hardware profile: %w", err) } @@ -149,7 +176,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Yellow("[Step 3/3] Registering with Brev...")) deviceID := uuid.New().String() - client := NewNodeServiceClient(s, config.GlobalConfig.GetBrevAPIURl()) + client := deps.newNodeClient(s, config.GlobalConfig.GetBrevAPIURl()) addResp, err := client.AddNode(ctx, connect.NewRequest(&nodev1.AddNodeRequest{ OrganizationId: org.ID, Name: name, @@ -176,7 +203,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Green(" Registration complete.")) if cmd := addResp.Msg.GetSetupCommand(); cmd != "" { - if err := runSetupCommand(cmd); err != nil { + if err := deps.runSetupCommand(cmd); err != nil { t.Vprintf(" Warning: setup command failed: %v\n", err) } } diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go new file mode 100644 index 00000000..9a26eba5 --- /dev/null +++ b/pkg/cmd/register/register_test.go @@ -0,0 +1,321 @@ +package register + +import ( + "context" + "net/http/httptest" + "testing" + + nodev1connect "buf.build/gen/go/brevdev/devplane/connectrpc/go/devplaneapi/v1/devplaneapiv1connect" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + + "github.com/brevdev/brev-cli/pkg/entity" + "github.com/brevdev/brev-cli/pkg/files" + "github.com/brevdev/brev-cli/pkg/terminal" + "github.com/spf13/afero" +) + +// mockRegisterStore satisfies RegisterStore for orchestration tests. +type mockRegisterStore struct { + user *entity.User + org *entity.Organization + home string + token string + err error +} + +func (m *mockRegisterStore) GetCurrentUser() (*entity.User, error) { + if m.err != nil { + return nil, m.err + } + return m.user, nil +} + +func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organization, error) { + return m.org, nil +} + +func (m *mockRegisterStore) GetBrevHomePath() (string, error) { return m.home, nil } +func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } + +// testRegisterDeps returns deps with all side-effects stubbed out, and a fake +// ConnectRPC server backed by the provided fakeNodeService. +func testRegisterDeps(t *testing.T, svc *fakeNodeService) (registerDeps, *httptest.Server) { + t.Helper() + + _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) + server := httptest.NewServer(handler) + + return registerDeps{ + goos: "linux", + promptYesNo: func(_ string) bool { return true }, + installNetbird: func(_ *terminal.Terminal) error { return nil }, + runSetupCommand: func(_ string) error { return nil }, + newNodeClient: func(provider TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { + return NewNodeServiceClient(provider, server.URL) + }, + commandRunner: &mockCommandRunner{ + outputs: map[string][]byte{ + "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), + "lsblk": []byte("nvme0n1 500107862016 disk\n"), + }, + }, + fileReader: &mockFileReader{ + files: map[string][]byte{ + "/proc/cpuinfo": []byte("processor\t: 0\nprocessor\t: 1\n"), + "/proc/meminfo": []byte("MemTotal: 131886028 kB\n"), + "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), + }, + }, + }, server +} + +func setupRegisterTestFs(t *testing.T) (string, func()) { + t.Helper() + origFs := files.AppFs + files.AppFs = afero.NewMemMapFs() + brevHome := "/home/testuser/.brev" + if err := files.AppFs.MkdirAll(brevHome, 0o700); err != nil { + t.Fatalf("failed to create test dir: %v", err) + } + return brevHome, func() { files.AppFs = origFs } +} + +func Test_runRegister_HappyPath(t *testing.T) { + brevHome, cleanup := setupRegisterTestFs(t) + defer cleanup() + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: brevHome, + token: "tok", + } + + var gotSetupCmd string + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + if req.GetOrganizationId() != "org_123" { + t.Errorf("unexpected org: %s", req.GetOrganizationId()) + } + if req.GetName() != "My Spark" { + t.Errorf("unexpected name: %s", req.GetName()) + } + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + SetupCommand: "netbird up --key abc", + }, nil + }, + } + + deps, server := testRegisterDeps(t, svc) + defer server.Close() + + deps.runSetupCommand = func(cmd string) error { + gotSetupCmd = cmd + return nil + } + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + // Verify registration was persisted + exists, err := RegistrationExists(brevHome) + if err != nil { + t.Fatalf("RegistrationExists error: %v", err) + } + if !exists { + t.Fatal("expected registration file to exist after successful register") + } + + reg, err := LoadRegistration(brevHome) + if err != nil { + t.Fatalf("LoadRegistration failed: %v", err) + } + if reg.ExternalNodeID != "unode_abc" { + t.Errorf("expected ExternalNodeID unode_abc, got %s", reg.ExternalNodeID) + } + if reg.DisplayName != "My Spark" { + t.Errorf("expected display name 'My Spark', got %s", reg.DisplayName) + } + if reg.OrgID != "org_123" { + t.Errorf("expected org org_123, got %s", reg.OrgID) + } + + // Verify setup command was executed + if gotSetupCmd != "netbird up --key abc" { + t.Errorf("expected setup command 'netbird up --key abc', got %q", gotSetupCmd) + } +} + +func Test_runRegister_UserCancels(t *testing.T) { + brevHome, cleanup := setupRegisterTestFs(t) + defer cleanup() + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc) + defer server.Close() + + deps.promptYesNo = func(_ string) bool { return false } + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("expected nil error on cancel, got: %v", err) + } + + // Registration file should not exist + exists, err := RegistrationExists(brevHome) + if err != nil { + t.Fatalf("RegistrationExists error: %v", err) + } + if exists { + t.Error("registration file should not exist after cancel") + } +} + +func Test_runRegister_AlreadyRegistered(t *testing.T) { + brevHome, cleanup := setupRegisterTestFs(t) + defer cleanup() + + // Save an existing registration + _ = SaveRegistration(brevHome, &DeviceRegistration{ + ExternalNodeID: "unode_existing", + DisplayName: "Existing", + }) + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error for already-registered machine") + } +} + +func Test_runRegister_NoOrganization(t *testing.T) { + brevHome, cleanup := setupRegisterTestFs(t) + defer cleanup() + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: nil, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{} + deps, server := testRegisterDeps(t, svc) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error when no org exists") + } +} + +func Test_runRegister_AddNodeFails(t *testing.T) { + brevHome, cleanup := setupRegisterTestFs(t) + defer cleanup() + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: brevHome, + token: "tok", + } + + svc := &fakeNodeService{ + addNodeFn: func(_ *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return nil, connect.NewError(connect.CodeInternal, nil) + }, + } + + deps, server := testRegisterDeps(t, svc) + defer server.Close() + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err == nil { + t.Fatal("expected error when AddNode fails") + } + + // Registration file should not exist on failure + exists, err := RegistrationExists(brevHome) + if err != nil { + t.Fatalf("RegistrationExists error: %v", err) + } + if exists { + t.Error("registration file should not exist after AddNode failure") + } +} + +func Test_runRegister_NoSetupCommand(t *testing.T) { + brevHome, cleanup := setupRegisterTestFs(t) + defer cleanup() + + store := &mockRegisterStore{ + user: &entity.User{ID: "user_1"}, + org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, + home: brevHome, + token: "tok", + } + + setupCalled := false + svc := &fakeNodeService{ + addNodeFn: func(req *nodev1.AddNodeRequest) (*nodev1.AddNodeResponse, error) { + return &nodev1.AddNodeResponse{ + ExternalNode: &nodev1.ExternalNode{ + ExternalNodeId: "unode_abc", + OrganizationId: "org_123", + Name: req.GetName(), + DeviceId: req.GetDeviceId(), + }, + // No SetupCommand + }, nil + }, + } + + deps, server := testRegisterDeps(t, svc) + defer server.Close() + + deps.runSetupCommand = func(_ string) error { + setupCalled = true + return nil + } + + term := terminal.New() + err := runRegister(context.Background(), term, store, "My Spark", deps) + if err != nil { + t.Fatalf("runRegister failed: %v", err) + } + + if setupCalled { + t.Error("setup command should not be called when empty") + } +} diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go index 6ca80ba3..2b07d83d 100644 --- a/pkg/cmd/register/rpcclient.go +++ b/pkg/cmd/register/rpcclient.go @@ -9,14 +9,14 @@ import ( breverrors "github.com/brevdev/brev-cli/pkg/errors" ) -// tokenProvider abstracts access token retrieval for the HTTP transport. -type tokenProvider interface { +// TokenProvider abstracts access token retrieval for the HTTP transport. +type TokenProvider interface { GetAccessToken() (string, error) } // bearerTokenTransport injects a Bearer token into every request. type bearerTokenTransport struct { - provider tokenProvider + provider TokenProvider base http.RoundTripper } @@ -36,7 +36,7 @@ func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, err // newAuthenticatedHTTPClient creates an http.Client that injects the bearer token // from the given provider on every request. -func newAuthenticatedHTTPClient(provider tokenProvider) *http.Client { +func newAuthenticatedHTTPClient(provider TokenProvider) *http.Client { return &http.Client{ Transport: &bearerTokenTransport{ provider: provider, @@ -47,7 +47,7 @@ func newAuthenticatedHTTPClient(provider tokenProvider) *http.Client { // NewNodeServiceClient creates a ConnectRPC ExternalNodeServiceClient using the // given token provider for authentication. -func NewNodeServiceClient(provider tokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { +func NewNodeServiceClient(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient { return nodev1connect.NewExternalNodeServiceClient( newAuthenticatedHTTPClient(provider), baseURL, From 97a8e38914e47813001076bb20c035b33ecfe229 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Thu, 26 Feb 2026 14:52:14 -0800 Subject: [PATCH 08/11] storage --- pkg/cmd/register/hardware.go | 70 ++++++++++++----------- pkg/cmd/register/hardware_test.go | 90 +++++++++++++++++------------- pkg/cmd/register/register_test.go | 2 +- pkg/cmd/register/rpcclient.go | 26 +++++++-- pkg/cmd/register/rpcclient_test.go | 14 ++--- 5 files changed, 116 insertions(+), 86 deletions(-) diff --git a/pkg/cmd/register/hardware.go b/pkg/cmd/register/hardware.go index 5130294b..721d32c9 100644 --- a/pkg/cmd/register/hardware.go +++ b/pkg/cmd/register/hardware.go @@ -30,14 +30,19 @@ func (r ExecCommandRunner) Run(name string, args ...string) ([]byte, error) { // NodeSpec matches the proto NodeSpec message from dev-plane. // All fields are best-effort. type NodeSpec struct { - GPUs []NodeGPU `json:"gpus"` - RAMBytes *int64 `json:"ram_bytes,omitempty"` - CPUCount *int32 `json:"cpu_count,omitempty"` - Architecture string `json:"architecture,omitempty"` - StorageBytes *int64 `json:"storage_bytes,omitempty"` - StorageType string `json:"storage_type,omitempty"` - OS string `json:"os,omitempty"` - OSVersion string `json:"os_version,omitempty"` + GPUs []NodeGPU `json:"gpus"` + RAMBytes *int64 `json:"ram_bytes,omitempty"` + CPUCount *int32 `json:"cpu_count,omitempty"` + Architecture string `json:"architecture,omitempty"` + Storage []NodeStorage `json:"storage,omitempty"` + OS string `json:"os,omitempty"` + OSVersion string `json:"os_version,omitempty"` +} + +// NodeStorage represents a single storage device with its size and type. +type NodeStorage struct { + StorageBytes int64 `json:"storage_bytes"` + StorageType string `json:"storage_type,omitempty"` // "SSD" or "HDD" } // NodeGPU matches the proto NodeGPU message. @@ -76,11 +81,7 @@ func CollectHardwareProfile(runner CommandRunner, reader FileReader) (*NodeSpec, spec.OS = osName spec.OSVersion = osVersion - storageBytes, storageType := collectStorage(runner) - if storageBytes > 0 { - spec.StorageBytes = &storageBytes - spec.StorageType = storageType - } + spec.Storage = collectStorage(runner) return spec, nil } @@ -235,39 +236,42 @@ func parseNvidiaSMIOutput(output string) []NodeGPU { return gpus } -// collectStorage sums disk devices from lsblk to get total storage bytes -// and infers a storage type from the device names. -func collectStorage(runner CommandRunner) (int64, string) { - out, err := runner.Run("lsblk", "-b", "-d", "-n", "-o", "NAME,SIZE,TYPE") +// collectStorage returns per-device storage entries from lsblk, +// using the ROTA column to determine device type. +func collectStorage(runner CommandRunner) []NodeStorage { + out, err := runner.Run("lsblk", "-b", "-d", "-n", "-o", "NAME,SIZE,TYPE,ROTA") if err != nil { - return 0, "" + return nil } return parseStorageOutput(string(out)) } -// parseStorageOutput parses lsblk output, summing disk device sizes and -// inferring storage type. -func parseStorageOutput(output string) (int64, string) { - var totalBytes int64 - storageType := "" +// parseStorageOutput parses lsblk output (NAME,SIZE,TYPE,ROTA columns), +// returning one NodeStorage entry per disk device. ROTA=0 → SSD, ROTA=1 → HDD. +func parseStorageOutput(output string) []NodeStorage { + var devices []NodeStorage scanner := bufio.NewScanner(strings.NewReader(output)) for scanner.Scan() { fields := strings.Fields(scanner.Text()) - if len(fields) < 3 || fields[2] != "disk" { + if len(fields) < 4 || fields[2] != "disk" { continue } size, err := strconv.ParseInt(fields[1], 10, 64) if err != nil { continue } - totalBytes += size - if storageType == "" { - if strings.HasPrefix(fields[0], "nvme") { - storageType = "NVMe" + entry := NodeStorage{StorageBytes: size} + rota, err := strconv.Atoi(fields[3]) + if err == nil { + if rota == 0 { + entry.StorageType = "SSD" + } else { + entry.StorageType = "HDD" } } + devices = append(devices, entry) } - return totalBytes, storageType + return devices } // FormatNodeSpec returns a human-readable summary of the hardware profile. @@ -291,10 +295,10 @@ func FormatNodeSpec(s *NodeSpec) string { if s.OS != "" || s.OSVersion != "" { _, _ = fmt.Fprintf(&b, " OS: %s %s\n", s.OS, s.OSVersion) } - if s.StorageBytes != nil { - _, _ = fmt.Fprintf(&b, " Storage: %d GB", *s.StorageBytes/(1024*1024*1024)) - if s.StorageType != "" { - _, _ = fmt.Fprintf(&b, " (%s)", s.StorageType) + for _, st := range s.Storage { + _, _ = fmt.Fprintf(&b, " Storage: %d GB", st.StorageBytes/(1024*1024*1024)) + if st.StorageType != "" { + _, _ = fmt.Fprintf(&b, " (%s)", st.StorageType) } b.WriteString("\n") } diff --git a/pkg/cmd/register/hardware_test.go b/pkg/cmd/register/hardware_test.go index 5370bf6a..92fd8fc6 100644 --- a/pkg/cmd/register/hardware_test.go +++ b/pkg/cmd/register/hardware_test.go @@ -136,29 +136,40 @@ func Test_parseNvidiaSMIOutput_Empty(t *testing.T) { } func Test_parseStorageOutput(t *testing.T) { - output := `nvme0n1 500107862016 disk -nvme1n1 1000204886016 disk -sda 2048 rom + output := `nvme0n1 500107862016 disk 0 +nvme1n1 1000204886016 disk 0 +sda 2048 rom 1 ` - totalBytes, storageType := parseStorageOutput(output) - expected := int64(500107862016 + 1000204886016) - if totalBytes != expected { - t.Errorf("expected %d bytes, got %d", expected, totalBytes) + devices := parseStorageOutput(output) + if len(devices) != 2 { + t.Fatalf("expected 2 devices, got %d", len(devices)) } - if storageType != "NVMe" { - t.Errorf("expected NVMe, got %s", storageType) + if devices[0].StorageBytes != 500107862016 { + t.Errorf("expected 500107862016, got %d", devices[0].StorageBytes) + } + if devices[0].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", devices[0].StorageType) + } + if devices[1].StorageBytes != 1000204886016 { + t.Errorf("expected 1000204886016, got %d", devices[1].StorageBytes) + } + if devices[1].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", devices[1].StorageType) } } func Test_parseStorageOutput_SDA(t *testing.T) { - output := `sda 500107862016 disk + output := `sda 500107862016 disk 1 ` - totalBytes, storageType := parseStorageOutput(output) - if totalBytes != 500107862016 { - t.Errorf("expected 500107862016 bytes, got %d", totalBytes) + devices := parseStorageOutput(output) + if len(devices) != 1 { + t.Fatalf("expected 1 device, got %d", len(devices)) } - if storageType != "" { - t.Errorf("expected empty storage type for non-nvme disk, got %s", storageType) + if devices[0].StorageBytes != 500107862016 { + t.Errorf("expected 500107862016 bytes, got %d", devices[0].StorageBytes) + } + if devices[0].StorageType != "HDD" { + t.Errorf("expected HDD, got %s", devices[0].StorageType) } } @@ -233,18 +244,22 @@ func Test_FormatNodeSpec_MinimalFields(t *testing.T) { } func Test_FormatNodeSpec_WithStorage(t *testing.T) { - storageBytes := int64(1099511627776) // 1 TB s := &NodeSpec{ Architecture: "amd64", - StorageBytes: &storageBytes, - StorageType: "NVMe", + Storage: []NodeStorage{ + {StorageBytes: 500107862016, StorageType: "SSD"}, + {StorageBytes: 1000204886016, StorageType: "HDD"}, + }, } output := FormatNodeSpec(s) if !strings.Contains(output, "Storage:") { t.Errorf("expected storage in output: %s", output) } - if !strings.Contains(output, "NVMe") { - t.Errorf("expected NVMe in output: %s", output) + if !strings.Contains(output, "SSD") { + t.Errorf("expected SSD in output: %s", output) + } + if !strings.Contains(output, "HDD") { + t.Errorf("expected HDD in output: %s", output) } } @@ -266,25 +281,19 @@ NVIDIA A100, not-a-number } func Test_parseStorageOutput_Empty(t *testing.T) { - totalBytes, storageType := parseStorageOutput("") - if totalBytes != 0 { - t.Errorf("expected 0 bytes, got %d", totalBytes) - } - if storageType != "" { - t.Errorf("expected empty storage type, got %s", storageType) + devices := parseStorageOutput("") + if len(devices) != 0 { + t.Errorf("expected 0 devices, got %d", len(devices)) } } func Test_parseStorageOutput_NoDiskDevices(t *testing.T) { - output := `sr0 1073741312 rom -loop0 123456 loop + output := `sr0 1073741312 rom 1 +loop0 123456 loop 0 ` - totalBytes, storageType := parseStorageOutput(output) - if totalBytes != 0 { - t.Errorf("expected 0 bytes for non-disk devices, got %d", totalBytes) - } - if storageType != "" { - t.Errorf("expected empty storage type, got %s", storageType) + devices := parseStorageOutput(output) + if len(devices) != 0 { + t.Errorf("expected 0 devices for non-disk entries, got %d", len(devices)) } } @@ -324,7 +333,7 @@ func Test_CollectHardwareProfile_WithMocks(t *testing.T) { runner := &mockCommandRunner{ outputs: map[string][]byte{ "nvidia-smi": []byte("NVIDIA GB10, 131072\nNVIDIA GB10, 131072\n"), - "lsblk": []byte("nvme0n1 500107862016 disk\n"), + "lsblk": []byte("nvme0n1 500107862016 disk 0\n"), }, } reader := &mockFileReader{ @@ -351,8 +360,11 @@ func Test_CollectHardwareProfile_WithMocks(t *testing.T) { if spec.OS != "Ubuntu" || spec.OSVersion != "24.04" { t.Errorf("unexpected OS: %s %s", spec.OS, spec.OSVersion) } - if spec.StorageBytes == nil || *spec.StorageBytes != 500107862016 { - t.Errorf("unexpected storage: %v", spec.StorageBytes) + if len(spec.Storage) != 1 || spec.Storage[0].StorageBytes != 500107862016 { + t.Errorf("unexpected storage: %v", spec.Storage) + } + if spec.Storage[0].StorageType != "SSD" { + t.Errorf("expected SSD, got %s", spec.Storage[0].StorageType) } } @@ -404,8 +416,8 @@ func Test_CollectHardwareProfile_OptionalFieldsMissing(t *testing.T) { if spec.RAMBytes != nil { t.Errorf("expected nil RAMBytes when /proc/meminfo missing") } - if spec.StorageBytes != nil { - t.Errorf("expected nil StorageBytes when lsblk fails") + if len(spec.Storage) != 0 { + t.Errorf("expected empty Storage when lsblk fails, got %v", spec.Storage) } if len(spec.GPUs) != 1 { t.Errorf("expected 1 GPU, got %d", len(spec.GPUs)) diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 9a26eba5..f89f8cbd 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -57,7 +57,7 @@ func testRegisterDeps(t *testing.T, svc *fakeNodeService) (registerDeps, *httpte commandRunner: &mockCommandRunner{ outputs: map[string][]byte{ "nvidia-smi": []byte("NVIDIA GB10, 131072\n"), - "lsblk": []byte("nvme0n1 500107862016 disk\n"), + "lsblk": []byte("nvme0n1 500107862016 disk 0\n"), }, }, fileReader: &mockFileReader{ diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go index 2b07d83d..0f74bbd8 100644 --- a/pkg/cmd/register/rpcclient.go +++ b/pkg/cmd/register/rpcclient.go @@ -62,17 +62,31 @@ func toProtoNodeSpec(s *NodeSpec) *nodev1.NodeSpec { } proto := &nodev1.NodeSpec{ - RamBytes: s.RAMBytes, - CpuCount: s.CPUCount, - StorageBytes: s.StorageBytes, + RamBytes: s.RAMBytes, + CpuCount: s.CPUCount, + } + + // Bridge: sum storage array into the scalar proto fields until the + // proto is updated with repeated StorageSpec. Delete this block when + // we `go get` the new buf module commit. + if len(s.Storage) > 0 { + var totalBytes int64 + var firstType string + for _, st := range s.Storage { + totalBytes += st.StorageBytes + if firstType == "" && st.StorageType != "" { + firstType = st.StorageType + } + } + proto.StorageBytes = &totalBytes + if firstType != "" { + proto.StorageType = &firstType + } } if s.Architecture != "" { proto.Architecture = &s.Architecture } - if s.StorageType != "" { - proto.StorageType = &s.StorageType - } if s.OS != "" { proto.Os = &s.OS } diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index 85384c95..9b24c697 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -59,7 +59,6 @@ func Test_toProtoNodeSpec(t *testing.T) { cpuCount := int32(12) ramBytes := int64(137438953472) memBytes := int64(137438953472) - storageBytes := int64(500107862016) local := &NodeSpec{ GPUs: []NodeGPU{ @@ -68,10 +67,11 @@ func Test_toProtoNodeSpec(t *testing.T) { RAMBytes: &ramBytes, CPUCount: &cpuCount, Architecture: "arm64", - StorageBytes: &storageBytes, - StorageType: "NVMe", - OS: "Ubuntu", - OSVersion: "24.04", + Storage: []NodeStorage{ + {StorageBytes: 500107862016, StorageType: "SSD"}, + }, + OS: "Ubuntu", + OSVersion: "24.04", } proto := toProtoNodeSpec(local) @@ -94,8 +94,8 @@ func Test_toProtoNodeSpec(t *testing.T) { if proto.GetStorageBytes() != 500107862016 { t.Errorf("expected StorageBytes, got %d", proto.GetStorageBytes()) } - if proto.GetStorageType() != "NVMe" { - t.Errorf("expected NVMe, got %s", proto.GetStorageType()) + if proto.GetStorageType() != "SSD" { + t.Errorf("expected SSD, got %s", proto.GetStorageType()) } if len(proto.GetGpus()) != 1 { t.Fatalf("expected 1 GPU, got %d", len(proto.GetGpus())) From 0ee64278581350dd181c7e8eae4e1e484df85fa7 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Thu, 26 Feb 2026 18:18:17 -0800 Subject: [PATCH 09/11] fixing abstractions and tests --- go.mod | 4 +- go.sum | 4 + pkg/cmd/deregister/deregister.go | 55 +++---- pkg/cmd/deregister/deregister_test.go | 150 +++++++++--------- pkg/cmd/register/netbird.go | 12 +- pkg/cmd/register/register.go | 104 ++++++------ pkg/cmd/register/register_test.go | 126 ++++++++------- .../register/{identity.go => registration.go} | 45 ++++-- ...{identity_test.go => registration_test.go} | 48 +++--- pkg/cmd/register/rpcclient.go | 21 +-- pkg/cmd/register/rpcclient_test.go | 11 +- 11 files changed, 310 insertions(+), 270 deletions(-) rename pkg/cmd/register/{identity.go => registration.go} (57%) rename pkg/cmd/register/{identity_test.go => registration_test.go} (75%) diff --git a/go.mod b/go.mod index 352ecb2f..608c9509 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,8 @@ module github.com/brevdev/brev-cli go 1.24.0 require ( - buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226200709-e1ac2ea142d1.2 - buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226200709-e1ac2ea142d1.1 + buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226234124-59cddad562f0.2 + buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226234124-59cddad562f0.1 connectrpc.com/connect v1.19.1 github.com/alessio/shellescape v1.4.1 github.com/brevdev/parse v0.0.11 diff --git a/go.sum b/go.sum index 33a3e19e..db2d581b 100644 --- a/go.sum +++ b/go.sum @@ -2,10 +2,14 @@ buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226031750-e6fd8dbaf991.2/go.mod h1:EGcIExX0SEtObIZr1l3pouENtdl2gsZtHjOYOfuB7ss= buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226200709-e1ac2ea142d1.2 h1:7ld1AqzV9YsRWP5I4FvMUxDiq5fMCEEsNMECCcUJE/s= buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226200709-e1ac2ea142d1.2/go.mod h1:ZqSAMH+RVqnfQsnUQ5OpJI7dUWx0UzPUWcceudIHWmI= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226234124-59cddad562f0.2 h1:B+GNU2e5fb54KUw11+kOUXNuzxWM40J2GiSmONL8VEA= +buf.build/gen/go/brevdev/devplane/connectrpc/go v1.19.1-20260226234124-59cddad562f0.2/go.mod h1:k1PtdOGpCm4AS2SszBDYyA2pbj9Y39TYRwdhlA17slw= buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1 h1:xkJkJcCnAq5WiEUevk7Kz3b+aFuK7aj64DyVUQM9ZQ0= buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226031750-e6fd8dbaf991.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226200709-e1ac2ea142d1.1 h1:pWlngsd33oF5xFhfTbxYProXMihFXmthzTAcgd3zXKg= buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226200709-e1ac2ea142d1.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226234124-59cddad562f0.1 h1:7+YIWe9KK1AJjpzFThk402ginqJ51bgtjTulw97a4fo= +buf.build/gen/go/brevdev/devplane/protocolbuffers/go v1.36.11-20260226234124-59cddad562f0.1/go.mod h1:V/y7Wxg0QvU4XPVwqErF5NHLobUT1QEyfgrGuQIxdPo= buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1 h1:6amhprQmCKJ4wgJ6ngkh32d9V+dQcOLUZ/SfHdOnYgo= buf.build/gen/go/brevdev/protoc-gen-gotag/protocolbuffers/go v1.36.11-20220906235457-8b4922735da5.1/go.mod h1:O+pnSHMru/naTMrm4tmpBoH3wz6PHa+R75HR7Mv8X2g= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= diff --git a/pkg/cmd/deregister/deregister.go b/pkg/cmd/deregister/deregister.go index 37fb8e30..d50cf326 100644 --- a/pkg/cmd/deregister/deregister.go +++ b/pkg/cmd/deregister/deregister.go @@ -1,4 +1,4 @@ -// Package deregister provides the brev deregister command for DGX Spark deregistration +// Package deregister provides the brev deregister command for device deregistration package deregister import ( @@ -29,16 +29,14 @@ type DeregisterStore interface { // deregisterDeps bundles the side-effecting dependencies of runDeregister so // they can be replaced in tests. type deregisterDeps struct { - goos string - promptSelect func(label string, items []string) string - uninstallNetbird func(t *terminal.Terminal) error - newNodeClient func(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient - registrationExists func(brevHome string) (bool, error) - loadRegistration func(brevHome string) (*register.DeviceRegistration, error) - deleteRegistration func(brevHome string) error + goos string + promptSelect func(label string, items []string) string + uninstallNetbird func() error + newNodeClient func(provider register.TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + registrationStore register.RegistrationStore } -func prodDeregisterDeps() deregisterDeps { +func prodDeregisterDeps(brevHome string) deregisterDeps { return deregisterDeps{ goos: runtime.GOOS, promptSelect: func(label string, items []string) string { @@ -47,16 +45,14 @@ func prodDeregisterDeps() deregisterDeps { Items: items, }) }, - uninstallNetbird: register.UninstallNetbird, - newNodeClient: register.NewNodeServiceClient, - registrationExists: register.RegistrationExists, - loadRegistration: register.LoadRegistration, - deleteRegistration: register.DeleteRegistration, + uninstallNetbird: register.UninstallNetbird, + newNodeClient: register.NewNodeServiceClient, + registrationStore: register.NewFileRegistrationStore(brevHome), } } var ( - deregisterLong = `Deregister your DGX Spark from NVIDIA Brev + deregisterLong = `Deregister your device from NVIDIA Brev This command removes the local registration data and optionally uninstalls NetBird (network agent).` @@ -69,11 +65,15 @@ func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Comman Annotations: map[string]string{"configuration": ""}, Use: "deregister", DisableFlagsInUseLine: true, - Short: "Deregister your DGX Spark from Brev", + Short: "Deregister your device from Brev", Long: deregisterLong, Example: deregisterExample, RunE: func(cmd *cobra.Command, args []string) error { - return runDeregister(cmd.Context(), t, store, prodDeregisterDeps()) + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runDeregister(cmd.Context(), t, store, prodDeregisterDeps(brevHome)) }, } @@ -82,29 +82,24 @@ func NewCmdDeregister(t *terminal.Terminal, store DeregisterStore) *cobra.Comman func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, deps deregisterDeps) error { //nolint:funlen // deregistration flow if deps.goos != "linux" { - return fmt.Errorf("brev deregister is only supported on Linux (DGX Spark)") + return fmt.Errorf("brev deregister is only supported on Linux") } - brevHome, err := s.GetBrevHomePath() - if err != nil { - return breverrors.WrapAndTrace(err) - } - - registered, err := deps.registrationExists(brevHome) + registered, err := deps.registrationStore.Exists() if err != nil { return breverrors.WrapAndTrace(err) } if !registered { - return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your DGX Spark") + return fmt.Errorf("no registration found; this machine does not appear to be registered\nRun 'brev register' to register your device") } - reg, err := deps.loadRegistration(brevHome) + reg, err := deps.registrationStore.Load() if err != nil { return fmt.Errorf("failed to read registration file: %w", err) } t.Vprint("") - t.Vprint(t.Green("Deregistering DGX Spark")) + t.Vprint(t.Green("Deregistering device")) t.Vprint("") t.Vprintf(" Node ID: %s\n", reg.ExternalNodeID) t.Vprintf(" Name: %s\n", reg.DisplayName) @@ -138,7 +133,7 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, if removeNetbird == "Yes, uninstall NetBird" { t.Vprint("Removing NetBird...") - if err := deps.uninstallNetbird(t); err != nil { + if err := deps.uninstallNetbird(); err != nil { t.Vprintf(" Warning: failed to uninstall NetBird: %v\n", err) } else { t.Vprint(t.Green(" NetBird uninstalled.")) @@ -147,9 +142,9 @@ func runDeregister(ctx context.Context, t *terminal.Terminal, s DeregisterStore, } t.Vprint("Removing registration data...") - if err := deps.deleteRegistration(brevHome); err != nil { + if err := deps.registrationStore.Delete(); err != nil { t.Vprintf(" Warning: failed to remove local registration file: %v\n", err) - t.Vprint(" You can manually remove it with: rm ~/.brev/spark_registration.json") + t.Vprint(" You can manually remove it with: rm ~/.brev/device_registration.json") } t.Vprint(t.Green("Deregistration complete.")) diff --git a/pkg/cmd/deregister/deregister_test.go b/pkg/cmd/deregister/deregister_test.go index ed4dfd5d..0c9e24a6 100644 --- a/pkg/cmd/deregister/deregister_test.go +++ b/pkg/cmd/deregister/deregister_test.go @@ -2,6 +2,7 @@ package deregister import ( "context" + "fmt" "net/http/httptest" "testing" @@ -11,9 +12,7 @@ import ( "github.com/brevdev/brev-cli/pkg/cmd/register" "github.com/brevdev/brev-cli/pkg/entity" - "github.com/brevdev/brev-cli/pkg/files" "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/spf13/afero" ) type mockDeregisterStore struct { @@ -47,20 +46,35 @@ func (f *fakeNodeService) RemoveNode(_ context.Context, req *connect.Request[nod return connect.NewResponse(resp), nil } -func setupDeregisterTestFs(t *testing.T) (string, func()) { - t.Helper() - origFs := files.AppFs - files.AppFs = afero.NewMemMapFs() - brevHome := "/home/testuser/.brev" - if err := files.AppFs.MkdirAll(brevHome, 0o700); err != nil { - t.Fatalf("failed to create test dir: %v", err) +// mockRegistrationStore satisfies register.RegistrationStore for deregister tests. +type mockRegistrationStore struct { + reg *register.DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *register.DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*register.DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") } - return brevHome, func() { files.AppFs = origFs } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil } // testDeregisterDeps returns deps with all side-effects stubbed. The // promptSelect defaults to confirming all prompts. -func testDeregisterDeps(t *testing.T, svc *fakeNodeService) (deregisterDeps, *httptest.Server) { +func testDeregisterDeps(t *testing.T, svc *fakeNodeService, regStore register.RegistrationStore) (deregisterDeps, *httptest.Server) { t.Helper() _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) @@ -75,31 +89,27 @@ func testDeregisterDeps(t *testing.T, svc *fakeNodeService) (deregisterDeps, *ht } return "" }, - uninstallNetbird: func(_ *terminal.Terminal) error { return nil }, + uninstallNetbird: func() error { return nil }, newNodeClient: func(provider register.TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { return register.NewNodeServiceClient(provider, server.URL) }, - registrationExists: register.RegistrationExists, - loadRegistration: register.LoadRegistration, - deleteRegistration: register.DeleteRegistration, + registrationStore: regStore, }, server } func Test_runDeregister_HappyPath(t *testing.T) { - brevHome, cleanup := setupDeregisterTestFs(t) - defer cleanup() - - // Pre-save a registration - _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ - ExternalNodeID: "unode_abc", - DisplayName: "My Spark", - OrgID: "org_123", - DeviceID: "dev-uuid", - }) + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + DeviceID: "dev-uuid", + }, + } store := &mockDeregisterStore{ user: &entity.User{ID: "user_1"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } @@ -112,7 +122,7 @@ func Test_runDeregister_HappyPath(t *testing.T) { }, } - deps, server := testDeregisterDeps(t, svc) + deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() term := terminal.New() @@ -128,34 +138,33 @@ func Test_runDeregister_HappyPath(t *testing.T) { t.Errorf("expected org ID org_123, got %s", gotOrgID) } - // Registration file should be deleted - exists, err := register.RegistrationExists(brevHome) + // Registration should be deleted + exists, err := regStore.Exists() if err != nil { - t.Fatalf("RegistrationExists error: %v", err) + t.Fatalf("Exists error: %v", err) } if exists { - t.Error("expected registration file to be deleted after deregister") + t.Error("expected registration to be deleted after deregister") } } func Test_runDeregister_UserCancels(t *testing.T) { - brevHome, cleanup := setupDeregisterTestFs(t) - defer cleanup() - - _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ - ExternalNodeID: "unode_abc", - DisplayName: "My Spark", - OrgID: "org_123", - }) + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } store := &mockDeregisterStore{ user: &entity.User{ID: "user_1"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } svc := &fakeNodeService{} - deps, server := testDeregisterDeps(t, svc) + deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() callCount := 0 @@ -174,10 +183,10 @@ func Test_runDeregister_UserCancels(t *testing.T) { t.Fatalf("expected nil error on cancel, got: %v", err) } - // Registration file should still exist - exists, err := register.RegistrationExists(brevHome) + // Registration should still exist + exists, err := regStore.Exists() if err != nil { - t.Fatalf("RegistrationExists error: %v", err) + t.Fatalf("Exists error: %v", err) } if !exists { t.Error("registration should still exist after cancel") @@ -185,17 +194,16 @@ func Test_runDeregister_UserCancels(t *testing.T) { } func Test_runDeregister_NotRegistered(t *testing.T) { - brevHome, cleanup := setupDeregisterTestFs(t) - defer cleanup() + regStore := &mockRegistrationStore{} store := &mockDeregisterStore{ user: &entity.User{ID: "user_1"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } svc := &fakeNodeService{} - deps, server := testDeregisterDeps(t, svc) + deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() term := terminal.New() @@ -206,18 +214,17 @@ func Test_runDeregister_NotRegistered(t *testing.T) { } func Test_runDeregister_RemoveNodeFails(t *testing.T) { - brevHome, cleanup := setupDeregisterTestFs(t) - defer cleanup() - - _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ - ExternalNodeID: "unode_abc", - DisplayName: "My Spark", - OrgID: "org_123", - }) + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } store := &mockDeregisterStore{ user: &entity.User{ID: "user_1"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } @@ -227,7 +234,7 @@ func Test_runDeregister_RemoveNodeFails(t *testing.T) { }, } - deps, server := testDeregisterDeps(t, svc) + deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() term := terminal.New() @@ -236,10 +243,10 @@ func Test_runDeregister_RemoveNodeFails(t *testing.T) { t.Fatal("expected error when RemoveNode fails") } - // Registration file should still exist (server-side removal failed) - exists, err := register.RegistrationExists(brevHome) + // Registration should still exist (server-side removal failed) + exists, err := regStore.Exists() if err != nil { - t.Fatalf("RegistrationExists error: %v", err) + t.Fatalf("Exists error: %v", err) } if !exists { t.Error("registration should still exist when RemoveNode fails") @@ -247,18 +254,17 @@ func Test_runDeregister_RemoveNodeFails(t *testing.T) { } func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { - brevHome, cleanup := setupDeregisterTestFs(t) - defer cleanup() - - _ = register.SaveRegistration(brevHome, ®ister.DeviceRegistration{ - ExternalNodeID: "unode_abc", - DisplayName: "My Spark", - OrgID: "org_123", - }) + regStore := &mockRegistrationStore{ + reg: ®ister.DeviceRegistration{ + ExternalNodeID: "unode_abc", + DisplayName: "My Spark", + OrgID: "org_123", + }, + } store := &mockDeregisterStore{ user: &entity.User{ID: "user_1"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } @@ -269,7 +275,7 @@ func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { } uninstallCalled := false - deps, server := testDeregisterDeps(t, svc) + deps, server := testDeregisterDeps(t, svc, regStore) defer server.Close() deps.promptSelect = func(label string, items []string) string { @@ -278,7 +284,7 @@ func Test_runDeregister_SkipsNetbirdUninstall(t *testing.T) { } return "Yes, proceed" } - deps.uninstallNetbird = func(_ *terminal.Terminal) error { + deps.uninstallNetbird = func() error { uninstallCalled = true return nil } diff --git a/pkg/cmd/register/netbird.go b/pkg/cmd/register/netbird.go index 4ed9ca9f..5de359af 100644 --- a/pkg/cmd/register/netbird.go +++ b/pkg/cmd/register/netbird.go @@ -4,12 +4,14 @@ import ( "fmt" "os" "os/exec" - - "github.com/brevdev/brev-cli/pkg/terminal" ) -// InstallNetbird downloads and installs NetBird using the official install script. -func InstallNetbird(t *terminal.Terminal) error { +// InstallNetbird installs NetBird if it is not already present. +func InstallNetbird() error { + if _, err := exec.LookPath("netbird"); err == nil { + return nil + } + script := `(curl -fsSL https://pkgs.netbird.io/install.sh | sh) || (curl -fsSL https://pkgs.netbird.io/install.sh | sh -s -- --update)` cmd := exec.Command("bash", "-c", script) // #nosec G204 @@ -34,7 +36,7 @@ func runSetupCommand(script string) error { } // UninstallNetbird stops, uninstalls, and removes NetBird. -func UninstallNetbird(t *terminal.Terminal) error { +func UninstallNetbird() error { script := `netbird service stop && netbird service uninstall && sudo apt-get remove -y netbird` cmd := exec.Command("bash", "-c", script) // #nosec G204 diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 68dd8cba..50f4bbd5 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -1,4 +1,4 @@ -// Package register provides the brev register command for DGX Spark registration +// Package register provides the brev register command for device registration package register import ( @@ -44,16 +44,17 @@ func (r OSFileReader) ReadFile(path string) ([]byte, error) { // registerDeps bundles the side-effecting dependencies of runRegister so they // can be replaced in tests. type registerDeps struct { - goos string - promptYesNo func(label string) bool - installNetbird func(t *terminal.Terminal) error - runSetupCommand func(script string) error - newNodeClient func(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient - commandRunner CommandRunner - fileReader FileReader + goos string + promptYesNo func(label string) bool + installNetbird func() error + runSetupCommand func(script string) error + newNodeClient func(provider TokenProvider, baseURL string) nodev1connect.ExternalNodeServiceClient + commandRunner CommandRunner + fileReader FileReader + registrationStore RegistrationStore } -func prodRegisterDeps() registerDeps { +func prodRegisterDeps(brevHome string) registerDeps { return registerDeps{ goos: runtime.GOOS, promptYesNo: func(label string) bool { @@ -63,16 +64,17 @@ func prodRegisterDeps() registerDeps { }) return result == "Yes, proceed" }, - installNetbird: InstallNetbird, - runSetupCommand: runSetupCommand, - newNodeClient: NewNodeServiceClient, - commandRunner: ExecCommandRunner{}, - fileReader: OSFileReader{}, + installNetbird: InstallNetbird, + runSetupCommand: runSetupCommand, + newNodeClient: NewNodeServiceClient, + commandRunner: ExecCommandRunner{}, + fileReader: OSFileReader{}, + registrationStore: NewFileRegistrationStore(brevHome), } } var ( - registerLong = `Register your DGX Spark with NVIDIA Brev + registerLong = `Register your device with NVIDIA Brev This command installs NetBird (network agent), and registers this machine with Brev.` @@ -85,13 +87,16 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { cmd := &cobra.Command{ Annotations: map[string]string{"configuration": ""}, Use: "register", - Aliases: []string{"spark"}, DisableFlagsInUseLine: true, Short: "Register this device with Brev", Long: registerLong, Example: registerExample, RunE: func(cmd *cobra.Command, args []string) error { - return runRegister(cmd.Context(), t, store, name, prodRegisterDeps()) + brevHome, err := store.GetBrevHomePath() + if err != nil { + return breverrors.WrapAndTrace(err) + } + return runRegister(cmd.Context(), t, store, name, prodRegisterDeps(brevHome)) }, } @@ -102,45 +107,20 @@ func NewCmdRegister(t *terminal.Terminal, store RegisterStore) *cobra.Command { } func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, name string, deps registerDeps) error { //nolint:funlen // registration flow - if deps.goos != "linux" { - return fmt.Errorf("brev register is only supported on Linux (DGX Spark)") - } - - _, err := s.GetCurrentUser() // ensure active token - if err != nil { - return breverrors.WrapAndTrace(err) - } - - org, err := s.GetActiveOrganizationOrDefault() - if err != nil { - return breverrors.WrapAndTrace(err) - } - if org == nil { - return fmt.Errorf("no organization found; please create or join an organization first") - } - - brevHome, err := s.GetBrevHomePath() - if err != nil { - return breverrors.WrapAndTrace(err) - } - - alreadyRegistered, err := RegistrationExists(brevHome) + org, err := getOrgToRegisterFor(deps, s) if err != nil { - return breverrors.WrapAndTrace(err) - } - if alreadyRegistered { - return fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") + return err } u, _ := user.Current() linuxUser := u.Username t.Vprint("") - t.Vprint(t.Green("Registering your DGX Spark with Brev")) + t.Vprint(t.Green("Registering your device with Brev")) t.Vprint("") t.Vprintf(" Name: %s\n", t.Yellow(name)) t.Vprintf(" Organization: %s\n", org.Name) - t.Vprintf(" Linux user: %s\n", linuxUser) + t.Vprintf(" Registering for Linux user: %s\n", linuxUser) t.Vprint("") t.Vprint("This will perform the following steps:") t.Vprint(" 1. Install NetBird") @@ -155,7 +135,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint("") t.Vprint(t.Yellow("[Step 1/3] Installing NetBird...")) - if err := deps.installNetbird(t); err != nil { + if err := deps.installNetbird(); err != nil { return fmt.Errorf("NetBird installation failed: %w", err) } t.Vprint(t.Green(" NetBird installed successfully.")) @@ -196,7 +176,7 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam RegisteredAt: time.Now().UTC().Format(time.RFC3339), NodeSpec: *nodeSpec, } - if err := SaveRegistration(brevHome, reg); err != nil { + if err := deps.registrationStore.Save(reg); err != nil { return fmt.Errorf("node registered but failed to save locally: %w", err) } @@ -209,3 +189,31 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam } return nil } + +func getOrgToRegisterFor(deps registerDeps, s RegisterStore) (*entity.Organization, error) { + if deps.goos != "linux" { + return nil, fmt.Errorf("brev register is only supported on Linux") + } + + _, err := s.GetCurrentUser() // ensure active token + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + + org, err := s.GetActiveOrganizationOrDefault() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if org == nil { + return nil, fmt.Errorf("no organization found; please create or join an organization first") + } + + alreadyRegistered, err := deps.registrationStore.Exists() + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + if alreadyRegistered { + return nil, fmt.Errorf("this machine is already registered; run 'brev deregister' first to re-register") + } + return org, nil +} diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index f89f8cbd..261fae1a 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -2,6 +2,7 @@ package register import ( "context" + "fmt" "net/http/httptest" "testing" @@ -10,9 +11,7 @@ import ( "connectrpc.com/connect" "github.com/brevdev/brev-cli/pkg/entity" - "github.com/brevdev/brev-cli/pkg/files" "github.com/brevdev/brev-cli/pkg/terminal" - "github.com/spf13/afero" ) // mockRegisterStore satisfies RegisterStore for orchestration tests. @@ -38,9 +37,35 @@ func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organizati func (m *mockRegisterStore) GetBrevHomePath() (string, error) { return m.home, nil } func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } -// testRegisterDeps returns deps with all side-effects stubbed out, and a fake +// mockRegistrationStore satisfies RegistrationStore for orchestration tests. +type mockRegistrationStore struct { + reg *DeviceRegistration +} + +func (m *mockRegistrationStore) Save(reg *DeviceRegistration) error { + m.reg = reg + return nil +} + +func (m *mockRegistrationStore) Load() (*DeviceRegistration, error) { + if m.reg == nil { + return nil, fmt.Errorf("no registration") + } + return m.reg, nil +} + +func (m *mockRegistrationStore) Delete() error { + m.reg = nil + return nil +} + +func (m *mockRegistrationStore) Exists() (bool, error) { + return m.reg != nil, nil +} + +// testRegisterDeps returns deps with all side effects stubbed out, and a fake // ConnectRPC server backed by the provided fakeNodeService. -func testRegisterDeps(t *testing.T, svc *fakeNodeService) (registerDeps, *httptest.Server) { +func testRegisterDeps(t *testing.T, svc *fakeNodeService, regStore RegistrationStore) (registerDeps, *httptest.Server) { t.Helper() _, handler := nodev1connect.NewExternalNodeServiceHandler(svc) @@ -49,7 +74,7 @@ func testRegisterDeps(t *testing.T, svc *fakeNodeService) (registerDeps, *httpte return registerDeps{ goos: "linux", promptYesNo: func(_ string) bool { return true }, - installNetbird: func(_ *terminal.Terminal) error { return nil }, + installNetbird: func() error { return nil }, runSetupCommand: func(_ string) error { return nil }, newNodeClient: func(provider TokenProvider, _ string) nodev1connect.ExternalNodeServiceClient { return NewNodeServiceClient(provider, server.URL) @@ -67,28 +92,17 @@ func testRegisterDeps(t *testing.T, svc *fakeNodeService) (registerDeps, *httpte "/etc/os-release": []byte("NAME=\"Ubuntu\"\nVERSION_ID=\"24.04\"\n"), }, }, + registrationStore: regStore, }, server } -func setupRegisterTestFs(t *testing.T) (string, func()) { - t.Helper() - origFs := files.AppFs - files.AppFs = afero.NewMemMapFs() - brevHome := "/home/testuser/.brev" - if err := files.AppFs.MkdirAll(brevHome, 0o700); err != nil { - t.Fatalf("failed to create test dir: %v", err) - } - return brevHome, func() { files.AppFs = origFs } -} - func Test_runRegister_HappyPath(t *testing.T) { - brevHome, cleanup := setupRegisterTestFs(t) - defer cleanup() + regStore := &mockRegistrationStore{} store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } @@ -113,7 +127,7 @@ func Test_runRegister_HappyPath(t *testing.T) { }, } - deps, server := testRegisterDeps(t, svc) + deps, server := testRegisterDeps(t, svc, regStore) defer server.Close() deps.runSetupCommand = func(cmd string) error { @@ -128,17 +142,17 @@ func Test_runRegister_HappyPath(t *testing.T) { } // Verify registration was persisted - exists, err := RegistrationExists(brevHome) + exists, err := regStore.Exists() if err != nil { - t.Fatalf("RegistrationExists error: %v", err) + t.Fatalf("Exists error: %v", err) } if !exists { - t.Fatal("expected registration file to exist after successful register") + t.Fatal("expected registration to exist after successful register") } - reg, err := LoadRegistration(brevHome) + reg, err := regStore.Load() if err != nil { - t.Fatalf("LoadRegistration failed: %v", err) + t.Fatalf("Load failed: %v", err) } if reg.ExternalNodeID != "unode_abc" { t.Errorf("expected ExternalNodeID unode_abc, got %s", reg.ExternalNodeID) @@ -157,18 +171,17 @@ func Test_runRegister_HappyPath(t *testing.T) { } func Test_runRegister_UserCancels(t *testing.T) { - brevHome, cleanup := setupRegisterTestFs(t) - defer cleanup() + regStore := &mockRegistrationStore{} store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } svc := &fakeNodeService{} - deps, server := testRegisterDeps(t, svc) + deps, server := testRegisterDeps(t, svc, regStore) defer server.Close() deps.promptYesNo = func(_ string) bool { return false } @@ -179,35 +192,33 @@ func Test_runRegister_UserCancels(t *testing.T) { t.Fatalf("expected nil error on cancel, got: %v", err) } - // Registration file should not exist - exists, err := RegistrationExists(brevHome) + // Registration should not exist + exists, err := regStore.Exists() if err != nil { - t.Fatalf("RegistrationExists error: %v", err) + t.Fatalf("Exists error: %v", err) } if exists { - t.Error("registration file should not exist after cancel") + t.Error("registration should not exist after cancel") } } func Test_runRegister_AlreadyRegistered(t *testing.T) { - brevHome, cleanup := setupRegisterTestFs(t) - defer cleanup() - - // Save an existing registration - _ = SaveRegistration(brevHome, &DeviceRegistration{ - ExternalNodeID: "unode_existing", - DisplayName: "Existing", - }) + regStore := &mockRegistrationStore{ + reg: &DeviceRegistration{ + ExternalNodeID: "unode_existing", + DisplayName: "Existing", + }, + } store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } svc := &fakeNodeService{} - deps, server := testRegisterDeps(t, svc) + deps, server := testRegisterDeps(t, svc, regStore) defer server.Close() term := terminal.New() @@ -218,18 +229,17 @@ func Test_runRegister_AlreadyRegistered(t *testing.T) { } func Test_runRegister_NoOrganization(t *testing.T) { - brevHome, cleanup := setupRegisterTestFs(t) - defer cleanup() + regStore := &mockRegistrationStore{} store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: nil, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } svc := &fakeNodeService{} - deps, server := testRegisterDeps(t, svc) + deps, server := testRegisterDeps(t, svc, regStore) defer server.Close() term := terminal.New() @@ -240,13 +250,12 @@ func Test_runRegister_NoOrganization(t *testing.T) { } func Test_runRegister_AddNodeFails(t *testing.T) { - brevHome, cleanup := setupRegisterTestFs(t) - defer cleanup() + regStore := &mockRegistrationStore{} store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } @@ -256,7 +265,7 @@ func Test_runRegister_AddNodeFails(t *testing.T) { }, } - deps, server := testRegisterDeps(t, svc) + deps, server := testRegisterDeps(t, svc, regStore) defer server.Close() term := terminal.New() @@ -265,24 +274,23 @@ func Test_runRegister_AddNodeFails(t *testing.T) { t.Fatal("expected error when AddNode fails") } - // Registration file should not exist on failure - exists, err := RegistrationExists(brevHome) + // Registration should not exist on failure + exists, err := regStore.Exists() if err != nil { - t.Fatalf("RegistrationExists error: %v", err) + t.Fatalf("Exists error: %v", err) } if exists { - t.Error("registration file should not exist after AddNode failure") + t.Error("registration should not exist after AddNode failure") } } func Test_runRegister_NoSetupCommand(t *testing.T) { - brevHome, cleanup := setupRegisterTestFs(t) - defer cleanup() + regStore := &mockRegistrationStore{} store := &mockRegisterStore{ user: &entity.User{ID: "user_1"}, org: &entity.Organization{ID: "org_123", Name: "TestOrg"}, - home: brevHome, + home: "/home/testuser/.brev", token: "tok", } @@ -301,7 +309,7 @@ func Test_runRegister_NoSetupCommand(t *testing.T) { }, } - deps, server := testRegisterDeps(t, svc) + deps, server := testRegisterDeps(t, svc, regStore) defer server.Close() deps.runSetupCommand = func(_ string) error { diff --git a/pkg/cmd/register/identity.go b/pkg/cmd/register/registration.go similarity index 57% rename from pkg/cmd/register/identity.go rename to pkg/cmd/register/registration.go index 18455295..0161ce76 100644 --- a/pkg/cmd/register/identity.go +++ b/pkg/cmd/register/registration.go @@ -10,7 +10,7 @@ import ( "github.com/spf13/afero" ) -const registrationFileName = "spark_registration.json" +const registrationFileName = "device_registration.json" // DeviceRegistration is the persistent identity file for a registered device. // Fields align with the AddNodeResponse from dev-plane. @@ -23,13 +23,30 @@ type DeviceRegistration struct { NodeSpec NodeSpec `json:"node_spec"` } -func registrationPath(brevHome string) string { - return filepath.Join(brevHome, registrationFileName) +// RegistrationStore defines the contract for persisting device registration data. +type RegistrationStore interface { + Save(reg *DeviceRegistration) error + Load() (*DeviceRegistration, error) + Delete() error + Exists() (bool, error) } -// SaveRegistration writes the registration to ~/.brev/spark_registration.json. -func SaveRegistration(brevHome string, reg *DeviceRegistration) error { - path := registrationPath(brevHome) +// FileRegistrationStore implements RegistrationStore using the local filesystem. +type FileRegistrationStore struct { + brevHome string +} + +// NewFileRegistrationStore returns a FileRegistrationStore rooted at brevHome. +func NewFileRegistrationStore(brevHome string) *FileRegistrationStore { + return &FileRegistrationStore{brevHome: brevHome} +} + +func (s *FileRegistrationStore) path() string { + return filepath.Join(s.brevHome, registrationFileName) +} + +func (s *FileRegistrationStore) Save(reg *DeviceRegistration) error { + path := s.path() data, err := json.MarshalIndent(reg, "", " ") if err != nil { return breverrors.WrapAndTrace(err) @@ -43,9 +60,8 @@ func SaveRegistration(brevHome string, reg *DeviceRegistration) error { return nil } -// LoadRegistration reads the registration from ~/.brev/spark_registration.json. -func LoadRegistration(brevHome string) (*DeviceRegistration, error) { - path := registrationPath(brevHome) +func (s *FileRegistrationStore) Load() (*DeviceRegistration, error) { + path := s.path() var reg DeviceRegistration err := files.ReadJSON(files.AppFs, path, ®) if err != nil { @@ -54,9 +70,8 @@ func LoadRegistration(brevHome string) (*DeviceRegistration, error) { return ®, nil } -// DeleteRegistration removes ~/.brev/spark_registration.json. -func DeleteRegistration(brevHome string) error { - path := registrationPath(brevHome) +func (s *FileRegistrationStore) Delete() error { + path := s.path() err := files.DeleteFile(files.AppFs, path) if err != nil { return breverrors.WrapAndTrace(err) @@ -64,10 +79,8 @@ func DeleteRegistration(brevHome string) error { return nil } -// RegistrationExists checks if a registration file exists. -// Returns (exists, error) so callers can distinguish "not found" from real errors. -func RegistrationExists(brevHome string) (bool, error) { - path := registrationPath(brevHome) +func (s *FileRegistrationStore) Exists() (bool, error) { + path := s.path() _, err := files.AppFs.Stat(path) if err == nil { return true, nil diff --git a/pkg/cmd/register/identity_test.go b/pkg/cmd/register/registration_test.go similarity index 75% rename from pkg/cmd/register/identity_test.go rename to pkg/cmd/register/registration_test.go index d5d08f85..6e14514d 100644 --- a/pkg/cmd/register/identity_test.go +++ b/pkg/cmd/register/registration_test.go @@ -22,6 +22,8 @@ func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() + store := NewFileRegistrationStore(brevHome) + cpuCount := int32(12) ramBytes := int64(137438953472) reg := &DeviceRegistration{ @@ -37,13 +39,13 @@ func Test_SaveAndLoadRegistration_RoundTrip(t *testing.T) { }, } - if err := SaveRegistration(brevHome, reg); err != nil { - t.Fatalf("SaveRegistration failed: %v", err) + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) } - loaded, err := LoadRegistration(brevHome) + loaded, err := store.Load() if err != nil { - t.Fatalf("LoadRegistration failed: %v", err) + t.Fatalf("Load failed: %v", err) } if loaded.ExternalNodeID != reg.ExternalNodeID { @@ -70,12 +72,14 @@ func Test_RegistrationExists_ReturnsFalseWhenMissing(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() - exists, err := RegistrationExists(brevHome) + store := NewFileRegistrationStore(brevHome) + + exists, err := store.Exists() if err != nil { t.Fatalf("unexpected error: %v", err) } if exists { - t.Error("expected RegistrationExists to return false") + t.Error("expected Exists to return false") } } @@ -83,20 +87,22 @@ func Test_RegistrationExists_ReturnsTrueAfterSave(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() + store := NewFileRegistrationStore(brevHome) + reg := &DeviceRegistration{ ExternalNodeID: "unode_abc123", DisplayName: "Test", } - if err := SaveRegistration(brevHome, reg); err != nil { - t.Fatalf("SaveRegistration failed: %v", err) + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) } - exists, err := RegistrationExists(brevHome) + exists, err := store.Exists() if err != nil { t.Fatalf("unexpected error: %v", err) } if !exists { - t.Error("expected RegistrationExists to return true") + t.Error("expected Exists to return true") } } @@ -104,24 +110,26 @@ func Test_DeleteRegistration_RemovesFile(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() + store := NewFileRegistrationStore(brevHome) + reg := &DeviceRegistration{ ExternalNodeID: "unode_abc123", DisplayName: "Test", } - if err := SaveRegistration(brevHome, reg); err != nil { - t.Fatalf("SaveRegistration failed: %v", err) + if err := store.Save(reg); err != nil { + t.Fatalf("Save failed: %v", err) } - if err := DeleteRegistration(brevHome); err != nil { - t.Fatalf("DeleteRegistration failed: %v", err) + if err := store.Delete(); err != nil { + t.Fatalf("Delete failed: %v", err) } - exists, err := RegistrationExists(brevHome) + exists, err := store.Exists() if err != nil { t.Fatalf("unexpected error: %v", err) } if exists { - t.Error("expected RegistrationExists to return false after delete") + t.Error("expected Exists to return false after delete") } } @@ -129,7 +137,9 @@ func Test_LoadRegistration_FailsWhenMissing(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() - _, err := LoadRegistration(brevHome) + store := NewFileRegistrationStore(brevHome) + + _, err := store.Load() if err == nil { t.Error("expected error loading missing registration") } @@ -139,7 +149,9 @@ func Test_DeleteRegistration_FailsWhenMissing(t *testing.T) { brevHome, cleanup := setupTestFs(t) defer cleanup() - err := DeleteRegistration(brevHome) + store := NewFileRegistrationStore(brevHome) + + err := store.Delete() if err == nil { t.Error("expected error deleting missing registration") } diff --git a/pkg/cmd/register/rpcclient.go b/pkg/cmd/register/rpcclient.go index 0f74bbd8..501d4aeb 100644 --- a/pkg/cmd/register/rpcclient.go +++ b/pkg/cmd/register/rpcclient.go @@ -66,22 +66,11 @@ func toProtoNodeSpec(s *NodeSpec) *nodev1.NodeSpec { CpuCount: s.CPUCount, } - // Bridge: sum storage array into the scalar proto fields until the - // proto is updated with repeated StorageSpec. Delete this block when - // we `go get` the new buf module commit. - if len(s.Storage) > 0 { - var totalBytes int64 - var firstType string - for _, st := range s.Storage { - totalBytes += st.StorageBytes - if firstType == "" && st.StorageType != "" { - firstType = st.StorageType - } - } - proto.StorageBytes = &totalBytes - if firstType != "" { - proto.StorageType = &firstType - } + for _, st := range s.Storage { + proto.Storage = append(proto.Storage, &nodev1.StorageSpec{ + StorageBytes: st.StorageBytes, + StorageType: st.StorageType, + }) } if s.Architecture != "" { diff --git a/pkg/cmd/register/rpcclient_test.go b/pkg/cmd/register/rpcclient_test.go index 9b24c697..a0fdf8d2 100644 --- a/pkg/cmd/register/rpcclient_test.go +++ b/pkg/cmd/register/rpcclient_test.go @@ -91,11 +91,14 @@ func Test_toProtoNodeSpec(t *testing.T) { if proto.GetOsVersion() != "24.04" { t.Errorf("expected 24.04, got %s", proto.GetOsVersion()) } - if proto.GetStorageBytes() != 500107862016 { - t.Errorf("expected StorageBytes, got %d", proto.GetStorageBytes()) + if len(proto.GetStorage()) != 1 { + t.Fatalf("expected 1 storage entry, got %d", len(proto.GetStorage())) } - if proto.GetStorageType() != "SSD" { - t.Errorf("expected SSD, got %s", proto.GetStorageType()) + if proto.GetStorage()[0].GetStorageBytes() != 500107862016 { + t.Errorf("expected StorageBytes 500107862016, got %d", proto.GetStorage()[0].GetStorageBytes()) + } + if proto.GetStorage()[0].GetStorageType() != "SSD" { + t.Errorf("expected SSD, got %s", proto.GetStorage()[0].GetStorageType()) } if len(proto.GetGpus()) != 1 { t.Fatalf("expected 1 GPU, got %d", len(proto.GetGpus())) From eee4c0b2c3a1f1e0555201e7f23dd5082cfae4e0 Mon Sep 17 00:00:00 2001 From: Pratik Patel Date: Thu, 26 Feb 2026 18:36:34 -0800 Subject: [PATCH 10/11] add list nodes --- pkg/cmd/ls/ls.go | 141 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 140 insertions(+), 1 deletion(-) diff --git a/pkg/cmd/ls/ls.go b/pkg/cmd/ls/ls.go index 01280159..496bb084 100644 --- a/pkg/cmd/ls/ls.go +++ b/pkg/cmd/ls/ls.go @@ -2,14 +2,19 @@ package ls import ( + "context" "encoding/json" "fmt" "os" + nodev1 "buf.build/gen/go/brevdev/devplane/protocolbuffers/go/devplaneapi/v1" + "connectrpc.com/connect" + "github.com/brevdev/brev-cli/pkg/analytics" "github.com/brevdev/brev-cli/pkg/cmd/cmderrors" "github.com/brevdev/brev-cli/pkg/cmd/completions" "github.com/brevdev/brev-cli/pkg/cmd/hello" + "github.com/brevdev/brev-cli/pkg/cmd/register" cmdutil "github.com/brevdev/brev-cli/pkg/cmd/util" "github.com/brevdev/brev-cli/pkg/cmdcontext" "github.com/brevdev/brev-cli/pkg/config" @@ -32,6 +37,7 @@ type LsStore interface { GetUsers(queryParams map[string]string) ([]entity.User, error) GetWorkspace(workspaceID string) (*entity.Workspace, error) GetOrganizations(options *store.GetOrganizationsOptions) ([]entity.Organization, error) + GetAccessToken() (string, error) hello.HelloStore } @@ -99,7 +105,7 @@ with other commands like stop, start, or delete.`, return nil }, Args: cmderrors.TransformToValidationError(cobra.MinimumNArgs(0)), - ValidArgs: []string{"orgs", "workspaces"}, + ValidArgs: []string{"orgs", "workspaces", "nodes"}, RunE: func(cmd *cobra.Command, args []string) error { // Auto-switch to names-only output when piped (for chaining with stop/start/delete) piped := cmdutil.IsStdoutPiped() @@ -230,6 +236,12 @@ func handleLsArg(ls *Ls, arg string, user *entity.User, org *entity.Organization return breverrors.WrapAndTrace(err) } return nil + } else if util.IsSingularOrPlural(arg, "node") { + err := ls.RunNodes(org) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil } return nil } @@ -444,6 +456,10 @@ func (ls Ls) RunWorkspaces(org *entity.Organization, user *entity.User, showAll } else { ls.ShowUserWorkspaces(org, orgs, user, allWorkspaces) } + + // Also show external nodes in the default listing + ls.showNodesSection(org) + return nil } @@ -643,3 +659,126 @@ func getStatusColoredText(t *terminal.Terminal, status string) string { return status } } + +// NodeInfo represents external node data for JSON output. +type NodeInfo struct { + Name string `json:"name"` + ExternalNodeID string `json:"external_node_id"` + DeviceID string `json:"device_id"` + OrgID string `json:"org_id"` + Status string `json:"status"` +} + +func (ls Ls) listNodes(org *entity.Organization) ([]*nodev1.ExternalNode, error) { + client := register.NewNodeServiceClient(ls.lsStore, config.GlobalConfig.GetBrevAPIURl()) + resp, err := client.ListNodes(context.Background(), connect.NewRequest(&nodev1.ListNodesRequest{ + OrganizationId: org.ID, + })) + if err != nil { + return nil, breverrors.WrapAndTrace(err) + } + return resp.Msg.GetItems(), nil +} + +// RunNodes lists external nodes for the given org. +func (ls Ls) RunNodes(org *entity.Organization) error { + nodes, err := ls.listNodes(org) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + if len(nodes) == 0 { + if ls.jsonOutput { + fmt.Println("[]") + return nil + } + if ls.piped { + return nil + } + ls.terminal.Vprint(ls.terminal.Yellow("No external nodes in this org.")) + return nil + } + + if ls.jsonOutput { + return ls.outputNodesJSON(nodes) + } + if ls.piped { + displayNodesTablePlain(nodes) + return nil + } + + ls.terminal.Vprintf("\nYou have %d external node(s) in Org %s\n", len(nodes), ls.terminal.Yellow(org.Name)) + displayNodesTable(ls.terminal, nodes) + return nil +} + +// showNodesSection appends external nodes to the default `brev ls` output. +// Errors are silently ignored so that a ListNodes failure doesn't break the +// workspace listing. +func (ls Ls) showNodesSection(org *entity.Organization) { + nodes, err := ls.listNodes(org) + if err != nil || len(nodes) == 0 { + return + } + + if ls.jsonOutput || ls.piped { + // JSON and piped modes are already handled per-section; skip here to + // avoid duplicating output when the user runs `brev ls nodes` explicitly. + return + } + + ls.terminal.Vprintf("\nExternal Nodes (%d):\n", len(nodes)) + displayNodesTable(ls.terminal, nodes) +} + +func (ls Ls) outputNodesJSON(nodes []*nodev1.ExternalNode) error { + var infos []NodeInfo + for _, n := range nodes { + infos = append(infos, NodeInfo{ + Name: n.GetName(), + ExternalNodeID: n.GetExternalNodeId(), + DeviceID: n.GetDeviceId(), + OrgID: n.GetOrganizationId(), + Status: nodeConnectionStatus(n), + }) + } + output, err := json.MarshalIndent(infos, "", " ") + if err != nil { + return breverrors.WrapAndTrace(err) + } + fmt.Println(string(output)) + return nil +} + +func displayNodesTable(t *terminal.Terminal, nodes []*nodev1.ExternalNode) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + ta.AppendHeader(table.Row{"NAME", "NODE ID", "DEVICE ID", "STATUS"}) + for _, n := range nodes { + status := nodeConnectionStatus(n) + ta.AppendRows([]table.Row{{n.GetName(), n.GetExternalNodeId(), n.GetDeviceId(), getStatusColoredText(t, status)}}) + } + ta.Render() +} + +func displayNodesTablePlain(nodes []*nodev1.ExternalNode) { + ta := table.NewWriter() + ta.SetOutputMirror(os.Stdout) + ta.Style().Options = getBrevTableOptions() + ta.AppendHeader(table.Row{"NAME", "NODE ID", "DEVICE ID", "STATUS"}) + for _, n := range nodes { + ta.AppendRows([]table.Row{{n.GetName(), n.GetExternalNodeId(), n.GetDeviceId(), nodeConnectionStatus(n)}}) + } + ta.Render() +} + +func nodeConnectionStatus(n *nodev1.ExternalNode) string { + if ci := n.GetConnectivityInfo(); ci != nil && ci.HasConnected() { + if ci.GetConnected() { + return "CONNECTED" + } + return "DISCONNECTED" + } + return "UNKNOWN" +} From b242a046bc49623c5e6e091b3741feaeac7cad7e Mon Sep 17 00:00:00 2001 From: Alec Fong Date: Thu, 26 Feb 2026 22:30:36 -0800 Subject: [PATCH 11/11] feat(register): add user public key to authorized_keys on registration After device registration completes, fetch the user's public key and append it to ~/.ssh/authorized_keys on the registered machine. This enables SSH access to the device using the user's brev key pair. --- pkg/cmd/register/register.go | 14 ++++++++++++ pkg/cmd/register/register_test.go | 3 +++ pkg/files/files.go | 38 +++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+) diff --git a/pkg/cmd/register/register.go b/pkg/cmd/register/register.go index 50f4bbd5..0048e6c6 100644 --- a/pkg/cmd/register/register.go +++ b/pkg/cmd/register/register.go @@ -17,6 +17,7 @@ import ( "github.com/brevdev/brev-cli/pkg/config" "github.com/brevdev/brev-cli/pkg/entity" breverrors "github.com/brevdev/brev-cli/pkg/errors" + "github.com/brevdev/brev-cli/pkg/files" "github.com/brevdev/brev-cli/pkg/terminal" "github.com/spf13/cobra" @@ -25,6 +26,7 @@ import ( // RegisterStore defines the store methods needed by the register command. type RegisterStore interface { GetCurrentUser() (*entity.User, error) + GetCurrentUserKeys() (*entity.UserKeys, error) GetActiveOrganizationOrDefault() (*entity.Organization, error) GetBrevHomePath() (string, error) GetAccessToken() (string, error) @@ -182,6 +184,18 @@ func runRegister(ctx context.Context, t *terminal.Terminal, s RegisterStore, nam t.Vprint(t.Green(" Registration complete.")) + // Add user's public key to authorized_keys for SSH access + keys, err := s.GetCurrentUserKeys() + if err != nil { + t.Vprintf(" Warning: failed to get user keys: %v\n", err) + } else if keys.PublicKey != "" { + if err := files.WriteAuthorizedKey(files.AppFs, keys.PublicKey, u.HomeDir); err != nil { + t.Vprintf(" Warning: failed to add public key to authorized_keys: %v\n", err) + } else { + t.Vprint(t.Green(" Added public key to authorized_keys.")) + } + } + if cmd := addResp.Msg.GetSetupCommand(); cmd != "" { if err := deps.runSetupCommand(cmd); err != nil { t.Vprintf(" Warning: setup command failed: %v\n", err) diff --git a/pkg/cmd/register/register_test.go b/pkg/cmd/register/register_test.go index 261fae1a..59da4dee 100644 --- a/pkg/cmd/register/register_test.go +++ b/pkg/cmd/register/register_test.go @@ -36,6 +36,9 @@ func (m *mockRegisterStore) GetActiveOrganizationOrDefault() (*entity.Organizati func (m *mockRegisterStore) GetBrevHomePath() (string, error) { return m.home, nil } func (m *mockRegisterStore) GetAccessToken() (string, error) { return m.token, nil } +func (m *mockRegisterStore) GetCurrentUserKeys() (*entity.UserKeys, error) { + return &entity.UserKeys{PublicKey: "ssh-rsa AAAA test@test", PrivateKey: "fake-private-key"}, nil +} // mockRegistrationStore satisfies RegistrationStore for orchestration tests. type mockRegistrationStore struct { diff --git a/pkg/files/files.go b/pkg/files/files.go index 0b1796e7..c2c34f0f 100644 --- a/pkg/files/files.go +++ b/pkg/files/files.go @@ -7,6 +7,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" breverrors "github.com/brevdev/brev-cli/pkg/errors" "golang.org/x/text/encoding/charmap" @@ -73,6 +74,43 @@ func GetSSHPrivateKeyPath(home string) string { return fpath } +func GetAuthorizedKeysPath(home string) string { + return filepath.Join(home, ".ssh", "authorized_keys") +} + +// WriteAuthorizedKey ensures the given public key is present in ~/.ssh/authorized_keys. +// It appends the key only if it's not already there. +func WriteAuthorizedKey(fs afero.Fs, publicKey string, home string) error { + authorizedKeysPath := GetAuthorizedKeysPath(home) + err := fs.MkdirAll(filepath.Dir(authorizedKeysPath), 0o700) + if err != nil { + return breverrors.WrapAndTrace(err) + } + + publicKey = strings.TrimSpace(publicKey) + + existing, err := afero.ReadFile(fs, authorizedKeysPath) + if err != nil && !os.IsNotExist(err) { + return breverrors.WrapAndTrace(err) + } + + if strings.Contains(string(existing), publicKey) { + return nil + } + + content := string(existing) + if len(content) > 0 && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += publicKey + "\n" + + err = afero.WriteFile(fs, authorizedKeysPath, []byte(content), 0o600) + if err != nil { + return breverrors.WrapAndTrace(err) + } + return nil +} + func GetUserSSHConfigPath(home string) (string, error) { sshConfigPath := filepath.Join(home, ".ssh", "config") return sshConfigPath, nil