diff --git a/decentralized-api/apiconfig/config.go b/decentralized-api/apiconfig/config.go index 8977ded4f..74044d155 100644 --- a/decentralized-api/apiconfig/config.go +++ b/decentralized-api/apiconfig/config.go @@ -2,6 +2,7 @@ package apiconfig import ( "fmt" + "net/url" "strings" ) @@ -90,6 +91,8 @@ type InferenceNodeConfig struct { InferencePort int `koanf:"inference_port" json:"inference_port"` PoCSegment string `koanf:"poc_segment" json:"poc_segment"` PoCPort int `koanf:"poc_port" json:"poc_port"` + BaseURL string `koanf:"base_url" json:"base_url"` + AuthToken string `koanf:"auth_token" json:"auth_token"` Models map[string]ModelConfig `koanf:"models" json:"models"` Id string `koanf:"id" json:"id"` MaxConcurrent int `koanf:"max_concurrent" json:"max_concurrent"` @@ -102,21 +105,46 @@ type InferenceNodeConfig struct { func ValidateInferenceNodeBasic(node InferenceNodeConfig) []string { var errors []string - // Validate required fields - if strings.TrimSpace(node.Id) == "" { - errors = append(errors, "node id is required and cannot be empty") + // Validate host/baseURL configuration + // ensures that a node configuration uses either the legacy host/port registration or the new baseURL registration, but not both. + // When baseURL is provided, it must be a valid HTTP(S) URL. AuthToken is always optional (no validation needed) + hasHostPorts := strings.TrimSpace(node.Host) != "" || node.InferencePort > 0 || node.PoCPort > 0 + hasBaseURL := strings.TrimSpace(node.BaseURL) != "" + + if hasHostPorts && hasBaseURL { + errors = append(errors, "node configuration error: cannot specify both (Host+Ports) and baseURL. Use either Host+InferencePort+PoCPort OR baseURL") } - if strings.TrimSpace(node.Host) == "" { - errors = append(errors, "host is required and cannot be empty") + if !hasHostPorts && !hasBaseURL { + errors = append(errors, "node configuration error: must specify either (Host+InferencePort+PoCPort) OR baseURL") } - if node.InferencePort <= 0 || node.InferencePort > 65535 { - errors = append(errors, fmt.Sprintf("inference_port must be between 1 and 65535, got %d", node.InferencePort)) + if hasHostPorts { + if strings.TrimSpace(node.Host) == "" { + errors = append(errors, "host is required and cannot be empty when using host+port registration") + } + + if node.InferencePort <= 0 || node.InferencePort > 65535 { + errors = append(errors, fmt.Sprintf("inference_port must be between 1 and 65535, got %d", node.InferencePort)) + } + + if node.PoCPort <= 0 || node.PoCPort > 65535 { + errors = append(errors, fmt.Sprintf("poc_port must be between 1 and 65535, got %d", node.PoCPort)) + } + } + + if hasBaseURL { + parsedURL, err := url.Parse(node.BaseURL) + if err != nil { + errors = append(errors, fmt.Sprintf("node configuration error: baseURL is not a valid URL: %v", err)) + } else if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + errors = append(errors, fmt.Sprintf("node configuration error: baseURL must use http:// or https:// scheme, got: %s", parsedURL.Scheme)) + } } - if node.PoCPort <= 0 || node.PoCPort > 65535 { - errors = append(errors, fmt.Sprintf("poc_port must be between 1 and 65535, got %d", node.PoCPort)) + // Validate required fields + if strings.TrimSpace(node.Id) == "" { + errors = append(errors, "node id is required and cannot be empty") } if node.MaxConcurrent <= 0 { diff --git a/decentralized-api/apiconfig/sqlite_store.go b/decentralized-api/apiconfig/sqlite_store.go index 3d0d2a60e..4a09f2ff8 100644 --- a/decentralized-api/apiconfig/sqlite_store.go +++ b/decentralized-api/apiconfig/sqlite_store.go @@ -83,6 +83,8 @@ CREATE TABLE IF NOT EXISTS inference_nodes ( max_concurrent INTEGER NOT NULL, models_json TEXT NOT NULL, hardware_json TEXT NOT NULL, + base_url TEXT NOT NULL DEFAULT '', + auth_token TEXT NOT NULL DEFAULT '', updated_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f','now')), created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f','now')) ); @@ -104,8 +106,62 @@ CREATE TABLE IF NOT EXISTS seed_info ( is_active BOOLEAN NOT NULL DEFAULT 1, created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f','now')) );` - _, err := db.ExecContext(ctx, stmt) - return err + if _, err := db.ExecContext(ctx, stmt); err != nil { + return err + } + + // Migrate existing tables: add base_url and auth_token columns if they don't exist + return migrateInferenceNodesTable(ctx, db) +} + +// migrateInferenceNodesTable adds base_url and auth_token columns to existing inference_nodes table +func migrateInferenceNodesTable(ctx context.Context, db *sql.DB) error { + // Check if base_url column exists + rows, err := db.QueryContext(ctx, "PRAGMA table_info(inference_nodes)") + if err != nil { + return err + } + defer rows.Close() + + var hasBaseURL, hasAuthToken bool + for rows.Next() { + var ( + cid int + name string + declType string + notnull int + dflt sql.NullString + pk int + ) + if err := rows.Scan(&cid, &name, &declType, ¬null, &dflt, &pk); err != nil { + return err + } + if name == "base_url" { + hasBaseURL = true + } + if name == "auth_token" { + hasAuthToken = true + } + } + if err := rows.Err(); err != nil { + return err + } + + // Add base_url column if it doesn't exist + if !hasBaseURL { + if _, err := db.ExecContext(ctx, "ALTER TABLE inference_nodes ADD COLUMN base_url TEXT NOT NULL DEFAULT ''"); err != nil { + return fmt.Errorf("failed to add base_url column: %w", err) + } + } + + // Add auth_token column if it doesn't exist + if !hasAuthToken { + if _, err := db.ExecContext(ctx, "ALTER TABLE inference_nodes ADD COLUMN auth_token TEXT NOT NULL DEFAULT ''"); err != nil { + return fmt.Errorf("failed to add auth_token column: %w", err) + } + } + + return nil } // UpsertInferenceNodes replaces or inserts the given nodes by id. @@ -121,8 +177,8 @@ func UpsertInferenceNodes(ctx context.Context, db *sql.DB, nodes []InferenceNode q := ` INSERT INTO inference_nodes ( - id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json, base_url, auth_token +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(id) DO UPDATE SET host = excluded.host, inference_segment = excluded.inference_segment, @@ -132,6 +188,8 @@ ON CONFLICT(id) DO UPDATE SET max_concurrent = excluded.max_concurrent, models_json = excluded.models_json, hardware_json = excluded.hardware_json, + base_url = excluded.base_url, + auth_token = excluded.auth_token, updated_at = (STRFTIME('%Y-%m-%d %H:%M:%f','now'))` stmt, err := tx.PrepareContext(ctx, q) @@ -160,6 +218,8 @@ ON CONFLICT(id) DO UPDATE SET n.MaxConcurrent, string(modelsJSON), string(hardwareJSON), + n.BaseURL, + n.AuthToken, ); err != nil { return err } @@ -175,7 +235,7 @@ func WriteNodes(ctx context.Context, db *sql.DB, nodes []InferenceNodeConfig) er // ReadNodes reads all nodes from the database and reconstructs InferenceNodeConfig entries. func ReadNodes(ctx context.Context, db *sql.DB) ([]InferenceNodeConfig, error) { rows, err := db.QueryContext(ctx, ` -SELECT id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json +SELECT id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json, base_url, auth_token FROM inference_nodes ORDER BY id`) if err != nil { return nil, err @@ -194,8 +254,10 @@ FROM inference_nodes ORDER BY id`) maxConc int modelsRaw []byte hardwareRaw []byte + baseURL string + authToken string ) - if err := rows.Scan(&id, &host, &infSeg, &infPort, &pocSeg, &pocPort, &maxConc, &modelsRaw, &hardwareRaw); err != nil { + if err := rows.Scan(&id, &host, &infSeg, &infPort, &pocSeg, &pocPort, &maxConc, &modelsRaw, &hardwareRaw, &baseURL, &authToken); err != nil { return nil, err } var models map[string]ModelConfig @@ -216,6 +278,8 @@ FROM inference_nodes ORDER BY id`) InferencePort: infPort, PoCSegment: pocSeg, PoCPort: pocPort, + BaseURL: baseURL, + AuthToken: authToken, Models: models, Id: id, MaxConcurrent: maxConc, @@ -246,8 +310,8 @@ func ReplaceInferenceNodes(ctx context.Context, db *sql.DB, nodes []InferenceNod q := ` INSERT INTO inference_nodes ( - id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json -) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)` + id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json, base_url, auth_token +) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` stmt, err := tx.PrepareContext(ctx, q) if err != nil { @@ -275,6 +339,8 @@ INSERT INTO inference_nodes ( n.MaxConcurrent, string(modelsJSON), string(hardwareJSON), + n.BaseURL, + n.AuthToken, ); err != nil { return err } diff --git a/decentralized-api/broker/broker.go b/decentralized-api/broker/broker.go index 9b62841e6..1ac7a7b78 100644 --- a/decentralized-api/broker/broker.go +++ b/decentralized-api/broker/broker.go @@ -195,6 +195,8 @@ type Node struct { InferencePort int `json:"inference_port"` PoCSegment string `json:"poc_segment"` PoCPort int `json:"poc_port"` + BaseURL string `json:"base_url"` + AuthToken string `json:"auth_token"` Models map[string]ModelArgs `json:"models"` Id string `json:"id"` MaxConcurrent int `json:"max_concurrent"` @@ -202,15 +204,47 @@ type Node struct { Hardware []apiconfig.Hardware `json:"hardware"` } +type MlNodePathElements struct { + Host string `json:"host"` + Port int `json:"port"` + BaseURL string `json:"base_url"` + Version string `json:"version"` + Segment string `json:"segment"` +} + +func GetMlNodeUrl(elements MlNodePathElements) string { + // If BaseURL is provided, build on top of it + if strings.TrimSpace(elements.BaseURL) != "" { + base := strings.TrimRight(elements.BaseURL, "/") + if strings.TrimSpace(elements.Version) == "" { + return fmt.Sprintf("%s%s", base, elements.Segment) + } + return fmt.Sprintf("%s/%s%s", base, strings.TrimSpace(elements.Version), elements.Segment) + } + if strings.TrimSpace(elements.Version) == "" { + return fmt.Sprintf("http://%s:%d%s", elements.Host, elements.Port, elements.Segment) + } + return fmt.Sprintf("http://%s:%d/%s%s", elements.Host, elements.Port, strings.TrimSpace(elements.Version), elements.Segment) +} + func (n *Node) InferenceUrl() string { return fmt.Sprintf("http://%s:%d%s", n.Host, n.InferencePort, n.InferenceSegment) } func (n *Node) InferenceUrlWithVersion(version string) string { - if version == "" { + v := strings.TrimSpace(version) + // If BaseURL is provided, build on top of it + if n.BaseURL != "" { + base := strings.TrimRight(n.BaseURL, "/") + if v == "" { + return fmt.Sprintf("%s%s", base, n.InferenceSegment) + } + return fmt.Sprintf("%s/%s%s", base, v, n.InferenceSegment) + } + if v == "" { return n.InferenceUrl() } - return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.InferencePort, version, n.InferenceSegment) + return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.InferencePort, v, n.InferenceSegment) } func (n *Node) PoCUrl() string { @@ -218,10 +252,32 @@ func (n *Node) PoCUrl() string { } func (n *Node) PoCUrlWithVersion(version string) string { - if version == "" { + v := strings.TrimSpace(version) + // If BaseURL is provided, build on top of it + if n.BaseURL != "" { + base := strings.TrimRight(n.BaseURL, "/") + if v == "" { + return fmt.Sprintf("%s%s", base, n.PoCSegment) + } + return fmt.Sprintf("%s/%s%s", base, v, n.PoCSegment) + } + if v == "" { return n.PoCUrl() } - return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.PoCPort, version, n.PoCSegment) + return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.PoCPort, v, n.PoCSegment) +} + +// BaseUrlWithVersion constructs a base URL with version +func BaseUrlWithVersion(baseURL, version string) string { + base := strings.TrimRight(baseURL, "/") + if strings.TrimSpace(version) != "" { + return fmt.Sprintf("%s/%s", base, strings.TrimSpace(version)) + } + return base +} + +func (n *Node) BaseUrlWithVersion(version string) string { + return BaseUrlWithVersion(n.BaseURL, version) } type NodeWithState struct { @@ -500,7 +556,7 @@ func (b *Broker) QueueMessage(command Command) error { func (b *Broker) NewNodeClient(node *Node) mlnodeclient.MLNodeClient { version := b.configManager.GetCurrentNodeVersion() - return b.mlNodeClientFactory.CreateClient(node.PoCUrlWithVersion(version), node.InferenceUrlWithVersion(version)) + return b.mlNodeClientFactory.CreateClient(node.PoCUrlWithVersion(version), node.InferenceUrlWithVersion(version), node.AuthToken, node.BaseUrlWithVersion(version)) } func (b *Broker) lockAvailableNode(command LockAvailableNode) { @@ -664,7 +720,7 @@ func (b *Broker) GetNodes() ([]NodeResponse, error) { nodes := <-command.Response if nodes == nil { - return nil, errors.New("Error getting nodes") + return nil, errors.New("error getting nodes") } logging.Debug("Got nodes", types.Nodes, "size", len(nodes)) return nodes, nil diff --git a/decentralized-api/broker/broker_test.go b/decentralized-api/broker/broker_test.go index ce7c5092c..73cc7342c 100644 --- a/decentralized-api/broker/broker_test.go +++ b/decentralized-api/broker/broker_test.go @@ -197,7 +197,7 @@ func registerNodeAndSetInferenceStatus(t *testing.T, broker *Broker, node apicon mockClient := mockFactory.GetClientForNode(fmt.Sprintf("http://%s:%d", node.Host, node.PoCPort)) if mockClient == nil { // If it's not created yet, create it. - mockClient = mockFactory.CreateClient(fmt.Sprintf("http://%s:%d", node.Host, node.PoCPort), fmt.Sprintf("http://%s:%d", node.Host, node.InferencePort)).(*mlnodeclient.MockClient) + mockClient = mockFactory.CreateClient(fmt.Sprintf("http://%s:%d", node.Host, node.PoCPort), fmt.Sprintf("http://%s:%d", node.Host, node.InferencePort), "", "").(*mlnodeclient.MockClient) } mockClient.Mu.Lock() mockClient.CurrentState = mlnodeclient.MlNodeState_INFERENCE @@ -244,6 +244,74 @@ func registerNodeAndSetInferenceStatus(t *testing.T, broker *Broker, node apicon t.Fatalf("Node did not reach INFERENCE status in time") } +func TestBaseUrlWithVersion(t *testing.T) { + // Test cases for BaseUrlWithVersion function + tests := []struct { + name string + baseURL string + version string + expected string + }{ + { + name: "Base URL with version", + baseURL: "http://example.com", + version: "v1", + expected: "http://example.com/v1", + }, + { + name: "Base URL without version", + baseURL: "http://example.com", + version: "", + expected: "http://example.com", + }, + { + name: "Base URL with trailing slash and version", + baseURL: "http://example.com/", + version: "v1", + expected: "http://example.com/v1", + }, + { + name: "Base URL with trailing slash and no version", + baseURL: "http://example.com/", + version: "", + expected: "http://example.com", + }, + { + name: "Empty base URL with version", + baseURL: "", + version: "v1", + expected: "/v1", + }, + { + name: "Empty base URL without version", + baseURL: "", + version: "", + expected: "", + }, + { + name: "Version with whitespace", + baseURL: "http://example.com", + version: " v1 ", + expected: "http://example.com/v1", + }, + { + name: "Version_with_whitespace", + baseURL: "https://api.example.com", + version: " v2 ", + expected: "https://api.example.com/v2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := BaseUrlWithVersion(tt.baseURL, tt.version) + if result != tt.expected { + t.Errorf("BaseUrlWithVersion(%q, %q) = %q; expected %q", tt.baseURL, tt.version, result, tt.expected) + } + }) + } +} + func TestNodeRemoval(t *testing.T) { broker := NewTestBroker() node := apiconfig.InferenceNodeConfig{ @@ -469,6 +537,107 @@ func TestNodeShouldBeOperationalTest(t *testing.T) { require.False(t, ShouldBeOperational(adminState, 12, types.InferencePhase)) } +func TestGetMlNodeUrl(t *testing.T) { + tests := []struct { + name string + elements MlNodePathElements + expected string + }{ + { + name: "with BaseURL and Version", + elements: MlNodePathElements{ + BaseURL: "https://api.example.com", + Version: "v2", + Segment: "/endpoint", + }, + expected: "https://api.example.com/v2/endpoint", + }, + { + name: "with BaseURL without Version", + elements: MlNodePathElements{ + BaseURL: "https://api.example.com", + Version: "", + Segment: "/endpoint", + }, + expected: "https://api.example.com/endpoint", + }, + { + name: "without BaseURL with Version", + elements: MlNodePathElements{ + Host: "example.com", + Port: 8080, + Version: "v2", + Segment: "/endpoint", + }, + expected: "http://example.com:8080/v2/endpoint", + }, + { + name: "without BaseURL without Version", + elements: MlNodePathElements{ + Host: "example.com", + Port: 8080, + Version: "", + Segment: "/endpoint", + }, + expected: "http://example.com:8080/endpoint", + }, + { + name: "BaseURL with trailing slash", + elements: MlNodePathElements{ + BaseURL: "https://api.example.com/", + Version: "v2", + Segment: "/endpoint", + }, + expected: "https://api.example.com/v2/endpoint", + }, + { + name: "empty Segment", + elements: MlNodePathElements{ + Host: "example.com", + Port: 8080, + Version: "v2", + Segment: "", + }, + expected: "http://example.com:8080/v2", + }, + { + name: "BaseURL with empty segment", + elements: MlNodePathElements{ + BaseURL: "https://api.example.com", + Version: "v2", + Segment: "", + }, + expected: "https://api.example.com/v2", + }, + { + name: "version_with_whitespace", + elements: MlNodePathElements{ + BaseURL: "https://api.example.com", + Version: " v2 ", + Segment: "/endpoint", + }, + expected: "https://api.example.com/v2/endpoint", + }, + { + name: "version_with_whitespace_without_baseurl", + elements: MlNodePathElements{ + Host: "example.com", + Port: 8080, + Version: " v2 ", + Segment: "/endpoint", + }, + expected: "http://example.com:8080/v2/endpoint", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + actual := GetMlNodeUrl(tt.elements) + assert.Equal(t, tt.expected, actual) + }) + } +} + func TestVersionedUrls(t *testing.T) { node := Node{ Host: "example.com", diff --git a/decentralized-api/broker/node_admin_commands.go b/decentralized-api/broker/node_admin_commands.go index bf03e3f39..e0cf3bced 100644 --- a/decentralized-api/broker/node_admin_commands.go +++ b/decentralized-api/broker/node_admin_commands.go @@ -4,6 +4,8 @@ import ( "decentralized-api/apiconfig" "decentralized-api/logging" "fmt" + "net" + "net/url" "strings" "time" @@ -66,7 +68,71 @@ func (r RegisterNode) GetResponseChannelCapacity() int { return cap(r.Response) } +// validateInferenceNodeConfig validates node configuration: +// - Requires either (Host+Ports) OR baseURL, not both +// - baseURL must be valid HTTP(S) URL +// - AuthToken is always optional (no validation needed) +func validateInferenceNodeConfig(node apiconfig.InferenceNodeConfig) error { + hasHostPorts := strings.TrimSpace(node.Host) != "" && node.InferencePort > 0 && node.PoCPort > 0 + hasBaseURL := strings.TrimSpace(node.BaseURL) != "" + + if hasHostPorts && hasBaseURL { + return fmt.Errorf("node configuration error: cannot specify both (Host+Ports) and baseURL. Use either Host+InferencePort+PoCPort OR baseURL") + } + + if !hasHostPorts && !hasBaseURL { + return fmt.Errorf("node configuration error: must specify either (Host+InferencePort+PoCPort) OR baseURL") + } + + if hasBaseURL { + // Validate baseURL is a valid HTTP(S) URL + parsedURL, err := url.Parse(node.BaseURL) + if err != nil { + return fmt.Errorf("node configuration error: baseURL is not a valid URL: %w", err) + } + + if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return fmt.Errorf("node configuration error: baseURL must use http:// or https:// scheme, got: %s", parsedURL.Scheme) + } + + if parsedURL.Host == "" { + return fmt.Errorf("node configuration error: baseURL must include a valid host") + } + + // Validate host is either a valid IP address or a valid domain name + hostname := parsedURL.Hostname() + if hostname == "" { + return fmt.Errorf("node configuration error: baseURL must include a valid hostname") + } + + // Check if it's a valid IP address + if ip := net.ParseIP(hostname); ip != nil { + // Valid IP address, allow it + } else { + // Not an IP, check if it's a valid domain name format + // Basic validation: domain should contain at least one dot or be localhost + if hostname != "localhost" && !strings.Contains(hostname, ".") { + return fmt.Errorf("node configuration error: baseURL hostname '%s' is not a valid IP address or domain name", hostname) + } + // Additional check: domain should not start or end with dot or hyphen + if strings.HasPrefix(hostname, ".") || strings.HasSuffix(hostname, ".") || + strings.HasPrefix(hostname, "-") || strings.HasSuffix(hostname, "-") { + return fmt.Errorf("node configuration error: baseURL hostname '%s' has invalid format", hostname) + } + } + } + + return nil +} + func (c RegisterNode) Execute(b *Broker) { + // Validate node configuration (Host+Ports vs baseURL) + if err := validateInferenceNodeConfig(c.Node); err != nil { + logging.Error("RegisterNode. Invalid node configuration", types.Nodes, "error", err, "node_id", c.Node.Id) + c.Response <- NodeCommandResponse{Node: nil, Error: err} + return + } + // Enforce model if configured EnforceModel(&c.Node) @@ -112,6 +178,8 @@ func (c RegisterNode) Execute(b *Broker) { InferencePort: c.Node.InferencePort, PoCSegment: c.Node.PoCSegment, PoCPort: c.Node.PoCPort, + BaseURL: c.Node.BaseURL, + AuthToken: c.Node.AuthToken, Models: models, Id: c.Node.Id, MaxConcurrent: c.Node.MaxConcurrent, @@ -194,6 +262,13 @@ func (u UpdateNode) GetResponseChannelCapacity() int { } func (c UpdateNode) Execute(b *Broker) { + // Validate node configuration (Host+Ports vs baseURL) + if err := validateInferenceNodeConfig(c.Node); err != nil { + logging.Error("UpdateNode. Invalid node configuration", types.Nodes, "error", err, "node_id", c.Node.Id) + c.Response <- NodeCommandResponse{Node: nil, Error: err} + return + } + // Fetch existing node first to check if it exists b.mu.RLock() existing, exists := b.nodes[c.Node.Id] @@ -251,6 +326,8 @@ func (c UpdateNode) Execute(b *Broker) { InferencePort: c.Node.InferencePort, PoCSegment: c.Node.PoCSegment, PoCPort: c.Node.PoCPort, + BaseURL: c.Node.BaseURL, + AuthToken: c.Node.AuthToken, Models: models, Id: c.Node.Id, MaxConcurrent: c.Node.MaxConcurrent, diff --git a/decentralized-api/broker/node_worker.go b/decentralized-api/broker/node_worker.go index b57145fcf..c6b50e6c8 100644 --- a/decentralized-api/broker/node_worker.go +++ b/decentralized-api/broker/node_worker.go @@ -133,8 +133,9 @@ func (w *NodeWorker) CheckClientVersionAlive(version string, factory mlnodeclien node := w.node.Node pocUrl := node.PoCUrlWithVersion(version) inferenceUrl := node.InferenceUrlWithVersion(version) + baseUrl := node.BaseUrlWithVersion(version) - versionClient := factory.CreateClient(pocUrl, inferenceUrl) + versionClient := factory.CreateClient(pocUrl, inferenceUrl, node.AuthToken, baseUrl) _, err := versionClient.NodeState(context.Background()) w.versionsMu.Lock() diff --git a/decentralized-api/broker/node_worker_test.go b/decentralized-api/broker/node_worker_test.go index 830dbb2df..ef25a6360 100644 --- a/decentralized-api/broker/node_worker_test.go +++ b/decentralized-api/broker/node_worker_test.go @@ -322,7 +322,7 @@ func TestNodeWorker_CheckClientVersionAlive(t *testing.T) { versionedPocUrl2 := node.Node.PoCUrlWithVersion(version2) // Configure the mock client for this version to return an error - version2Client := mockFactory.CreateClient(versionedPocUrl2, "").(*mlnodeclient.MockClient) + version2Client := mockFactory.CreateClient(versionedPocUrl2, "", "", "").(*mlnodeclient.MockClient) testErr := errors.New("node not ready") version2Client.NodeStateError = testErr diff --git a/decentralized-api/internal/event_listener/integration_test.go b/decentralized-api/internal/event_listener/integration_test.go index a5e37cf2c..ed36bff93 100644 --- a/decentralized-api/internal/event_listener/integration_test.go +++ b/decentralized-api/internal/event_listener/integration_test.go @@ -444,7 +444,7 @@ func (setup *IntegrationTestSetup) getNodeClient(nodeId string, port int) *mlnod client := setup.MockClientFactory.GetClientForNode(pocUrl) if client == nil { // Create the client if it doesn't exist (should have been created by node registration) - setup.MockClientFactory.CreateClient(pocUrl, inferenceUrl) + setup.MockClientFactory.CreateClient(pocUrl, inferenceUrl, "", "") client = setup.MockClientFactory.GetClientForNode(pocUrl) if client == nil { panic(fmt.Sprintf("Mock client is still nil after creation for pocUrl: %s", pocUrl)) diff --git a/decentralized-api/internal/modelmanager/mlnode_background_manager.go b/decentralized-api/internal/modelmanager/mlnode_background_manager.go index 688faed1f..b0c3586da 100644 --- a/decentralized-api/internal/modelmanager/mlnode_background_manager.go +++ b/decentralized-api/internal/modelmanager/mlnode_background_manager.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "sort" + "strings" "time" "github.com/productscience/inference/x/inference/types" @@ -127,7 +128,8 @@ func (m *MLNodeBackgroundManager) checkNodeModels(node apiconfig.InferenceNodeCo version := m.configManager.GetCurrentNodeVersion() pocUrl := getPoCUrlWithVersion(node, version) inferenceUrl := getInferenceUrlWithVersion(node, version) - client := m.mlNodeClientFactory.CreateClient(pocUrl, inferenceUrl) + baseUrl := getBaseUrlWithVersion(node, version) + client := m.mlNodeClientFactory.CreateClient(pocUrl, inferenceUrl, node.AuthToken, baseUrl) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -207,18 +209,30 @@ func getInferenceUrlWithVersion(node apiconfig.InferenceNodeConfig, version stri } func getPoCUrl(node apiconfig.InferenceNodeConfig) string { + if node.BaseURL != "" { + return formatBaseURL(node.BaseURL, node.PoCSegment) + } return formatURL(node.Host, node.PoCPort, node.PoCSegment) } func getPoCUrlVersioned(node apiconfig.InferenceNodeConfig, version string) string { + if node.BaseURL != "" { + return formatBaseURLWithVersion(node.BaseURL, version, node.PoCSegment) + } return formatURLWithVersion(node.Host, node.PoCPort, version, node.PoCSegment) } func getInferenceUrl(node apiconfig.InferenceNodeConfig) string { + if node.BaseURL != "" { + return formatBaseURL(node.BaseURL, node.InferenceSegment) + } return formatURL(node.Host, node.InferencePort, node.InferenceSegment) } func getInferenceUrlVersioned(node apiconfig.InferenceNodeConfig, version string) string { + if node.BaseURL != "" { + return formatBaseURLWithVersion(node.BaseURL, version, node.InferenceSegment) + } return formatURLWithVersion(node.Host, node.InferencePort, version, node.InferenceSegment) } @@ -227,7 +241,33 @@ func formatURL(host string, port int, segment string) string { } func formatURLWithVersion(host string, port int, version string, segment string) string { - return fmt.Sprintf("http://%s:%d/%s%s", host, port, version, segment) + v := strings.TrimSpace(version) + if v == "" { + return fmt.Sprintf("http://%s:%d%s", host, port, segment) + } + return fmt.Sprintf("http://%s:%d/%s%s", host, port, v, segment) +} + +func formatBaseURL(baseURL string, segment string) string { + // seg := segment + // if seg == "" { + // seg = "/" + // } + base := strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s", base, segment) +} + +func formatBaseURLWithVersion(baseURL string, version string, segment string) string { + base := strings.TrimRight(baseURL, "/") + v := strings.TrimSpace(version) + if v == "" { + return fmt.Sprintf("%s%s", base, segment) + } + return fmt.Sprintf("%s/%s%s", base, v, segment) +} + +func getBaseUrlWithVersion(node apiconfig.InferenceNodeConfig, version string) string { + return broker.BaseUrlWithVersion(node.BaseURL, version) } // checkAndUpdateGPUs fetches GPU info from all nodes and updates hardware @@ -291,7 +331,8 @@ func (m *MLNodeBackgroundManager) fetchNodeGPUHardware(ctx context.Context, node version := m.configManager.GetCurrentNodeVersion() pocUrl := getPoCUrlWithVersion(*node, version) inferenceUrl := getInferenceUrlWithVersion(*node, version) - client := m.mlNodeClientFactory.CreateClient(pocUrl, inferenceUrl) + baseUrl := getBaseUrlWithVersion(*node, version) + client := m.mlNodeClientFactory.CreateClient(pocUrl, inferenceUrl, node.AuthToken, baseUrl) timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() diff --git a/decentralized-api/internal/modelmanager/mlnode_background_manager_test.go b/decentralized-api/internal/modelmanager/mlnode_background_manager_test.go index 3fe26a5d7..6ccbae07d 100644 --- a/decentralized-api/internal/modelmanager/mlnode_background_manager_test.go +++ b/decentralized-api/internal/modelmanager/mlnode_background_manager_test.go @@ -74,7 +74,7 @@ type mockClientFactory struct { client mlnodeclient.MLNodeClient } -func (m *mockClientFactory) CreateClient(pocUrl, inferenceUrl string) mlnodeclient.MLNodeClient { +func (m *mockClientFactory) CreateClient(pocUrl, inferenceUrl string, authToken string, baseURL string) mlnodeclient.MLNodeClient { return m.client } @@ -553,6 +553,168 @@ func TestCheckNodeModels(t *testing.T) { // Test URL formatting func TestURLFormatting(t *testing.T) { + // Test cases for BaseURL with version + tests := []struct { + name string + node apiconfig.InferenceNodeConfig + version string + isPoC bool + expected string + }{ + { + name: "PoC with BaseURL and Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "/api", + BaseURL: "https://api.example.com", + }, + version: "v2", + isPoC: true, + expected: "https://api.example.com/v2/api", + }, + { + name: "PoC with BaseURL without Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "/api", + BaseURL: "https://api.example.com", + }, + version: "", + isPoC: true, + expected: "https://api.example.com/api", + }, + { + name: "PoC without BaseURL with Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "/api", + }, + version: "v2", + isPoC: true, + expected: "http://localhost:8080/v2/api", + }, + { + name: "PoC without BaseURL without Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "/api", + }, + version: "", + isPoC: true, + expected: "http://localhost:8080/api", + }, + { + name: "Inference with BaseURL and Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + InferencePort: 8081, + InferenceSegment: "/inference", + BaseURL: "https://api.example.com", + }, + version: "v2", + isPoC: false, + expected: "https://api.example.com/v2/inference", + }, + { + name: "Inference with BaseURL without Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + InferencePort: 8081, + InferenceSegment: "/inference", + BaseURL: "https://api.example.com", + }, + version: "", + isPoC: false, + expected: "https://api.example.com/inference", + }, + { + name: "Inference without BaseURL with Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + InferencePort: 8081, + InferenceSegment: "/inference", + }, + version: "v2", + isPoC: false, + expected: "http://localhost:8081/v2/inference", + }, + { + name: "Inference without BaseURL without Version", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + InferencePort: 8081, + InferenceSegment: "/inference", + }, + version: "", + isPoC: false, + expected: "http://localhost:8081/inference", + }, + { + name: "PoC with BaseURL with trailing slash", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "/api", + BaseURL: "https://api.example.com/", + }, + version: "v2", + isPoC: true, + expected: "https://api.example.com/v2/api", + }, + { + name: "PoC with empty segment", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "", + }, + version: "v2", + isPoC: true, + expected: "http://localhost:8080/v2", + }, + { + name: "Version_with_whitespace", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + InferencePort: 8081, + InferenceSegment: "/inference", + }, + version: " v2 ", + isPoC: false, + expected: "http://localhost:8081/v2/inference", + }, + { + name: "Version_with_whitespace_PoC", + node: apiconfig.InferenceNodeConfig{ + Host: "localhost", + PoCPort: 8080, + PoCSegment: "/api", + }, + version: " v2 ", + isPoC: true, + expected: "http://localhost:8080/v2/api", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var url string + if tt.isPoC { + url = getPoCUrlVersioned(tt.node, tt.version) + } else { + url = getInferenceUrlVersioned(tt.node, tt.version) + } + if url != tt.expected { + t.Errorf("expected %s, got %s", tt.expected, url) + } + }) + } + + // Legacy tests with simple node node := apiconfig.InferenceNodeConfig{ Host: "localhost", PoCPort: 8080, @@ -610,6 +772,82 @@ func TestURLFormatting(t *testing.T) { }) } +// Test getBaseUrlWithVersion +func TestGetBaseUrlWithVersion(t *testing.T) { + tests := []struct { + name string + node apiconfig.InferenceNodeConfig + version string + expected string + }{ + { + name: "Base URL with version", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "https://api.example.com", + }, + version: "v1", + expected: "https://api.example.com/v1", + }, + { + name: "Base URL without version", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "https://api.example.com", + }, + version: "", + expected: "https://api.example.com", + }, + { + name: "Base URL with trailing slash and version", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "https://api.example.com/", + }, + version: "v1", + expected: "https://api.example.com/v1", + }, + { + name: "Base URL with trailing slash and no version", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "https://api.example.com/", + }, + version: "", + expected: "https://api.example.com", + }, + { + name: "Empty base URL with version", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "", + }, + version: "v1", + expected: "/v1", + }, + { + name: "Empty base URL without version", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "", + }, + version: "", + expected: "", + }, + { + name: "Version with whitespace", + node: apiconfig.InferenceNodeConfig{ + BaseURL: "https://api.example.com", + }, + version: " v1 ", + expected: "https://api.example.com/v1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := getBaseUrlWithVersion(tt.node, tt.version) + if result != tt.expected { + t.Errorf("getBaseUrlWithVersion() = %q; expected %q", result, tt.expected) + } + }) + } +} + // Test GPU transformation func TestTransformGPUDevicesToHardware(t *testing.T) { t.Run("identical GPUs grouped", func(t *testing.T) { diff --git a/decentralized-api/internal/server/admin/server_test.go b/decentralized-api/internal/server/admin/server_test.go index 61bbe2d07..d7fa66e64 100644 --- a/decentralized-api/internal/server/admin/server_test.go +++ b/decentralized-api/internal/server/admin/server_test.go @@ -205,7 +205,7 @@ func TestPostVersionStatus(t *testing.T) { // Pre-configure the mock client to return an error pocURL := "http://localhost:8081/v1.2.4/api/v1" - mockClient := mockClientFactory.CreateClient(pocURL, "").(*mlnodeclient.MockClient) + mockClient := mockClientFactory.CreateClient(pocURL, "", "", "").(*mlnodeclient.MockClient) mockClient.NodeStateError = errors.New("connection failed") s.e.ServeHTTP(rec, req) diff --git a/decentralized-api/internal/server/admin/setup_report.go b/decentralized-api/internal/server/admin/setup_report.go index 1d6963808..811e5adff 100644 --- a/decentralized-api/internal/server/admin/setup_report.go +++ b/decentralized-api/internal/server/admin/setup_report.go @@ -784,10 +784,16 @@ func getPoCUrlWithVersion(node apiconfig.InferenceNodeConfig, version string) st } func getPoCUrl(node apiconfig.InferenceNodeConfig) string { + if node.BaseURL != "" { + return formatBaseURL(node.BaseURL, node.PoCSegment) + } return formatURL(node.Host, node.PoCPort, node.PoCSegment) } func getPoCUrlVersioned(node apiconfig.InferenceNodeConfig, version string) string { + if node.BaseURL != "" { + return formatBaseURLWithVersion(node.BaseURL, version, node.PoCSegment) + } return formatURLWithVersion(node.Host, node.PoCPort, version, node.PoCSegment) } @@ -798,3 +804,21 @@ func formatURL(host string, port int, segment string) string { func formatURLWithVersion(host string, port int, version string, segment string) string { return fmt.Sprintf("http://%s:%d/%s%s", host, port, version, segment) } + +func formatBaseURL(baseURL string, segment string) string { + // seg := segment + // if seg == "" { + // seg = "/" + // } + base := strings.TrimRight(baseURL, "/") + return fmt.Sprintf("%s%s", base, segment) +} + +func formatBaseURLWithVersion(baseURL string, version string, segment string) string { + base := strings.TrimRight(baseURL, "/") + v := strings.TrimSpace(version) + if v == "" { + return fmt.Sprintf("%s%s", base, segment) + } + return fmt.Sprintf("%s/%s%s", base, v, segment) +} diff --git a/decentralized-api/internal/server/public/post_chat_handler.go b/decentralized-api/internal/server/public/post_chat_handler.go index 638ce91d8..78f7c9489 100644 --- a/decentralized-api/internal/server/public/post_chat_handler.go +++ b/decentralized-api/internal/server/public/post_chat_handler.go @@ -15,6 +15,7 @@ import ( "net/http" "net/url" "strconv" + "strings" "sync" "time" @@ -469,16 +470,8 @@ func (s *Server) getPromptTokenCount(text string, model string) (int, error) { Model: model, Prompt: text, } - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, broker.NewApplicationActionError(err) - } - resp, postErr := s.httpClient.Post( - tokenizeUrl, - "application/json", - bytes.NewReader(jsonData), - ) + resp, postErr := utils.SendPostJsonRequestWithAuth(context.Background(), http.DefaultClient, tokenizeUrl, reqBody, node.AuthToken) if postErr != nil { return nil, broker.NewTransportActionError(postErr) } @@ -556,11 +549,15 @@ func (s *Server) handleExecutorRequest(ctx echo.Context, request *ChatRequest, w if err != nil { return nil, broker.NewApplicationActionError(err) } - resp, postErr := s.httpClient.Post( - completionsUrl, - request.Request.Header.Get("Content-Type"), - bytes.NewReader(modifiedRequestBody.NewBody), - ) + req, err := http.NewRequest(http.MethodPost, completionsUrl, bytes.NewBuffer(modifiedRequestBody.NewBody)) + if err != nil { + return nil, broker.NewApplicationActionError(fmt.Errorf("failed to create request: %w", err)) + } + req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(node.AuthToken) != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(node.AuthToken))) + } + resp, postErr := s.httpClient.Do(req) if postErr != nil { return nil, broker.NewTransportActionError(postErr) } diff --git a/decentralized-api/internal/server/public/post_chat_handler_test.go b/decentralized-api/internal/server/public/post_chat_handler_test.go index 8f89457ee..d43b5964d 100644 --- a/decentralized-api/internal/server/public/post_chat_handler_test.go +++ b/decentralized-api/internal/server/public/post_chat_handler_test.go @@ -6,8 +6,10 @@ import ( "encoding/json" "io" "net/http" + "strings" "testing" + "decentralized-api/broker" "decentralized-api/chainphase" "decentralized-api/payloadstorage" @@ -227,3 +229,63 @@ func TestMaxRequestBodySizeConstant(t *testing.T) { expectedSize := 10 * 1024 * 1024 require.Equal(t, expectedSize, MaxRequestBodySize, "MaxRequestBodySize should be 10 MB") } + +func TestHTTPRequestWithAuthToken(t *testing.T) { + tests := []struct { + name string + authToken string + expectAuth bool + }{ + { + name: "Empty token", + authToken: "", + expectAuth: false, + }, + { + name: "Valid token", + authToken: "valid-token", + expectAuth: true, + }, + { + name: "Whitespace token", + authToken: " ", + expectAuth: false, + }, + } + + endpoints := []string{"tokenize", "completions"} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test node + node := &broker.Node{ + AuthToken: tt.authToken, + } + + for _, endpoint := range endpoints { + t.Run(endpoint+"Request", func(t *testing.T) { + req, err := http.NewRequest("POST", "http://example.com/"+endpoint, bytes.NewReader([]byte("{}"))) + if err != nil { + t.Fatal(err) + } + + // This would normally be in the actual handler + req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(node.AuthToken) != "" { + req.Header.Set("Authorization", "Bearer "+node.AuthToken) + } + + if tt.expectAuth { + if req.Header.Get("Authorization") == "" { + t.Error("Expected Authorization header but got none") + } + } else { + if req.Header.Get("Authorization") != "" { + t.Error("Expected no Authorization header but got one") + } + } + }) + } + }) + } +} diff --git a/decentralized-api/internal/validation/http_auth_test.go b/decentralized-api/internal/validation/http_auth_test.go new file mode 100644 index 000000000..af14288d4 --- /dev/null +++ b/decentralized-api/internal/validation/http_auth_test.go @@ -0,0 +1,67 @@ +package validation_test + +import ( + "bytes" + "net/http" + "strings" + "testing" +) + +// MockNode is a simplified version of broker.Node for testing +type MockNode struct { + AuthToken string +} + +func TestHTTPRequestAuthTokenHandling(t *testing.T) { + tests := []struct { + name string + authToken string + expectAuth bool + }{ + { + name: "Empty token", + authToken: "", + expectAuth: false, + }, + { + name: "Valid token", + authToken: "valid-token", + expectAuth: true, + }, + { + name: "Whitespace token", + authToken: " ", + expectAuth: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test node + node := &MockNode{ + AuthToken: tt.authToken, + } + + req, err := http.NewRequest("POST", "http://example.com/validate", bytes.NewReader([]byte("{}"))) + if err != nil { + t.Fatal(err) + } + + // This would normally be in the actual handler + req.Header.Set("Content-Type", "application/json") + if strings.TrimSpace(node.AuthToken) != "" { + req.Header.Set("Authorization", "Bearer "+node.AuthToken) + } + + if tt.expectAuth { + if req.Header.Get("Authorization") == "" { + t.Error("Expected Authorization header but got none") + } + } else { + if req.Header.Get("Authorization") != "" { + t.Error("Expected no Authorization header but got one") + } + } + }) + } +} diff --git a/decentralized-api/internal/validation/inference_validation.go b/decentralized-api/internal/validation/inference_validation.go index fc949d751..af08b1d9a 100644 --- a/decentralized-api/internal/validation/inference_validation.go +++ b/decentralized-api/internal/validation/inference_validation.go @@ -1,15 +1,15 @@ package validation import ( - "bytes" "context" "decentralized-api/apiconfig" "decentralized-api/broker" "decentralized-api/chainphase" "decentralized-api/completionapi" "decentralized-api/cosmosclient" - "decentralized-api/internal/utils" + internalutils "decentralized-api/internal/utils" "decentralized-api/logging" + "decentralized-api/utils" "encoding/json" "errors" "fmt" @@ -884,22 +884,13 @@ func (s *InferenceValidator) validateWithPayloads(inference types.Inference, inf requestMap["skip_special_tokens"] = false delete(requestMap, "stream_options") - requestBody, err := json.Marshal(requestMap) - if err != nil { - return nil, err - } - completionsUrl, err := url.JoinPath(inferenceNode.InferenceUrlWithVersion(s.configManager.GetCurrentNodeVersion()), "v1/chat/completions") if err != nil { logging.Error("Failed to join url", types.Validation, "url", inferenceNode.InferenceUrlWithVersion(s.configManager.GetCurrentNodeVersion()), "error", err) return nil, err } - resp, err := http.Post( - completionsUrl, - "application/json", - bytes.NewReader(requestBody), - ) + resp, err := utils.SendPostJsonRequestWithAuth(context.Background(), http.DefaultClient, completionsUrl, requestMap, inferenceNode.AuthToken) if err != nil { return nil, err } @@ -1182,7 +1173,7 @@ func ToMsgValidation(result ValidationResult) (*inference.MsgValidation, error) return nil, errors.New("unknown validation result type") } - responseHash, _, err := utils.GetResponseHash(result.GetValidationResponseBytes()) + responseHash, _, err := internalutils.GetResponseHash(result.GetValidationResponseBytes()) if err != nil { logging.Error("Failed to get response hash", types.Validation, "error", err) return nil, err diff --git a/decentralized-api/mlnodeclient/client.go b/decentralized-api/mlnodeclient/client.go index 3aaa90b1a..f91307e74 100644 --- a/decentralized-api/mlnodeclient/client.go +++ b/decentralized-api/mlnodeclient/client.go @@ -27,18 +27,22 @@ const ( type Client struct { pocUrl string inferenceUrl string + baseURL string client http.Client mlGrpcCallbackAddress string + authToken string } -func NewNodeClient(pocUrl string, inferenceUrl string) *Client { +func NewNodeClient(pocUrl string, inferenceUrl string, authToken string, baseURL string) *Client { return &Client{ pocUrl: pocUrl, inferenceUrl: inferenceUrl, + baseURL: baseURL, client: http.Client{ Timeout: 15 * time.Minute, }, mlGrpcCallbackAddress: "api-private:9300", // TODO: PRTODO: make this configurable + authToken: authToken, } } @@ -167,7 +171,7 @@ func (api *Client) StartTraining(ctx context.Context, taskId uint64, participant } logging.Info("Starting training with", types.Training, "trainEnv", trainEnv) - _, err = utils.SendPostJsonRequest(ctx, &api.client, requestUrl, body) + _, err = utils.SendPostJsonRequestWithAuth(ctx, &api.client, requestUrl, body, api.authToken) if err != nil { return err } @@ -181,7 +185,7 @@ func (api *Client) GetTrainingStatus(ctx context.Context) error { return err } - _, err = utils.SendGetRequest(ctx, &api.client, requestUrl) + _, err = utils.SendGetRequestWithAuth(ctx, &api.client, requestUrl, api.authToken) if err != nil { return err } @@ -195,7 +199,7 @@ func (api *Client) Stop(ctx context.Context) error { return err } - _, err = utils.SendPostJsonRequest(ctx, &api.client, requestUrl, nil) + _, err = utils.SendPostJsonRequestWithAuth(ctx, &api.client, requestUrl, nil, api.authToken) if err != nil { return err } @@ -222,7 +226,7 @@ func (api *Client) NodeState(ctx context.Context) (*StateResponse, error) { return nil, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return nil, err } @@ -263,7 +267,7 @@ func (api *Client) GetPowStatus(ctx context.Context) (*PowStatusResponse, error) return nil, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return nil, err } @@ -282,12 +286,23 @@ func (api *Client) GetPowStatus(ctx context.Context) (*PowStatusResponse, error) } func (api *Client) InferenceHealth(ctx context.Context) (bool, error) { - requestURL, err := url.JoinPath(api.inferenceUrl, "/health") + var requestURL string + var err error + + // Determine health endpoint based on registration method + if api.baseURL != "" { + // New registration (baseURL): Check /readyz + requestURL, err = url.JoinPath(api.baseURL, "/readyz") + } else { + // Legacy registration (Host/Port/Segment): Check http://:/health + requestURL, err = url.JoinPath(api.inferenceUrl, "/health") + } + if err != nil { return false, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return false, err } @@ -320,7 +335,7 @@ func (api *Client) InferenceUp(ctx context.Context, model string, args []string) logging.Info("Sending inference/up request to node", types.PoC, "inferenceUpUrl", inferenceUpUrl, "body", dto) - _, err = utils.SendPostJsonRequest(ctx, &api.client, inferenceUpUrl, dto) + _, err = utils.SendPostJsonRequestWithAuth(ctx, &api.client, inferenceUpUrl, dto, api.authToken) if err != nil { logging.Error("Failed to send inference/up request", types.PoC, "error", err, "inferenceUpUrl", inferenceUpUrl, "inferenceUpDto", dto) } diff --git a/decentralized-api/mlnodeclient/client_factory.go b/decentralized-api/mlnodeclient/client_factory.go index 4ff228b5b..08ca7eeb3 100644 --- a/decentralized-api/mlnodeclient/client_factory.go +++ b/decentralized-api/mlnodeclient/client_factory.go @@ -3,13 +3,13 @@ package mlnodeclient import "sync" type ClientFactory interface { - CreateClient(pocUrl string, inferenceUrl string) MLNodeClient + CreateClient(pocUrl string, inferenceUrl string, authToken string, baseURL string) MLNodeClient } type HttpClientFactory struct{} -func (f *HttpClientFactory) CreateClient(pocUrl string, inferenceUrl string) MLNodeClient { - return NewNodeClient(pocUrl, inferenceUrl) +func (f *HttpClientFactory) CreateClient(pocUrl string, inferenceUrl string, authToken string, baseURL string) MLNodeClient { + return NewNodeClient(pocUrl, inferenceUrl, authToken, baseURL) } type MockClientFactory struct { @@ -23,7 +23,7 @@ func NewMockClientFactory() *MockClientFactory { } } -func (f *MockClientFactory) CreateClient(pocUrl string, inferenceUrl string) MLNodeClient { +func (f *MockClientFactory) CreateClient(pocUrl string, inferenceUrl string, authToken string, baseURL string) MLNodeClient { f.mu.Lock() defer f.mu.Unlock() diff --git a/decentralized-api/mlnodeclient/gpu.go b/decentralized-api/mlnodeclient/gpu.go index 5e4093aae..9cd0da4f2 100644 --- a/decentralized-api/mlnodeclient/gpu.go +++ b/decentralized-api/mlnodeclient/gpu.go @@ -2,12 +2,11 @@ package mlnodeclient import ( "context" + "decentralized-api/utils" "encoding/json" "fmt" "net/http" "net/url" - - "decentralized-api/utils" ) const ( @@ -24,7 +23,7 @@ func (api *Client) GetGPUDevices(ctx context.Context) (*GPUDevicesResponse, erro return nil, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return nil, err } @@ -55,7 +54,7 @@ func (api *Client) GetGPUDriver(ctx context.Context) (*DriverInfo, error) { return nil, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return nil, err } diff --git a/decentralized-api/mlnodeclient/gpu_test.go b/decentralized-api/mlnodeclient/gpu_test.go index 7b0247f66..1cb9a9a7f 100644 --- a/decentralized-api/mlnodeclient/gpu_test.go +++ b/decentralized-api/mlnodeclient/gpu_test.go @@ -36,7 +36,7 @@ func TestClient_GetGPUDevices(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.GetGPUDevices(ctx) if err != nil { @@ -63,7 +63,7 @@ func TestClient_GetGPUDevices(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.GetGPUDevices(ctx) if err != nil { @@ -80,7 +80,7 @@ func TestClient_GetGPUDevices(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.GetGPUDevices(ctx) if err == nil { @@ -104,7 +104,7 @@ func TestClient_GetGPUDevices(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.GetGPUDevices(ctx) if err == nil { @@ -121,7 +121,7 @@ func TestClient_GetGPUDevices(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.GetGPUDevices(ctx) if err == nil { @@ -156,7 +156,7 @@ func TestClient_GetGPUDriver(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.GetGPUDriver(ctx) if err != nil { @@ -176,7 +176,7 @@ func TestClient_GetGPUDriver(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.GetGPUDriver(ctx) if err == nil { @@ -193,7 +193,7 @@ func TestClient_GetGPUDriver(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.GetGPUDriver(ctx) if err == nil { diff --git a/decentralized-api/mlnodeclient/models.go b/decentralized-api/mlnodeclient/models.go index f5b53fda4..95235145e 100644 --- a/decentralized-api/mlnodeclient/models.go +++ b/decentralized-api/mlnodeclient/models.go @@ -2,12 +2,11 @@ package mlnodeclient import ( "context" + "decentralized-api/utils" "encoding/json" "fmt" "net/http" "net/url" - - "decentralized-api/utils" ) const ( @@ -27,7 +26,7 @@ func (api *Client) CheckModelStatus(ctx context.Context, model Model) (*ModelSta return nil, err } - resp, err := utils.SendPostJsonRequest(ctx, &api.client, requestURL, model) + resp, err := utils.SendPostJsonRequestWithAuth(ctx, &api.client, requestURL, model, api.authToken) if err != nil { return nil, err } @@ -60,7 +59,7 @@ func (api *Client) DownloadModel(ctx context.Context, model Model) (*DownloadSta return nil, err } - resp, err := utils.SendPostJsonRequest(ctx, &api.client, requestURL, model) + resp, err := utils.SendPostJsonRequestWithAuth(ctx, &api.client, requestURL, model, api.authToken) if err != nil { return nil, err } @@ -101,7 +100,7 @@ func (api *Client) DeleteModel(ctx context.Context, model Model) (*DeleteRespons return nil, err } - resp, err := utils.SendDeleteJsonRequest(ctx, &api.client, requestURL, model) + resp, err := utils.SendDeleteJsonRequestWithAuth(ctx, &api.client, requestURL, model, api.authToken) if err != nil { return nil, err } @@ -132,7 +131,7 @@ func (api *Client) ListModels(ctx context.Context) (*ModelListResponse, error) { return nil, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return nil, err } @@ -163,7 +162,7 @@ func (api *Client) GetDiskSpace(ctx context.Context) (*DiskSpaceInfo, error) { return nil, err } - resp, err := utils.SendGetRequest(ctx, &api.client, requestURL) + resp, err := utils.SendGetRequestWithAuth(ctx, &api.client, requestURL, api.authToken) if err != nil { return nil, err } diff --git a/decentralized-api/mlnodeclient/models_test.go b/decentralized-api/mlnodeclient/models_test.go index 2a0f719e9..317df55a1 100644 --- a/decentralized-api/mlnodeclient/models_test.go +++ b/decentralized-api/mlnodeclient/models_test.go @@ -44,7 +44,7 @@ func TestClient_CheckModelStatus(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.CheckModelStatus(ctx, model) if err != nil { @@ -76,7 +76,7 @@ func TestClient_CheckModelStatus(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.CheckModelStatus(ctx, model) if err != nil { @@ -110,7 +110,7 @@ func TestClient_CheckModelStatus(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.CheckModelStatus(ctx, model) if err != nil { @@ -127,7 +127,7 @@ func TestClient_CheckModelStatus(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") model := Model{HfRepo: "test/model"} _, err := client.CheckModelStatus(ctx, model) @@ -168,7 +168,7 @@ func TestClient_DownloadModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.DownloadModel(ctx, model) if err != nil { @@ -188,7 +188,7 @@ func TestClient_DownloadModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") model := Model{HfRepo: "test/model"} _, err := client.DownloadModel(ctx, model) @@ -206,7 +206,7 @@ func TestClient_DownloadModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") model := Model{HfRepo: "test/model"} _, err := client.DownloadModel(ctx, model) @@ -224,7 +224,7 @@ func TestClient_DownloadModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") model := Model{HfRepo: "test/model"} _, err := client.DownloadModel(ctx, model) @@ -272,7 +272,7 @@ func TestClient_DeleteModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.DeleteModel(ctx, model) if err != nil { @@ -300,7 +300,7 @@ func TestClient_DeleteModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.DeleteModel(ctx, model) if err != nil { @@ -317,7 +317,7 @@ func TestClient_DeleteModel(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") model := Model{HfRepo: "test/model"} _, err := client.DeleteModel(ctx, model) @@ -367,7 +367,7 @@ func TestClient_ListModels(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.ListModels(ctx) if err != nil { @@ -392,7 +392,7 @@ func TestClient_ListModels(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.ListModels(ctx) if err != nil { @@ -409,7 +409,7 @@ func TestClient_ListModels(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.ListModels(ctx) if err == nil { @@ -444,7 +444,7 @@ func TestClient_GetDiskSpace(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") resp, err := client.GetDiskSpace(ctx) if err != nil { @@ -467,7 +467,7 @@ func TestClient_GetDiskSpace(t *testing.T) { })) defer server.Close() - client := NewNodeClient(server.URL, "") + client := NewNodeClient(server.URL, "", "", "") _, err := client.GetDiskSpace(ctx) if err == nil { diff --git a/decentralized-api/utils/http.go b/decentralized-api/utils/http.go index fdf1218b2..fc982f4ef 100644 --- a/decentralized-api/utils/http.go +++ b/decentralized-api/utils/http.go @@ -5,7 +5,9 @@ import ( "context" "decentralized-api/logging" "encoding/json" + "fmt" "net/http" + "strings" "time" "github.com/productscience/inference/x/inference/types" @@ -18,6 +20,11 @@ func NewHttpClient(timeout time.Duration) *http.Client { } func SendPostJsonRequest(ctx context.Context, client *http.Client, url string, payload any) (*http.Response, error) { + return SendPostJsonRequestWithAuth(ctx, client, url, payload, "") +} + +// SendPostJsonRequestWithAuth sends a POST request with JSON payload and adds Authorization header if authToken is set +func SendPostJsonRequestWithAuth(ctx context.Context, client *http.Client, url string, payload any, authToken string) (*http.Response, error) { var req *http.Request var err error @@ -26,34 +33,57 @@ func SendPostJsonRequest(ctx context.Context, client *http.Client, url string, p req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, nil) } else { // Marshal the payload to JSON. - jsonData, err := json.Marshal(payload) + var jsonData []byte + jsonData, err = json.Marshal(payload) if err != nil { return nil, err } req, err = http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") } if err != nil { return nil, err } if req == nil { - logging.Error("SendPostJsonRequest. Failed to create HTTP request", types.Server, "url", url, "payload", payload) + logging.Error("SendPostJsonRequestWithAuth. Failed to create HTTP request", types.Server, "url", url, "payload", payload) return nil, err } + if strings.TrimSpace(authToken) != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(authToken))) + } + return client.Do(req) } func SendGetRequest(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + return SendGetRequestWithAuth(ctx, client, url, "") +} + +// SendGetRequestWithAuth sends a GET request and adds Authorization header if authToken is set +func SendGetRequestWithAuth(ctx context.Context, client *http.Client, url string, authToken string) (*http.Response, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } + if strings.TrimSpace(authToken) != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(authToken))) + } + return client.Do(req) } func SendDeleteJsonRequest(ctx context.Context, client *http.Client, url string, payload any) (*http.Response, error) { + return SendDeleteJsonRequestWithAuth(ctx, client, url, payload, "") +} + +// SendDeleteJsonRequestWithAuth sends a DELETE request with JSON payload and adds Authorization header if authToken is set +func SendDeleteJsonRequestWithAuth(ctx context.Context, client *http.Client, url string, payload any, authToken string) (*http.Response, error) { var req *http.Request var err error @@ -62,7 +92,8 @@ func SendDeleteJsonRequest(ctx context.Context, client *http.Client, url string, req, err = http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) } else { // Marshal the payload to JSON. - jsonData, err := json.Marshal(payload) + var jsonData []byte + jsonData, err = json.Marshal(payload) if err != nil { return nil, err } @@ -77,9 +108,13 @@ func SendDeleteJsonRequest(ctx context.Context, client *http.Client, url string, return nil, err } if req == nil { - logging.Error("SendDeleteJsonRequest. Failed to create HTTP request", types.Server, "url", url, "payload", payload) + logging.Error("SendDeleteJsonRequestWithAuth. Failed to create HTTP request", types.Server, "url", url, "payload", payload) return nil, err } + if strings.TrimSpace(authToken) != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(authToken))) + } + return client.Do(req) } diff --git a/decentralized-api/utils/http_test.go b/decentralized-api/utils/http_test.go new file mode 100644 index 000000000..260920c1c --- /dev/null +++ b/decentralized-api/utils/http_test.go @@ -0,0 +1,71 @@ +package utils_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "decentralized-api/utils" +) + +func TestSendPostJsonRequestWithAuth(t *testing.T) { + // Test cases + tests := []struct { + name string + payload interface{} + authToken string + expectedAuth string + wantErr bool + }{ + { + name: "valid request with auth token", + payload: map[string]string{"key": "value"}, + authToken: "test-token", + expectedAuth: "Bearer test-token", + wantErr: false, + }, + { + name: "empty auth token should not send header", + payload: map[string]string{"key": "value"}, + authToken: "", + expectedAuth: "", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Setup test server for each test case + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify content type + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("unexpected content type: %v", ct) + } + + // Verify auth header matches expected + auth := r.Header.Get("Authorization") + if auth != tt.expectedAuth { + t.Errorf("unexpected auth header: got %q, want %q", auth, tt.expectedAuth) + } + + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"status":"ok"}`)) + })) + defer ts.Close() + + _, err := utils.SendPostJsonRequestWithAuth( + context.Background(), + ts.Client(), + ts.URL, + tt.payload, + tt.authToken, + ) + + if (err != nil) != tt.wantErr { + t.Errorf("SendPostJsonRequestWithAuth() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/proposals/mlnode-token-auth/README.md b/proposals/mlnode-token-auth/README.md new file mode 100644 index 000000000..cefe835a4 --- /dev/null +++ b/proposals/mlnode-token-auth/README.md @@ -0,0 +1,129 @@ +# Proposal: MLNode Token-Based Authentication and FQDN Support + +## Goal / Problem + +Current MLNode registration in the API service requires three static parameters: +- Static IP address +- PoC port (management API, default 8080) +- Inference port (vLLM inference API, default 5000) + +Both ports serve the same MLNode container through nginx proxy for version management (see `deploy/join/docker-compose.mlnode.yml` and `deploy/join/nginx.conf`). + +Problems: +- Some cloud providers (e.g., Aliyun EAS) assign new IPs on container recreation, making static IP registration impractical +- Managing two ports adds operational complexity +- Cloud providers often assign stable FQDNs with token-based authentication (e.g., `http://.ap-southeast-1.pai-eas.aliyuncs.com/api/predict/eas02/`) that remain consistent across deployments +- Current registration doesn't support using authentication tokens that managed services provide for access control + +**Note:** Segment fields (InferenceSegment, PoCSegment) are legacy parameters that are always empty in current deployments. + +## Proposal + +Support additional registration method using full URLs: + +1. Use single port (8080) since both endpoints proxy to the same container +2. Allow registration using stable baseURLs with authentication tokens instead of IP/ports + +The system will support both registration methods, allowing users to choose between IP/port configuration or baseURL-based registration + +## Implementation + +### Single-Port Operation + +Current state (see `deploy/join/nginx.conf` for nginx setup and `mlnode/packages/api/src/api/proxy.py` for internal routing logic): + +Management API (port 8080) supports: +- `http://:/api/v1/*` - management API endpoints +- `http://:/v1/*` - proxies to vLLM endpoints +- `http://:/readyz` - ready for inference +- `http://:/health` - whole service health, *not only inference* + +Inference API (port 5000, backward compatible) supports: +- `http://:/v1/*` - proxies to vLLM endpoints +- `http://:/health` - vLLM health check (proxied from vLLM backend) + +`/health` checks whole service health while `/health` checks only vLLM backend health. API node currently checks MLNode health via `client.InferenceHealth()` at `http://:/health`. New API binary must support both old MLNodes (port 5000) and new single-port configuration. + +### Solution +Use registration method to determine which health endpoint to check: +- Legacy registration (Host/Port/Segment): Check `http://:/health` +- New registration (baseURL): Check `/readyz` (management API readiness endpoint on port 8080) + +### FQDN and Token Authentication + +Current structure (`decentralized-api/apiconfig/config.go`): +```go +type InferenceNodeConfig struct { + Host string + InferenceSegment string + InferencePort int + PoCSegment string + PoCPort int + // ... other fields +} +``` + +Proposed structure for `InferenceNodeConfig` and `broker.Node`: +```go +type InferenceNodeConfig struct { + // Existing fields (preserved for backward compatibility) + Host string + InferenceSegment string // Legacy, always empty + InferencePort int + PoCSegment string // Legacy, always empty + PoCPort int + + // New optional fields (SQLite only, not stored on-chain) + BaseURL string // Optional: full URL to MLNode (e.g., "http://service.provider.com/path/") + AuthToken string // Optional: bearer token for authentication + + // ... other fields +} + +type Node struct { + // Existing fields + Host string + InferenceSegment string + InferencePort int + PoCSegment string + PoCPort int + + // New optional fields + BaseURL string + AuthToken string + + // ... other fields +} +``` + +URL construction uses baseURL when present, otherwise falls back to `http://:/`. Version insertion for rolling upgrades works identically for both approaches: `//` or `http://://`. + +baseURL and AuthToken are stored in local SQLite database only, not on-chain. This allows each API node to configure its own MLNode access methods independently. + +Required changes: + +1. Add `base_url` and `auth_token` columns to SQLite schema with empty defaults for automatic migration +2. Update URL construction methods: + - `broker.Node.InferenceUrlWithVersion()` and `broker.Node.PoCUrlWithVersion()` in `broker/broker.go` (main methods used for all inference and management calls including `/v1/chat/completions`) + - Helper functions in `mlnode_background_manager.go` + - Helper functions in `setup_report.go` for consistency +3. Add `Authorization: Bearer ` header to all MLNode requests when AuthToken is set +4. Validate registration: require either (Host+Ports) OR baseURL, not both. baseURL must be valid HTTP(S) URL. AuthToken is always optional. + + +### Testing + +1. Covered by unit tests and they pass: +- local `make local-build` +- [CICD](https://github.com/gonka-ai/gonka/actions/workflows/verify.yml) +2. Existing testermint tests pass: +- local `make build-docker && ./local-test-net/stop.sh && make run-tests` +- [Recommended] [CICD](https://github.com/gonka-ai/gonka/actions/workflows/integration.yml) +3. New node joins testnet and works with new MLNode registration + +### Backward Compatibility + +- Existing nodes using Host/Port configuration are unaffected +- baseURL and AuthToken are local SQLite configuration, not stored on-chain +- No migration needed - old and new registration methods coexist + diff --git a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/Application.kt b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/Application.kt index 6b7919224..dbc29deaf 100644 --- a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/Application.kt +++ b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/Application.kt @@ -10,10 +10,12 @@ import com.productscience.mockserver.routes.stateRoutes import com.productscience.mockserver.routes.stopRoutes import com.productscience.mockserver.routes.tokenizationRoutes import com.productscience.mockserver.routes.trainRoutes +import com.productscience.mockserver.routes.authRoutes import com.productscience.mockserver.service.ResponseService import com.productscience.mockserver.service.HostHeaderService import com.productscience.mockserver.service.TokenizationService import com.productscience.mockserver.service.WebhookService +import com.productscience.mockserver.service.AuthTokenService import io.ktor.serialization.jackson.jackson import io.ktor.server.engine.embeddedServer import io.ktor.server.netty.Netty @@ -35,6 +37,7 @@ val WebhookServiceKey = AttributeKey("WebhookService") val ResponseServiceKey = AttributeKey("ResponseService") val TokenizationServiceKey = AttributeKey("TokenizationService") val HostHeaderServiceKey = AttributeKey("HostHeaderService") +val AuthTokenServiceKey = AttributeKey("AuthTokenService") fun main() { embeddedServer(Netty, port = 8080, host = "0.0.0.0", module = Application::module) @@ -68,12 +71,14 @@ fun Application.configureServices() { val webhookService = WebhookService(responseService) val tokenizationService = TokenizationService() val hostHeaderService = HostHeaderService() + val authTokenService = AuthTokenService() // Register the services in the application's attributes attributes.put(WebhookServiceKey, webhookService) attributes.put(ResponseServiceKey, responseService) attributes.put(TokenizationServiceKey, tokenizationService) attributes.put(HostHeaderServiceKey, hostHeaderService) + attributes.put(AuthTokenServiceKey, authTokenService) } val HostHeaderRecorder = createApplicationPlugin(name = "HostHeaderRecorder") { @@ -91,6 +96,7 @@ fun Application.configureRouting() { val webhookService = attributes[WebhookServiceKey] val responseService = attributes[ResponseServiceKey] val tokenizationService = attributes[TokenizationServiceKey] + val authTokenService = attributes[AuthTokenServiceKey] routing { // Server status endpoint @@ -108,13 +114,14 @@ fun Application.configureRouting() { stateRoutes() powRoutes(webhookService) powV2Routes(webhookService) // PoC v2 (artifact-based) routes - inferenceRoutes(responseService) + inferenceRoutes(responseService, authTokenService) trainRoutes() stopRoutes() healthRoutes() responseRoutes(responseService) tokenizationRoutes(tokenizationService) fileRoutes() // Route for serving files + authRoutes(authTokenService) // Route for auth token management } } diff --git a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/AuthRoutes.kt b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/AuthRoutes.kt new file mode 100644 index 000000000..370222ec1 --- /dev/null +++ b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/AuthRoutes.kt @@ -0,0 +1,41 @@ +package com.productscience.mockserver.routes + +import com.fasterxml.jackson.annotation.JsonProperty +import com.productscience.mockserver.service.AuthTokenService +import com.productscience.mockserver.service.HostName +import io.ktor.http.* +import io.ktor.server.application.* +import io.ktor.server.request.* +import io.ktor.server.response.* +import io.ktor.server.routing.* + +data class SetAuthHeaderRequest( + @JsonProperty("expected_header") + val expectedHeader: String, + @JsonProperty("host_name") + val hostName: String? = null +) + +fun Route.authRoutes(authTokenService: AuthTokenService) { + post("/api/v1/auth/token") { + try { + val req = call.receive() + val host = req.hostName?.let { HostName(it) } + authTokenService.setExpectedHeader(host, req.expectedHeader) + call.respond(HttpStatusCode.OK, mapOf("status" to "success")) + } catch (e: Exception) { + call.respond(HttpStatusCode.BadRequest, mapOf("status" to "error", "message" to e.message)) + } + } + + post("/api/v1/auth/clear") { + try { + val req = call.receive() + val host = req.hostName?.let { HostName(it) } + authTokenService.clearExpectedHeader(host) + call.respond(HttpStatusCode.OK, mapOf("status" to "success")) + } catch (e: Exception) { + call.respond(HttpStatusCode.BadRequest, mapOf("status" to "error", "message" to e.message)) + } + } +} \ No newline at end of file diff --git a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/InferenceRoutes.kt b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/InferenceRoutes.kt index c098e7fbb..0b1ba24b6 100644 --- a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/InferenceRoutes.kt +++ b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/routes/InferenceRoutes.kt @@ -14,13 +14,15 @@ import com.productscience.mockserver.model.setModelState import com.productscience.mockserver.service.ResponseService import com.productscience.mockserver.service.SSEService import com.productscience.mockserver.service.HostName +import com.productscience.mockserver.service.AuthTokenService +import io.ktor.http.HttpHeaders import kotlinx.coroutines.delay import org.slf4j.LoggerFactory /** * Configures routes for inference-related endpoints. */ -fun Route.inferenceRoutes(responseService: ResponseService, sseService: SSEService = SSEService()) { +fun Route.inferenceRoutes(responseService: ResponseService, authTokenService: AuthTokenService, sseService: SSEService = SSEService()) { // POST /api/v1/inference/up - Transitions to INFERENCE state post("/api/v1/inference/up") { val logger = LoggerFactory.getLogger("InferenceRoutes") @@ -63,18 +65,18 @@ fun Route.inferenceRoutes(responseService: ResponseService, sseService: SSEServi // Handle the exact path /v1/chat/completions post("/v1/chat/completions") { - handleChatCompletions(call, responseService, sseService) + handleChatCompletions(call, responseService, authTokenService, sseService) } // Handle all versioned chat completions endpoints post("/{...segments}/v1/chat/completions") { - handleChatCompletions(call, responseService, sseService) + handleChatCompletions(call, responseService, authTokenService, sseService) } } /** * Handles chat completions requests. */ -private suspend fun handleChatCompletions(call: ApplicationCall, responseService: ResponseService, sseService: SSEService) { +private suspend fun handleChatCompletions(call: ApplicationCall, responseService: ResponseService, authTokenService: AuthTokenService, sseService: SSEService) { val logger = LoggerFactory.getLogger("InferenceRoutes") val objectMapper = ObjectMapper() .registerKotlinModule() @@ -86,6 +88,29 @@ private suspend fun handleChatCompletions(call: ApplicationCall, responseService // return // } + // Enforce Authorization header if configured for this host + val hostName = HostName(call.getHost()) + val expectedAuth = authTokenService.getExpectedHeader(hostName) + val authHeader = call.request.headers[HttpHeaders.Authorization] + if (expectedAuth != null) { + val requireOnly = expectedAuth == "*" + val valid = if (requireOnly) { + authHeader != null && authHeader.isNotBlank() + } else { + authHeader == expectedAuth + } + if (!valid) { + logger.warn("Unauthorized request for host {}. Authorization header mismatch (expected={}, got={})", hostName.name, expectedAuth != null, authHeader != null) + call.response.header("Content-Type", "application/json") + call.respondText( + """{"error":"unauthorized","message":"Authorization header missing or invalid"}""", + ContentType.Application.Json, + HttpStatusCode.Unauthorized + ) + return + } + } + // Get the request body val requestBody = call.receiveText() logger.info("Received chat completion request for path: ${call.request.path()}") @@ -112,7 +137,6 @@ private suspend fun handleChatCompletions(call: ApplicationCall, responseService val path = call.request.path() // Get the response configuration from the ResponseService (per-host) - val hostName = HostName(call.getHost()) val responseConfig = responseService.getInferenceResponseConfig(path, model, hostName) logger.info("Retrieved response config for path $path: ${responseConfig != null}") diff --git a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/AuthTokenService.kt b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/AuthTokenService.kt new file mode 100644 index 000000000..027838a4e --- /dev/null +++ b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/AuthTokenService.kt @@ -0,0 +1,25 @@ +package com.productscience.mockserver.service + +import org.slf4j.LoggerFactory +import java.util.concurrent.ConcurrentHashMap +import com.productscience.mockserver.service.HostName + +class AuthTokenService { + private val logger = LoggerFactory.getLogger(AuthTokenService::class.java) + + private val expectedHeaders = ConcurrentHashMap() + + fun setExpectedHeader(host: HostName?, header: String) { + val key = host ?: HostName("localhost") + expectedHeaders[key] = header + logger.info("Auth expected header set for host {}", key.name) + } + + fun clearExpectedHeader(host: HostName?) { + val key = host ?: HostName("localhost") + expectedHeaders.remove(key) + logger.info("Auth expected header cleared for host {}", key.name) + } + + fun getExpectedHeader(host: HostName): String? = expectedHeaders[host] +} \ No newline at end of file diff --git a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/HostName.kt b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/HostName.kt new file mode 100644 index 000000000..6ffdc18ce --- /dev/null +++ b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/HostName.kt @@ -0,0 +1,4 @@ +package com.productscience.mockserver.service + +@JvmInline +value class HostName(val name: String) \ No newline at end of file diff --git a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/ResponseService.kt b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/ResponseService.kt index c691653e7..ea5c4f2ed 100644 --- a/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/ResponseService.kt +++ b/testermint/mock_server/src/main/kotlin/com/productscience/mockserver/service/ResponseService.kt @@ -33,9 +33,6 @@ sealed class ResponseConfig { @JvmInline value class ScenarioName(val name: String) -@JvmInline -value class HostName(val name: String) - @JvmInline value class Endpoint(val path: String) diff --git a/testermint/src/main/kotlin/MockServerInferenceMock.kt b/testermint/src/main/kotlin/MockServerInferenceMock.kt index 92141e701..0ed8096c8 100644 --- a/testermint/src/main/kotlin/MockServerInferenceMock.kt +++ b/testermint/src/main/kotlin/MockServerInferenceMock.kt @@ -102,6 +102,27 @@ class MockServerInferenceMock(private val baseUrl: String, val name: String) : I } } + fun setExpectedAuthorizationHeader(expectedHeader: String, hostName: String? = null) { + data class SetAuthHeaderRequest( + val expected_header: String, + val host_name: String? = null + ) + + val request = SetAuthHeaderRequest(expectedHeader, hostName) + try { + val (_, response, _) = Fuel.post("$baseUrl/api/v1/auth/token") + .jsonBody(cosmosJson.toJson(request)) + .responseString() + if (response.statusCode != 200) { + Logger.error("Failed to set expected Authorization header: ${response.statusCode} ${response.responseMessage}") + } else { + Logger.debug("Set expected Authorization header: $expectedHeader for host=$hostName") + } + } catch (e: Exception) { + Logger.error("Failed to set expected Authorization header: ${e.message}") + } + } + /** * Sets the response for the inference endpoint using an OpenAIResponse object. diff --git a/testermint/src/main/kotlin/data/inferenceNodes.kt b/testermint/src/main/kotlin/data/inferenceNodes.kt index 103bab444..706fa794f 100644 --- a/testermint/src/main/kotlin/data/inferenceNodes.kt +++ b/testermint/src/main/kotlin/data/inferenceNodes.kt @@ -14,6 +14,8 @@ data class InferenceNode( val nodeNum: Long? = null, val hardware: List? = null, val version: String? = null, + val baseUrl: String? = null, + val authToken: String? = null, ) { val pocHost: String get() = "$host:$pocPort" @@ -76,4 +78,4 @@ data class MlNodeVersionQueryResponse( data class MlNodeVersion( val currentVersion: String, -) \ No newline at end of file +) diff --git a/testermint/src/test/kotlin/AuthTokenFlowTests.kt b/testermint/src/test/kotlin/AuthTokenFlowTests.kt new file mode 100644 index 000000000..58941cde2 --- /dev/null +++ b/testermint/src/test/kotlin/AuthTokenFlowTests.kt @@ -0,0 +1,77 @@ +import com.productscience.* +import com.productscience.data.InferenceStatus +import com.productscience.data.InferencePayload +import org.assertj.core.api.Assertions.assertThat +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.Timeout +import org.tinylog.kotlin.Logger +import java.time.Duration +import java.util.UUID +import java.util.concurrent.TimeUnit + +@Timeout(value = 15, unit = TimeUnit.MINUTES) +class AuthTokenFlowTests : TestermintTest() { + + @Test + fun `full flow with auth tokens`() { + logSection("Setup cluster with auth tokens") + val (cluster, genesis) = initCluster(reboot = true) + val authToken = "test-auth-token-${UUID.randomUUID()}" + + logSection("Re-register MLNodes with auth token") + cluster.allPairs.forEach { pair -> + val existingNodes = pair.api.getNodes() + if (existingNodes.isNotEmpty()) { + val existingNode = existingNodes.first().node + val nodeWithAuth = existingNode.copy( + authToken = authToken + ) + pair.waitForNextInferenceWindow(windowSizeInBlocks = 5) + pair.api.setNodesTo(nodeWithAuth) + pair.waitForMlNodesToLoad() + } + } + + logSection("Configure mock servers to require Authorization header") + cluster.allPairs.forEach { pair -> + val nodes = pair.api.getNodes() + val hostName = nodes.firstOrNull()?.node?.inferenceHost + (pair.mock as? MockServerInferenceMock)?.setExpectedAuthorizationHeader("Bearer $authToken", hostName) + } + + logSection("Wait for first PoC and verify validators") + genesis.waitForStage(EpochStage.SET_NEW_VALIDATORS) + val validators = genesis.node.getValidators().validators + assertThat(validators).isNotEmpty() + validators.forEach { v -> + assertThat(v.tokens).isGreaterThan(0) + Logger.info("Validator ${v.consensusPubkey.value} has tokens ${v.tokens}") + } + + logSection("Run inference in inference phase") + genesis.waitForNextInferenceWindow() + + // Ensure mocks respond successfully (optional; default is OK) + cluster.allPairs.forEach { it.mock?.setInferenceResponse(defaultInferenceResponseObject, Duration.ofMillis(0)) } + + // Simplified inference test - just check that we can make requests with auth tokens + logSection("Verify auth token functionality") + val nodes = genesis.api.getNodes() + assertThat(nodes).isNotEmpty() + + // Check that at least one node has the auth token configured + val nodeWithAuth = nodes.firstOrNull { it.node.authToken != null } + assertThat(nodeWithAuth).isNotNull() + assertThat(nodeWithAuth?.node?.authToken).isEqualTo(authToken) + + logSection("Wait for PoC to end and claim rewards") + genesis.waitForStage(EpochStage.CLAIM_REWARDS) + + logSection("Verify rewards were received") + val updatedParticipants = genesis.api.getParticipants() + updatedParticipants.forEach { participant -> + assertThat(participant.coinsOwed).isGreaterThanOrEqualTo(0) + Logger.info("Participant ${participant.id} earned ${participant.coinsOwed}") + } + } +} \ No newline at end of file