diff --git a/pkg/cloudagents/client.go b/pkg/cloudagents/client.go index 2957374f..088e3eba 100644 --- a/pkg/cloudagents/client.go +++ b/pkg/cloudagents/client.go @@ -122,6 +122,19 @@ func (c *Client) DeployAgent( return c.uploadAndBuild(ctx, agentID, resp.PresignedUrl, resp.PresignedPostRequest, source, excludeFiles, buildLogStreamWriter) } +// RegisterAgent creates an agent record without uploading source or triggering a build. +// Use this when you intend to push a prebuilt image immediately after via GetPushTarget. +func (c *Client) RegisterAgent(ctx context.Context, secrets []*lkproto.AgentSecret, regions []string) (string, error) { + resp, err := c.AgentClient.CreateAgent(ctx, &lkproto.CreateAgentRequest{ + Secrets: secrets, + Regions: regions, + }) + if err != nil { + return "", err + } + return resp.AgentId, nil +} + // CreatePrivateLink creates a new private link for cloud agents. func (c *Client) CreatePrivateLink(ctx context.Context, req *lkproto.CreatePrivateLinkRequest) (*lkproto.CreatePrivateLinkResponse, error) { return c.AgentClient.CreatePrivateLink(ctx, req) diff --git a/pkg/cloudagents/push.go b/pkg/cloudagents/push.go new file mode 100644 index 00000000..c2ed9543 --- /dev/null +++ b/pkg/cloudagents/push.go @@ -0,0 +1,85 @@ +// Copyright 2025 LiveKit, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cloudagents + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" +) + +// PushTarget describes where and how the CLI should push a prebuilt image. +type PushTarget struct { + // ProxyHost is the OCI registry host exposed by cloud-agents (e.g. "agents.livekit.io"). + ProxyHost string `json:"proxy_host"` + // Name is the OCI repository name to use in /v2/{name}/... paths. + Name string `json:"name"` + // Tag is the version tag cloud-agents generated; use this as the image tag. + Tag string `json:"tag"` +} + +// GetPushTarget asks cloud-agents for the OCI proxy location for the given agent. +// The caller should then push the image to ProxyHost/Name:Tag using a transport +// returned by NewRegistryTransport. +func (c *Client) GetPushTarget(ctx context.Context, agentID string) (*PushTarget, error) { + params := url.Values{} + params.Add("agent_id", agentID) + fullURL := fmt.Sprintf("%s/push-target?%s", c.agentsURL, params.Encode()) + + req, err := c.newRequestWithContext(ctx, http.MethodGet, fullURL, nil) + if err != nil { + return nil, err + } + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call push-target: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("push-target returned %d: %s", resp.StatusCode, body) + } + var target PushTarget + if err := json.NewDecoder(resp.Body).Decode(&target); err != nil { + return nil, fmt.Errorf("failed to decode push-target response: %w", err) + } + return &target, nil +} + +// NewRegistryTransport returns an http.RoundTripper that injects the LiveKit JWT on every +// request. Pass this to crane via crane.WithTransport when pushing to the cloud-agents +// OCI proxy so the proxy's auth middleware accepts the requests. +func (c *Client) NewRegistryTransport() http.RoundTripper { + return &lkRegistryTransport{base: http.DefaultTransport, client: c} +} + +// lkRegistryTransport injects LK auth headers on every HTTP request, allowing crane +// to push through the cloud-agents OCI proxy without doing OCI token negotiation. +type lkRegistryTransport struct { + base http.RoundTripper + client *Client +} + +func (t *lkRegistryTransport) RoundTrip(req *http.Request) (*http.Response, error) { + req = req.Clone(req.Context()) + if err := t.client.setAuthToken(req); err != nil { + return nil, err + } + t.client.setLivekitHeaders(req) + return t.base.RoundTrip(req) +} diff --git a/pkg/cloudagents/push_test.go b/pkg/cloudagents/push_test.go new file mode 100644 index 00000000..2383e4b3 --- /dev/null +++ b/pkg/cloudagents/push_test.go @@ -0,0 +1,82 @@ +package cloudagents + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/livekit/protocol/logger" +) + +func TestGetPushTargetSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, `{"proxy_host":"agents.livekit.io","name":"livekit","tag":"v1"}`) + })) + defer server.Close() + + client := &Client{ + httpClient: server.Client(), + logger: logger.GetLogger(), + apiKey: "test-api-key", + apiSecret: "test-api-secret", + agentsURL: server.URL, + } + + target, err := client.GetPushTarget(context.Background(), "test-agent") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if target.ProxyHost != "agents.livekit.io" || target.Name != "livekit" || target.Tag != "v1" { + t.Fatalf("unexpected push target: %+v", target) + } +} + +func TestGetPushTargetNonOK(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintf(w, "oops") + })) + defer server.Close() + + client := &Client{ + httpClient: server.Client(), + logger: logger.GetLogger(), + apiKey: "test-api-key", + apiSecret: "test-api-secret", + agentsURL: server.URL, + } + + _, err := client.GetPushTarget(context.Background(), "test-agent") + if err == nil { + t.Fatal("expected error when push-target returns non-OK status") + } + if !strings.Contains(err.Error(), "push-target returned 400") { + t.Fatalf("unexpected error message: %v", err) + } +} + +func TestGetPushTargetDecodeError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "invalid-json") + })) + defer server.Close() + + client := &Client{ + httpClient: server.Client(), + logger: logger.GetLogger(), + apiKey: "test-api-key", + apiSecret: "test-api-secret", + agentsURL: server.URL, + } + + _, err := client.GetPushTarget(context.Background(), "test-agent") + if err == nil { + t.Fatal("expected decode error") + } + if !strings.Contains(err.Error(), "failed to decode push-target response") { + t.Fatalf("unexpected error message: %v", err) + } +}