diff --git a/commands/dedicated_inference.go b/commands/dedicated_inference.go index d02b7d474..bbdab1c57 100644 --- a/commands/dedicated_inference.go +++ b/commands/dedicated_inference.go @@ -82,6 +82,22 @@ For more information, see https://docs.digitalocean.com/reference/api/digitaloce AddBoolFlag(cmdDelete, doctl.ArgForce, doctl.ArgShortForce, false, "Delete the dedicated inference endpoint without a confirmation prompt") cmdDelete.Example = `The following example deletes a dedicated inference endpoint: doctl dedicated-inference delete 12345678-1234-1234-1234-123456789012` + cmdUpdate := CmdBuilder( + cmd, + RunDedicatedInferenceUpdate, + "update ", + "Update a dedicated inference endpoint", + `Updates a dedicated inference endpoint using a spec file in JSON or YAML format. +Use the `+"`"+`--spec`+"`"+` flag to provide the path to the updated spec file. +Optionally provide a Hugging Face access token using `+"`"+`--hugging-face-token`+"`"+`.`, + Writer, + aliasOpt("u"), + displayerType(&displayers.DedicatedInference{}), + ) + AddStringFlag(cmdUpdate, doctl.ArgDedicatedInferenceSpec, "", "", `Path to a dedicated inference spec in JSON or YAML format. Set to "-" to read from stdin.`, requiredOpt()) + AddStringFlag(cmdUpdate, doctl.ArgDedicatedInferenceHuggingFaceToken, "", "", "Hugging Face token for accessing gated models (optional)") + cmdUpdate.Example = `The following example updates a dedicated inference endpoint using a spec file: doctl dedicated-inference update 12345678-1234-1234-1234-123456789012 --spec spec.yaml` + cmdListAccelerators := CmdBuilder( cmd, RunDedicatedInferenceListAccelerators, @@ -183,6 +199,41 @@ func RunDedicatedInferenceGet(c *CmdConfig) error { return c.Display(&displayers.DedicatedInference{DedicatedInferences: do.DedicatedInferences{*endpoint}}) } +// RunDedicatedInferenceUpdate updates a dedicated inference endpoint. +func RunDedicatedInferenceUpdate(c *CmdConfig) error { + if len(c.Args) < 1 { + return doctl.NewMissingArgsErr(c.NS) + } + id := c.Args[0] + + specPath, err := c.Doit.GetString(c.NS, doctl.ArgDedicatedInferenceSpec) + if err != nil { + return err + } + + spec, err := readDedicatedInferenceSpec(os.Stdin, specPath) + if err != nil { + return err + } + + req := &godo.DedicatedInferenceUpdateRequest{ + Spec: spec, + } + + hfToken, _ := c.Doit.GetString(c.NS, doctl.ArgDedicatedInferenceHuggingFaceToken) + if hfToken != "" { + req.Secrets = &godo.DedicatedInferenceSecrets{ + HuggingFaceToken: hfToken, + } + } + + endpoint, err := c.DedicatedInferences().Update(id, req) + if err != nil { + return err + } + return c.Display(&displayers.DedicatedInference{DedicatedInferences: do.DedicatedInferences{*endpoint}}) +} + // RunDedicatedInferenceListAccelerators lists accelerators for a dedicated inference endpoint. func RunDedicatedInferenceListAccelerators(c *CmdConfig) error { if len(c.Args) < 1 { diff --git a/commands/dedicated_inference_test.go b/commands/dedicated_inference_test.go index 24b3a2b43..bf966a10c 100644 --- a/commands/dedicated_inference_test.go +++ b/commands/dedicated_inference_test.go @@ -72,6 +72,7 @@ func TestDedicatedInferenceCommand(t *testing.T) { } assert.True(t, subcommands["create"], "Expected create subcommand") assert.True(t, subcommands["get"], "Expected get subcommand") + assert.True(t, subcommands["update"], "Expected update subcommand") assert.True(t, subcommands["delete"], "Expected delete subcommand") assert.True(t, subcommands["list-accelerators"], "Expected list-accelerators subcommand") } @@ -190,6 +191,93 @@ func TestRunDedicatedInferenceDelete_MissingID(t *testing.T) { }) } +func TestRunDedicatedInferenceUpdate(t *testing.T) { + withTestClient(t, func(config *CmdConfig, tm *tcMocks) { + specJSON := `{ + "version": 0, + "name": "test-dedicated-inference", + "region": "nyc2", + "vpc": {"uuid": "00000000-0000-4000-8000-000000000001"}, + "enable_public_endpoint": true, + "model_deployments": [ + { + "model_slug": "mistral/mistral-7b-instruct-v3", + "model_provider": "hugging_face", + "accelerators": [ + {"scale": 2, "type": "prefill", "accelerator_slug": "gpu-mi300x1-192gb"}, + {"scale": 4, "type": "decode", "accelerator_slug": "gpu-mi300x1-192gb"} + ] + } + ] + }` + tmpFile := t.TempDir() + "/spec.json" + err := os.WriteFile(tmpFile, []byte(specJSON), 0644) + assert.NoError(t, err) + + config.Args = append(config.Args, "00000000-0000-4000-8000-000000000000") + config.Doit.Set(config.NS, doctl.ArgDedicatedInferenceSpec, tmpFile) + + expectedReq := &godo.DedicatedInferenceUpdateRequest{ + Spec: testDedicatedInferenceSpecRequest, + } + + updatedDI := testDedicatedInference + tm.dedicatedInferences.EXPECT().Update("00000000-0000-4000-8000-000000000000", expectedReq).Return(&updatedDI, nil) + + err = RunDedicatedInferenceUpdate(config) + assert.NoError(t, err) + }) +} + +func TestRunDedicatedInferenceUpdate_WithHuggingFaceToken(t *testing.T) { + withTestClient(t, func(config *CmdConfig, tm *tcMocks) { + specJSON := `{ + "version": 0, + "name": "test-dedicated-inference", + "region": "nyc2", + "vpc": {"uuid": "00000000-0000-4000-8000-000000000001"}, + "enable_public_endpoint": true, + "model_deployments": [ + { + "model_slug": "mistral/mistral-7b-instruct-v3", + "model_provider": "hugging_face", + "accelerators": [ + {"scale": 2, "type": "prefill", "accelerator_slug": "gpu-mi300x1-192gb"}, + {"scale": 4, "type": "decode", "accelerator_slug": "gpu-mi300x1-192gb"} + ] + } + ] + }` + tmpFile := t.TempDir() + "/spec.json" + err := os.WriteFile(tmpFile, []byte(specJSON), 0644) + assert.NoError(t, err) + + config.Args = append(config.Args, "00000000-0000-4000-8000-000000000000") + config.Doit.Set(config.NS, doctl.ArgDedicatedInferenceSpec, tmpFile) + config.Doit.Set(config.NS, doctl.ArgDedicatedInferenceHuggingFaceToken, "hf_test_token") + + expectedReq := &godo.DedicatedInferenceUpdateRequest{ + Spec: testDedicatedInferenceSpecRequest, + Secrets: &godo.DedicatedInferenceSecrets{ + HuggingFaceToken: "hf_test_token", + }, + } + + updatedDI := testDedicatedInference + tm.dedicatedInferences.EXPECT().Update("00000000-0000-4000-8000-000000000000", expectedReq).Return(&updatedDI, nil) + + err = RunDedicatedInferenceUpdate(config) + assert.NoError(t, err) + }) +} + +func TestRunDedicatedInferenceUpdate_MissingID(t *testing.T) { + withTestClient(t, func(config *CmdConfig, tm *tcMocks) { + err := RunDedicatedInferenceUpdate(config) + assert.Error(t, err) + }) +} + func TestRunDedicatedInferenceListAccelerators(t *testing.T) { withTestClient(t, func(config *CmdConfig, tm *tcMocks) { testAccelerators := do.DedicatedInferenceAcceleratorInfos{ diff --git a/do/dedicated_inference.go b/do/dedicated_inference.go index 102601092..92c106f8f 100644 --- a/do/dedicated_inference.go +++ b/do/dedicated_inference.go @@ -44,6 +44,7 @@ type DedicatedInferenceAcceleratorInfos []DedicatedInferenceAcceleratorInfo type DedicatedInferenceService interface { Create(req *godo.DedicatedInferenceCreateRequest) (*DedicatedInference, *DedicatedInferenceToken, error) Get(id string) (*DedicatedInference, error) + Update(id string, req *godo.DedicatedInferenceUpdateRequest) (*DedicatedInference, error) Delete(id string) error ListAccelerators(diID string, slug string) (DedicatedInferenceAcceleratorInfos, error) } @@ -83,6 +84,15 @@ func (s *dedicatedInferenceService) Get(id string) (*DedicatedInference, error) return &DedicatedInference{DedicatedInference: d}, nil } +// Update updates a dedicated inference endpoint by ID. +func (s *dedicatedInferenceService) Update(id string, req *godo.DedicatedInferenceUpdateRequest) (*DedicatedInference, error) { + d, _, err := s.client.DedicatedInference.Update(context.TODO(), id, req) + if err != nil { + return nil, err + } + return &DedicatedInference{DedicatedInference: d}, nil +} + // Delete deletes a dedicated inference endpoint by ID. func (s *dedicatedInferenceService) Delete(id string) error { _, err := s.client.DedicatedInference.Delete(context.TODO(), id) diff --git a/do/mocks/DedicatedInferenceService.go b/do/mocks/DedicatedInferenceService.go index 948e3cd60..df5432eea 100644 --- a/do/mocks/DedicatedInferenceService.go +++ b/do/mocks/DedicatedInferenceService.go @@ -100,3 +100,18 @@ func (mr *MockDedicatedInferenceServiceMockRecorder) Get(id any) *gomock.Call { mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDedicatedInferenceService)(nil).Get), id) } + +// Update mocks base method. +func (m *MockDedicatedInferenceService) Update(id string, req *godo.DedicatedInferenceUpdateRequest) (*do.DedicatedInference, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Update", id, req) + ret0, _ := ret[0].(*do.DedicatedInference) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Update indicates an expected call of Update. +func (mr *MockDedicatedInferenceServiceMockRecorder) Update(id any, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Update", reflect.TypeOf((*MockDedicatedInferenceService)(nil).Update), id, req) +} diff --git a/integration/dedicated_inference_update_test.go b/integration/dedicated_inference_update_test.go new file mode 100644 index 000000000..5c6d308c6 --- /dev/null +++ b/integration/dedicated_inference_update_test.go @@ -0,0 +1,259 @@ +package integration + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/http/httputil" + "os" + "os/exec" + "strings" + "testing" + + "github.com/sclevine/spec" + "github.com/stretchr/testify/require" +) + +var _ = suite("dedicated-inference/update", func(t *testing.T, when spec.G, it spec.S) { + var ( + expect *require.Assertions + cmd *exec.Cmd + server *httptest.Server + ) + + it.Before(func() { + expect = require.New(t) + + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + switch req.URL.Path { + case "/v2/dedicated-inferences/00000000-0000-4000-8000-000000000000": + auth := req.Header.Get("Authorization") + if auth != "Bearer some-magic-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + if req.Method != http.MethodPatch { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatal("failed to read request body") + } + + var payload map[string]interface{} + err = json.Unmarshal(body, &payload) + if err != nil { + t.Fatalf("failed to parse request body: %s", err) + } + + // Verify the spec is present in the request + if _, ok := payload["spec"]; !ok { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"id":"bad_request","message":"spec is required"}`)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(dedicatedInferenceUpdateResponse)) + case "/v2/dedicated-inferences/99999999-9999-4999-8999-999999999999": + auth := req.Header.Get("Authorization") + if auth != "Bearer some-magic-token" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + if req.Method != http.MethodPatch { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + w.Write([]byte(`{"id":"not_found","message":"The resource you requested could not be found."}`)) + default: + dump, err := httputil.DumpRequest(req, true) + if err != nil { + t.Fatal("failed to dump request") + } + + t.Fatalf("received unknown request: %s", dump) + } + })) + }) + + when("valid ID and spec are provided", func() { + it("updates the dedicated inference endpoint", func() { + specFile := createDedicatedInferenceUpdateSpecFile(t) + + aliases := []string{"update", "u"} + + for _, alias := range aliases { + cmd = exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "dedicated-inference", + alias, + "00000000-0000-4000-8000-000000000000", + "--spec", specFile, + ) + + output, err := cmd.CombinedOutput() + expect.NoError(err, fmt.Sprintf("received error output for alias %q: %s", alias, output)) + expect.Equal(strings.TrimSpace(dedicatedInferenceUpdateOutput), strings.TrimSpace(string(output))) + } + }) + }) + + when("dedicated inference ID is missing", func() { + it("returns an error", func() { + specFile := createDedicatedInferenceUpdateSpecFile(t) + + cmd = exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "dedicated-inference", + "update", + "--spec", specFile, + ) + + output, err := cmd.CombinedOutput() + expect.Error(err) + expect.Contains(string(output), "missing") + }) + }) + + when("spec flag is missing", func() { + it("returns an error", func() { + cmd = exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "dedicated-inference", + "update", + "00000000-0000-4000-8000-000000000000", + ) + + output, err := cmd.CombinedOutput() + expect.Error(err) + expect.Contains(string(output), "spec") + }) + }) + + when("dedicated inference does not exist", func() { + it("returns a not found error", func() { + specFile := createDedicatedInferenceUpdateSpecFile(t) + + cmd = exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "dedicated-inference", + "update", + "99999999-9999-4999-8999-999999999999", + "--spec", specFile, + ) + + output, err := cmd.CombinedOutput() + expect.Error(err) + expect.Contains(string(output), "404") + }) + }) + + when("using the di alias", func() { + it("updates the dedicated inference endpoint", func() { + specFile := createDedicatedInferenceUpdateSpecFile(t) + + cmd = exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "di", + "update", + "00000000-0000-4000-8000-000000000000", + "--spec", specFile, + ) + + output, err := cmd.CombinedOutput() + expect.NoError(err, fmt.Sprintf("received error output: %s", output)) + expect.Equal(strings.TrimSpace(dedicatedInferenceUpdateOutput), strings.TrimSpace(string(output))) + }) + }) + + when("passing a format flag", func() { + it("displays only those columns", func() { + specFile := createDedicatedInferenceUpdateSpecFile(t) + + cmd = exec.Command(builtBinaryPath, + "-t", "some-magic-token", + "-u", server.URL, + "dedicated-inference", + "update", + "00000000-0000-4000-8000-000000000000", + "--spec", specFile, + "--format", "ID,Name,Status", + ) + + output, err := cmd.CombinedOutput() + expect.NoError(err, fmt.Sprintf("received error output: %s", output)) + expect.Equal(strings.TrimSpace(dedicatedInferenceUpdateFormatOutput), strings.TrimSpace(string(output))) + }) + }) +}) + +func createDedicatedInferenceUpdateSpecFile(t *testing.T) string { + t.Helper() + specJSON := `{ + "version": 0, + "name": "updated-dedicated-inference", + "region": "nyc2", + "vpc": {"uuid": "00000000-0000-4000-8000-000000000001"}, + "enable_public_endpoint": true, + "model_deployments": [ + { + "model_slug": "mistral/mistral-7b-instruct-v3", + "model_provider": "hugging_face", + "accelerators": [ + {"scale": 3, "type": "prefill", "accelerator_slug": "gpu-mi300x1-192gb"}, + {"scale": 6, "type": "decode", "accelerator_slug": "gpu-mi300x1-192gb"} + ] + } + ] + }` + tmpFile := t.TempDir() + "/update-spec.json" + err := os.WriteFile(tmpFile, []byte(specJSON), 0644) + if err != nil { + t.Fatalf("failed to write spec file: %s", err) + } + return tmpFile +} + +const ( + dedicatedInferenceUpdateOutput = ` +ID Name Region Status VPC UUID Public Endpoint Private Endpoint Created At Updated At +00000000-0000-4000-8000-000000000000 updated-dedicated-inference nyc2 UPDATING 00000000-0000-4000-8000-000000000001 public.dedicated-inference.example.com private.dedicated-inference.example.com 2023-01-01 00:00:00 +0000 UTC 2023-01-02 00:00:00 +0000 UTC +` + dedicatedInferenceUpdateFormatOutput = ` +ID Name Status +00000000-0000-4000-8000-000000000000 updated-dedicated-inference UPDATING +` + + dedicatedInferenceUpdateResponse = ` +{ + "dedicated_inference": { + "id": "00000000-0000-4000-8000-000000000000", + "name": "updated-dedicated-inference", + "region": "nyc2", + "status": "UPDATING", + "vpc_uuid": "00000000-0000-4000-8000-000000000001", + "endpoints": { + "public_endpoint_fqdn": "public.dedicated-inference.example.com", + "private_endpoint_fqdn": "private.dedicated-inference.example.com" + }, + "created_at": "2023-01-01T00:00:00Z", + "updated_at": "2023-01-02T00:00:00Z" + } +} +` +)