Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions decentralized-api/apiconfig/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package apiconfig

import (
"fmt"
"net/url"
"strings"
)

Expand Down Expand Up @@ -90,6 +91,8 @@ type InferenceNodeConfig struct {
InferencePort int `koanf:"inference_port" json:"inference_port"`
PoCSegment string `koanf:"poc_segment" json:"poc_segment"`
PoCPort int `koanf:"poc_port" json:"poc_port"`
BaseURL string `koanf:"base_url" json:"base_url"`
AuthToken string `koanf:"auth_token" json:"auth_token"`
Models map[string]ModelConfig `koanf:"models" json:"models"`
Id string `koanf:"id" json:"id"`
MaxConcurrent int `koanf:"max_concurrent" json:"max_concurrent"`
Expand All @@ -102,21 +105,46 @@ type InferenceNodeConfig struct {
func ValidateInferenceNodeBasic(node InferenceNodeConfig) []string {
var errors []string

// Validate required fields
if strings.TrimSpace(node.Id) == "" {
errors = append(errors, "node id is required and cannot be empty")
// Validate host/baseURL configuration
// ensures that a node configuration uses either the legacy host/port registration or the new baseURL registration, but not both.
// When baseURL is provided, it must be a valid HTTP(S) URL. AuthToken is always optional (no validation needed)
hasHostPorts := strings.TrimSpace(node.Host) != "" || node.InferencePort > 0 || node.PoCPort > 0
hasBaseURL := strings.TrimSpace(node.BaseURL) != ""

if hasHostPorts && hasBaseURL {
errors = append(errors, "node configuration error: cannot specify both (Host+Ports) and baseURL. Use either Host+InferencePort+PoCPort OR baseURL")
}

if strings.TrimSpace(node.Host) == "" {
errors = append(errors, "host is required and cannot be empty")
if !hasHostPorts && !hasBaseURL {
errors = append(errors, "node configuration error: must specify either (Host+InferencePort+PoCPort) OR baseURL")
}

if node.InferencePort <= 0 || node.InferencePort > 65535 {
errors = append(errors, fmt.Sprintf("inference_port must be between 1 and 65535, got %d", node.InferencePort))
if hasHostPorts {
if strings.TrimSpace(node.Host) == "" {
errors = append(errors, "host is required and cannot be empty when using host+port registration")
}

if node.InferencePort <= 0 || node.InferencePort > 65535 {
errors = append(errors, fmt.Sprintf("inference_port must be between 1 and 65535, got %d", node.InferencePort))
}

if node.PoCPort <= 0 || node.PoCPort > 65535 {
errors = append(errors, fmt.Sprintf("poc_port must be between 1 and 65535, got %d", node.PoCPort))
}
}

if hasBaseURL {
parsedURL, err := url.Parse(node.BaseURL)
if err != nil {
errors = append(errors, fmt.Sprintf("node configuration error: baseURL is not a valid URL: %v", err))
} else if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
errors = append(errors, fmt.Sprintf("node configuration error: baseURL must use http:// or https:// scheme, got: %s", parsedURL.Scheme))
}
}

if node.PoCPort <= 0 || node.PoCPort > 65535 {
errors = append(errors, fmt.Sprintf("poc_port must be between 1 and 65535, got %d", node.PoCPort))
// Validate required fields
if strings.TrimSpace(node.Id) == "" {
errors = append(errors, "node id is required and cannot be empty")
}

if node.MaxConcurrent <= 0 {
Expand Down
82 changes: 74 additions & 8 deletions decentralized-api/apiconfig/sqlite_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ CREATE TABLE IF NOT EXISTS inference_nodes (
max_concurrent INTEGER NOT NULL,
models_json TEXT NOT NULL,
hardware_json TEXT NOT NULL,
base_url TEXT NOT NULL DEFAULT '',
auth_token TEXT NOT NULL DEFAULT '',
updated_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f','now')),
created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f','now'))
);
Expand All @@ -104,8 +106,62 @@ CREATE TABLE IF NOT EXISTS seed_info (
is_active BOOLEAN NOT NULL DEFAULT 1,
created_at DATETIME NOT NULL DEFAULT (STRFTIME('%Y-%m-%d %H:%M:%f','now'))
);`
_, err := db.ExecContext(ctx, stmt)
return err
if _, err := db.ExecContext(ctx, stmt); err != nil {
return err
}

// Migrate existing tables: add base_url and auth_token columns if they don't exist
return migrateInferenceNodesTable(ctx, db)
}

// migrateInferenceNodesTable adds base_url and auth_token columns to existing inference_nodes table
func migrateInferenceNodesTable(ctx context.Context, db *sql.DB) error {
// Check if base_url column exists
rows, err := db.QueryContext(ctx, "PRAGMA table_info(inference_nodes)")
if err != nil {
return err
}
defer rows.Close()

var hasBaseURL, hasAuthToken bool
for rows.Next() {
var (
cid int
name string
declType string
notnull int
dflt sql.NullString
pk int
)
if err := rows.Scan(&cid, &name, &declType, &notnull, &dflt, &pk); err != nil {
return err
}
if name == "base_url" {
hasBaseURL = true
}
if name == "auth_token" {
hasAuthToken = true
}
}
if err := rows.Err(); err != nil {
return err
}

// Add base_url column if it doesn't exist
if !hasBaseURL {
if _, err := db.ExecContext(ctx, "ALTER TABLE inference_nodes ADD COLUMN base_url TEXT NOT NULL DEFAULT ''"); err != nil {
return fmt.Errorf("failed to add base_url column: %w", err)
}
}

// Add auth_token column if it doesn't exist
if !hasAuthToken {
if _, err := db.ExecContext(ctx, "ALTER TABLE inference_nodes ADD COLUMN auth_token TEXT NOT NULL DEFAULT ''"); err != nil {
return fmt.Errorf("failed to add auth_token column: %w", err)
}
}

return nil
}

// UpsertInferenceNodes replaces or inserts the given nodes by id.
Expand All @@ -121,8 +177,8 @@ func UpsertInferenceNodes(ctx context.Context, db *sql.DB, nodes []InferenceNode

q := `
INSERT INTO inference_nodes (
id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json, base_url, auth_token
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
host = excluded.host,
inference_segment = excluded.inference_segment,
Expand All @@ -132,6 +188,8 @@ ON CONFLICT(id) DO UPDATE SET
max_concurrent = excluded.max_concurrent,
models_json = excluded.models_json,
hardware_json = excluded.hardware_json,
base_url = excluded.base_url,
auth_token = excluded.auth_token,
updated_at = (STRFTIME('%Y-%m-%d %H:%M:%f','now'))`

stmt, err := tx.PrepareContext(ctx, q)
Expand Down Expand Up @@ -160,6 +218,8 @@ ON CONFLICT(id) DO UPDATE SET
n.MaxConcurrent,
string(modelsJSON),
string(hardwareJSON),
n.BaseURL,
n.AuthToken,
); err != nil {
return err
}
Expand All @@ -175,7 +235,7 @@ func WriteNodes(ctx context.Context, db *sql.DB, nodes []InferenceNodeConfig) er
// ReadNodes reads all nodes from the database and reconstructs InferenceNodeConfig entries.
func ReadNodes(ctx context.Context, db *sql.DB) ([]InferenceNodeConfig, error) {
rows, err := db.QueryContext(ctx, `
SELECT id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json
SELECT id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json, base_url, auth_token
FROM inference_nodes ORDER BY id`)
if err != nil {
return nil, err
Expand All @@ -194,8 +254,10 @@ FROM inference_nodes ORDER BY id`)
maxConc int
modelsRaw []byte
hardwareRaw []byte
baseURL string
authToken string
)
if err := rows.Scan(&id, &host, &infSeg, &infPort, &pocSeg, &pocPort, &maxConc, &modelsRaw, &hardwareRaw); err != nil {
if err := rows.Scan(&id, &host, &infSeg, &infPort, &pocSeg, &pocPort, &maxConc, &modelsRaw, &hardwareRaw, &baseURL, &authToken); err != nil {
return nil, err
}
var models map[string]ModelConfig
Expand All @@ -216,6 +278,8 @@ FROM inference_nodes ORDER BY id`)
InferencePort: infPort,
PoCSegment: pocSeg,
PoCPort: pocPort,
BaseURL: baseURL,
AuthToken: authToken,
Models: models,
Id: id,
MaxConcurrent: maxConc,
Expand Down Expand Up @@ -246,8 +310,8 @@ func ReplaceInferenceNodes(ctx context.Context, db *sql.DB, nodes []InferenceNod

q := `
INSERT INTO inference_nodes (
id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`
id, host, inference_segment, inference_port, poc_segment, poc_port, max_concurrent, models_json, hardware_json, base_url, auth_token
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`

stmt, err := tx.PrepareContext(ctx, q)
if err != nil {
Expand Down Expand Up @@ -275,6 +339,8 @@ INSERT INTO inference_nodes (
n.MaxConcurrent,
string(modelsJSON),
string(hardwareJSON),
n.BaseURL,
n.AuthToken,
); err != nil {
return err
}
Expand Down
68 changes: 62 additions & 6 deletions decentralized-api/broker/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,33 +195,89 @@ type Node struct {
InferencePort int `json:"inference_port"`
PoCSegment string `json:"poc_segment"`
PoCPort int `json:"poc_port"`
BaseURL string `json:"base_url"`
AuthToken string `json:"auth_token"`
Models map[string]ModelArgs `json:"models"`
Id string `json:"id"`
MaxConcurrent int `json:"max_concurrent"`
NodeNum uint64 `json:"node_num"`
Hardware []apiconfig.Hardware `json:"hardware"`
}

type MlNodePathElements struct {
Host string `json:"host"`
Port int `json:"port"`
BaseURL string `json:"base_url"`
Version string `json:"version"`
Segment string `json:"segment"`
}

func GetMlNodeUrl(elements MlNodePathElements) string {
// If BaseURL is provided, build on top of it
if strings.TrimSpace(elements.BaseURL) != "" {
base := strings.TrimRight(elements.BaseURL, "/")
if strings.TrimSpace(elements.Version) == "" {
return fmt.Sprintf("%s%s", base, elements.Segment)
}
return fmt.Sprintf("%s/%s%s", base, strings.TrimSpace(elements.Version), elements.Segment)
}
if strings.TrimSpace(elements.Version) == "" {
return fmt.Sprintf("http://%s:%d%s", elements.Host, elements.Port, elements.Segment)
}
return fmt.Sprintf("http://%s:%d/%s%s", elements.Host, elements.Port, strings.TrimSpace(elements.Version), elements.Segment)
}

func (n *Node) InferenceUrl() string {
return fmt.Sprintf("http://%s:%d%s", n.Host, n.InferencePort, n.InferenceSegment)
}

func (n *Node) InferenceUrlWithVersion(version string) string {
if version == "" {
v := strings.TrimSpace(version)
// If BaseURL is provided, build on top of it
if n.BaseURL != "" {
base := strings.TrimRight(n.BaseURL, "/")
if v == "" {
return fmt.Sprintf("%s%s", base, n.InferenceSegment)
}
return fmt.Sprintf("%s/%s%s", base, v, n.InferenceSegment)
}
if v == "" {
return n.InferenceUrl()
}
return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.InferencePort, version, n.InferenceSegment)
return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.InferencePort, v, n.InferenceSegment)
}

func (n *Node) PoCUrl() string {
return fmt.Sprintf("http://%s:%d%s", n.Host, n.PoCPort, n.PoCSegment)
}

func (n *Node) PoCUrlWithVersion(version string) string {
if version == "" {
v := strings.TrimSpace(version)
// If BaseURL is provided, build on top of it
if n.BaseURL != "" {
base := strings.TrimRight(n.BaseURL, "/")
if v == "" {
return fmt.Sprintf("%s%s", base, n.PoCSegment)
}
return fmt.Sprintf("%s/%s%s", base, v, n.PoCSegment)
}
if v == "" {
return n.PoCUrl()
}
return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.PoCPort, version, n.PoCSegment)
return fmt.Sprintf("http://%s:%d/%s%s", n.Host, n.PoCPort, v, n.PoCSegment)
}

// BaseUrlWithVersion constructs a base URL with version
func BaseUrlWithVersion(baseURL, version string) string {
base := strings.TrimRight(baseURL, "/")
if strings.TrimSpace(version) != "" {
return fmt.Sprintf("%s/%s", base, strings.TrimSpace(version))
}
return base
}

func (n *Node) BaseUrlWithVersion(version string) string {
return BaseUrlWithVersion(n.BaseURL, version)
}

type NodeWithState struct {
Expand Down Expand Up @@ -500,7 +556,7 @@ func (b *Broker) QueueMessage(command Command) error {

func (b *Broker) NewNodeClient(node *Node) mlnodeclient.MLNodeClient {
version := b.configManager.GetCurrentNodeVersion()
return b.mlNodeClientFactory.CreateClient(node.PoCUrlWithVersion(version), node.InferenceUrlWithVersion(version))
return b.mlNodeClientFactory.CreateClient(node.PoCUrlWithVersion(version), node.InferenceUrlWithVersion(version), node.AuthToken, node.BaseUrlWithVersion(version))
}

func (b *Broker) lockAvailableNode(command LockAvailableNode) {
Expand Down Expand Up @@ -664,7 +720,7 @@ func (b *Broker) GetNodes() ([]NodeResponse, error) {
nodes := <-command.Response

if nodes == nil {
return nil, errors.New("Error getting nodes")
return nil, errors.New("error getting nodes")
}
logging.Debug("Got nodes", types.Nodes, "size", len(nodes))
return nodes, nil
Expand Down
Loading
Loading