From aa69c042ff2b4966c76ff89792a6f89de8c93312 Mon Sep 17 00:00:00 2001 From: Zhongxin Guo Date: Wed, 17 Dec 2025 02:26:42 -0800 Subject: [PATCH 1/2] support assign job name --- src/model-proxy/src/proxy/authenticator.go | 17 +++- src/model-proxy/src/proxy/model_server.go | 45 +++++++---- src/model-proxy/src/proxy/proxy.go | 94 +++++++++++----------- src/model-proxy/src/types/config_types.go | 8 +- 4 files changed, 96 insertions(+), 68 deletions(-) diff --git a/src/model-proxy/src/proxy/authenticator.go b/src/model-proxy/src/proxy/authenticator.go index bacc1874..c6eb4877 100644 --- a/src/model-proxy/src/proxy/authenticator.go +++ b/src/model-proxy/src/proxy/authenticator.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "strings" + "time" "modelproxy/types" ) @@ -19,9 +20,13 @@ func obfuscateToken(token string) string { return token[:3] + "***" + token[len(token)-3:] } +const RefreshIntervalSeconds = 600 // 10 minutes + type RestServerAuthenticator struct { // rest-server token => model names => model service list - tokenToModels map[string]map[string][]*types.BaseSpec + tokenToModels map[string]map[string][]*types.BaseSpec + tokenToJobModels map[string]map[string][]*types.BaseSpec + tokenUpdatedTime map[string]int64 } func NewRestServerAuthenticator() *RestServerAuthenticator { @@ -36,6 +41,7 @@ func (ra *RestServerAuthenticator) UpdateTokenModels(token string, model2Service ra.tokenToModels = make(map[string]map[string][]*types.BaseSpec) } ra.tokenToModels[token] = model2Service + ra.tokenUpdatedTime[token] = time.Now().Unix() } // Check if the request is authenticated and return the available model urls @@ -49,15 +55,20 @@ func (ra *RestServerAuthenticator) AuthenticateReq(req *http.Request, reqBody ma return false, nil } availableModels, ok := ra.tokenToModels[token] - if !ok { + tokenLastUpdated, timeOk := ra.tokenUpdatedTime[token] + + // If the token is not found or the token info is older than RefreshIntervalSeconds, refresh it + if !ok || !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds { // request to RestServer to get the models log.Printf("[-] Error: token %s not found in the authenticator\n", obfuscateToken(token)) - availableModels, err := GetJobModelsMapping(req) + availableModels, err := ListJobModelsMapping(req) if err != nil { log.Printf("[-] Error: failed to get models for token %s: %v\n", obfuscateToken(token), err) return false, nil } ra.tokenToModels[token] = availableModels + ra.tokenUpdatedTime[token] = time.Now().Unix() + log.Printf("[*] Refreshed models for token %s: %v\n", obfuscateToken(token), availableModels) } if len(availableModels) == 0 { log.Printf("[-] Error: no models found") diff --git a/src/model-proxy/src/proxy/model_server.go b/src/model-proxy/src/proxy/model_server.go index 6e5f6c44..3a4d42ef 100644 --- a/src/model-proxy/src/proxy/model_server.go +++ b/src/model-proxy/src/proxy/model_server.go @@ -74,7 +74,6 @@ func ListInferenceJobs(restServerUrl string, restServerToken string) ([]string, result := make([]string, 0, len(jobs)) for _, j := range jobs { - jobId := fmt.Sprintf("%s~%s", j.Username, j.Name) result = append(result, jobId) } @@ -213,17 +212,17 @@ func listModels(jobServerUrl string, modelApiKey string) ([]string, error) { } // return JobURL => models -func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error) { +func ListJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error) { // modelName to job server url list - mapping := make(map[string][]*types.BaseSpec) + modelJobMapping := make(map[string][]*types.BaseSpec) if req == nil || req.Host == "" { - return mapping, fmt.Errorf("invalid request or empty host") + return modelJobMapping, fmt.Errorf("invalid request or empty host") } // get rest server base url from the os environment variable restServerUrl := os.Getenv("REST_SERVER_URI") if restServerUrl == "" { - return mapping, fmt.Errorf("REST_SERVER_URI environment variable is not set") + return modelJobMapping, fmt.Errorf("REST_SERVER_URI environment variable is not set") } // Ensure restServerUrl doesn't end with slash restServerUrl = strings.TrimRight(restServerUrl, "/") @@ -231,12 +230,13 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error restServerToken := req.Header.Get("Authorization") jobIDs, err := ListInferenceJobs(restServerUrl, restServerToken) if err != nil { - return mapping, fmt.Errorf("failed to list model serving jobs: %w", err) + return modelJobMapping, fmt.Errorf("failed to list model serving jobs: %w", err) } // Channel to collect results - type modelMapping struct { + type ModelEndpoint struct { modelName string + jobID string modelService *types.BaseSpec } concurrency, err := strconv.Atoi(os.Getenv("FETCH_JOB_CONCURRENCY")) @@ -244,7 +244,7 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error log.Printf("[-] Error: invalid FETCH_JOB_CONCURRENCY value: %s\n", err) concurrency = 10 // default value } - results := make(chan modelMapping, concurrency) // Buffer for potential models + allModelEndpoints := make(chan ModelEndpoint, concurrency) // Buffer for potential models // Use a wait group to run jobs in parallel var wg sync.WaitGroup @@ -287,13 +287,17 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error return } + userName := strings.Split(jobId, "~")[0] + jobName := strings.Split(jobId, "~")[1] // Send results to channel for _, model := range models { - results <- modelMapping{ + allModelEndpoints <- ModelEndpoint{ modelName: model, modelService: &types.BaseSpec{ - URL: jobServerUrl, - Key: apiKey, + URL: jobServerUrl, + Key: apiKey, + JobName: jobName, + UserName: userName, }, } } @@ -303,16 +307,23 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error // Close the results channel when all goroutines are done go func() { wg.Wait() - close(results) + close(allModelEndpoints) }() // Collect results from channel - for result := range results { - if _, ok := mapping[result.modelName]; !ok { - mapping[result.modelName] = make([]*types.BaseSpec, 0) + for result := range allModelEndpoints { + if _, ok := modelJobMapping[result.modelName]; !ok { + modelJobMapping[result.modelName] = make([]*types.BaseSpec, 0) + } + modelJobMapping[result.modelName] = append(modelJobMapping[result.modelName], result.modelService) + jobModelName := fmt.Sprintf("%s@%s", result.modelService.JobName, result.modelName) + + // Also map jobName@modelName to the model service which allows users to specify the job name in the model field + if _, ok := modelJobMapping[jobModelName]; !ok { + modelJobMapping[jobModelName] = make([]*types.BaseSpec, 0) } - mapping[result.modelName] = append(mapping[result.modelName], result.modelService) + modelJobMapping[jobModelName] = append(modelJobMapping[jobModelName], result.modelService) } - return mapping, nil + return modelJobMapping, nil } diff --git a/src/model-proxy/src/proxy/proxy.go b/src/model-proxy/src/proxy/proxy.go index a520beac..21b5c512 100644 --- a/src/model-proxy/src/proxy/proxy.go +++ b/src/model-proxy/src/proxy/proxy.go @@ -73,6 +73,54 @@ func NewProxyHandler(config *types.Config) *ProxyHandler { } } +// listAllModels list all models from the rest server and return the models in OpenAI style response +func (ph *ProxyHandler) listAllModels(w http.ResponseWriter, r *http.Request) string { + log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr) + model2Service, err := ListJobModelsMapping(r) + if err != nil { + log.Printf("[-] Error: failed to get models mapping: %v\n", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + // Update the ph.authenticator + token := r.Header.Get("Authorization") + token = strings.Replace(token, "Bearer ", "", 1) + ph.authenticator.UpdateTokenModels(token, model2Service) + + // convert models list to OpenAI style list and write it to w + ids := make([]string, 0, len(model2Service)) + for id := range model2Service { + ids = append(ids, id) + } + sort.Strings(ids) + + list := map[string]interface{}{ + "object": "list", + "data": make([]map[string]interface{}, 0, len(ids)), + } + for _, id := range ids { + item := map[string]interface{}{ + "id": id, + "object": "model", + // intentionally not including created or owned_by + } + list["data"] = append(list["data"].([]map[string]interface{}), item) + } + + out, err := json.Marshal(list) + if err != nil { + log.Printf("[-] Error: failed to marshal models list: %v\n", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write(out); err != nil { + log.Printf("[-] Error: failed to write response: %v\n", err) + } + // We've handled the response, do not continue proxying this request + return "" +} + // ReverseProxyHandler act as a reverse proxy, it will redirect the request to the destination website and return the response func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Request) (string, []string, bool) { log.Printf("[*] receive a request: %s %s\n", r.Method, r.URL.String()) @@ -88,51 +136,7 @@ func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Reque // handle /v1/models if r.URL.Path == "/v1/models" { - log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr) - model2Service, err := GetJobModelsMapping(r) - if err != nil { - log.Printf("[-] Error: failed to get models mapping: %v\n", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return "", nil, false - } - // Update the ph.authenticator - token := r.Header.Get("Authorization") - token = strings.Replace(token, "Bearer ", "", 1) - ph.authenticator.UpdateTokenModels(token, model2Service) - - // convert models list to OpenAI style list and write it to w - ids := make([]string, 0, len(model2Service)) - for id := range model2Service { - ids = append(ids, id) - } - sort.Strings(ids) - - list := map[string]interface{}{ - "object": "list", - "data": make([]map[string]interface{}, 0, len(ids)), - } - for _, id := range ids { - item := map[string]interface{}{ - "id": id, - "object": "model", - // intentionally not including created or owned_by - } - list["data"] = append(list["data"].([]map[string]interface{}), item) - } - - out, err := json.Marshal(list) - if err != nil { - log.Printf("[-] Error: failed to marshal models list: %v\n", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return "", nil, false - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write(out); err != nil { - log.Printf("[-] Error: failed to write response: %v\n", err) - } - // We've handled the response, do not continue proxying this request + ph.listAllModels(w, r) return "", nil, false } log.Printf("[*] receive a request from %s\n", r.RemoteAddr) diff --git a/src/model-proxy/src/types/config_types.go b/src/model-proxy/src/types/config_types.go index b8b7f07f..3409d7a3 100644 --- a/src/model-proxy/src/types/config_types.go +++ b/src/model-proxy/src/types/config_types.go @@ -48,9 +48,11 @@ type AzureStorage struct { // BaseSpec is the base spec for azure and openai type BaseSpec struct { - URL string `json:"url"` - Key string `json:"key"` - Version string `json:"version,omitempty"` + URL string `json:"url"` + Key string `json:"key"` + Version string `json:"version,omitempty"` + JobName string `json:"job_name,omitempty"` + UserName string `json:"user_name,omitempty"` } // ParseConfig parse the config file into Config struct From a351680f56ae47d34019c6779f02edef282d37ef Mon Sep 17 00:00:00 2001 From: Zhongxin Guo Date: Wed, 17 Dec 2025 18:55:32 -0800 Subject: [PATCH 2/2] fix copilot comments --- src/model-proxy/src/proxy/authenticator.go | 30 ++++++++++++---- src/model-proxy/src/proxy/model_server.go | 1 - src/model-proxy/src/proxy/proxy.go | 42 +++++++++++++--------- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/src/model-proxy/src/proxy/authenticator.go b/src/model-proxy/src/proxy/authenticator.go index c6eb4877..fa3ed8f7 100644 --- a/src/model-proxy/src/proxy/authenticator.go +++ b/src/model-proxy/src/proxy/authenticator.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "strings" + "sync" "time" "modelproxy/types" @@ -21,25 +22,30 @@ func obfuscateToken(token string) string { } const RefreshIntervalSeconds = 600 // 10 minutes - type RestServerAuthenticator struct { // rest-server token => model names => model service list tokenToModels map[string]map[string][]*types.BaseSpec - tokenToJobModels map[string]map[string][]*types.BaseSpec tokenUpdatedTime map[string]int64 + mu sync.RWMutex } func NewRestServerAuthenticator() *RestServerAuthenticator { return &RestServerAuthenticator{ - tokenToModels: make(map[string]map[string][]*types.BaseSpec), + tokenToModels: make(map[string]map[string][]*types.BaseSpec), + tokenUpdatedTime: make(map[string]int64), } } // UpdateTokenModels updates the model mapping for a given token func (ra *RestServerAuthenticator) UpdateTokenModels(token string, model2Service map[string][]*types.BaseSpec) { + ra.mu.Lock() + defer ra.mu.Unlock() if ra.tokenToModels == nil { ra.tokenToModels = make(map[string]map[string][]*types.BaseSpec) } + if ra.tokenUpdatedTime == nil { + ra.tokenUpdatedTime = make(map[string]int64) + } ra.tokenToModels[token] = model2Service ra.tokenUpdatedTime[token] = time.Now().Unix() } @@ -54,20 +60,30 @@ func (ra *RestServerAuthenticator) AuthenticateReq(req *http.Request, reqBody ma log.Printf("[-] Error: 'model' field missing or not a string in request body") return false, nil } + + ra.mu.RLock() availableModels, ok := ra.tokenToModels[token] tokenLastUpdated, timeOk := ra.tokenUpdatedTime[token] + ra.mu.RUnlock() + if !ok { + log.Printf("[-] Error: token %s not found in the authenticator\n", obfuscateToken(token)) + } + if !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds { + log.Printf("[-] Error: token %s info is outdated in the authenticator\n", obfuscateToken(token)) + } // If the token is not found or the token info is older than RefreshIntervalSeconds, refresh it + if !ok || !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds { // request to RestServer to get the models - log.Printf("[-] Error: token %s not found in the authenticator\n", obfuscateToken(token)) - availableModels, err := ListJobModelsMapping(req) + freshModels, err := ListJobModelsMapping(req) if err != nil { log.Printf("[-] Error: failed to get models for token %s: %v\n", obfuscateToken(token), err) return false, nil } - ra.tokenToModels[token] = availableModels - ra.tokenUpdatedTime[token] = time.Now().Unix() + ra.UpdateTokenModels(token, freshModels) + + availableModels = freshModels log.Printf("[*] Refreshed models for token %s: %v\n", obfuscateToken(token), availableModels) } if len(availableModels) == 0 { diff --git a/src/model-proxy/src/proxy/model_server.go b/src/model-proxy/src/proxy/model_server.go index 3a4d42ef..b9071b69 100644 --- a/src/model-proxy/src/proxy/model_server.go +++ b/src/model-proxy/src/proxy/model_server.go @@ -236,7 +236,6 @@ func ListJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, erro // Channel to collect results type ModelEndpoint struct { modelName string - jobID string modelService *types.BaseSpec } concurrency, err := strconv.Atoi(os.Getenv("FETCH_JOB_CONCURRENCY")) diff --git a/src/model-proxy/src/proxy/proxy.go b/src/model-proxy/src/proxy/proxy.go index 21b5c512..619eb451 100644 --- a/src/model-proxy/src/proxy/proxy.go +++ b/src/model-proxy/src/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "log" "net/http" @@ -74,12 +75,13 @@ func NewProxyHandler(config *types.Config) *ProxyHandler { } // listAllModels list all models from the rest server and return the models in OpenAI style response -func (ph *ProxyHandler) listAllModels(w http.ResponseWriter, r *http.Request) string { +func (ph *ProxyHandler) listAllModels(r *http.Request) ([]byte, error) { log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr) - model2Service, err := ListJobModelsMapping(r) - if err != nil { - log.Printf("[-] Error: failed to get models mapping: %v\n", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) + model2Service, internalErr := ListJobModelsMapping(r) + if internalErr != nil { + errorMsg := fmt.Sprintf("[-] Error: failed to list models: %v\n", internalErr) + internalErr = errors.New(errorMsg) + return nil, internalErr } // Update the ph.authenticator token := r.Header.Get("Authorization") @@ -106,19 +108,14 @@ func (ph *ProxyHandler) listAllModels(w http.ResponseWriter, r *http.Request) st list["data"] = append(list["data"].([]map[string]interface{}), item) } - out, err := json.Marshal(list) - if err != nil { - log.Printf("[-] Error: failed to marshal models list: %v\n", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) + out, internalErr := json.Marshal(list) + if internalErr != nil { + errorMsg := fmt.Sprintf("[-] Error: failed to marshal models list: %v\n", internalErr) + internalErr = errors.New(errorMsg) + return nil, internalErr } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - if _, err := w.Write(out); err != nil { - log.Printf("[-] Error: failed to write response: %v\n", err) - } - // We've handled the response, do not continue proxying this request - return "" + return out, nil } // ReverseProxyHandler act as a reverse proxy, it will redirect the request to the destination website and return the response @@ -136,7 +133,18 @@ func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Reque // handle /v1/models if r.URL.Path == "/v1/models" { - ph.listAllModels(w, r) + output, err := ph.listAllModels(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return "", nil, false + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + if _, err := w.Write(output); err != nil { + log.Printf("[-] Error: failed to write response: %v\n", err) + } + // We've handled the response, do not continue proxying this request return "", nil, false } log.Printf("[*] receive a request from %s\n", r.RemoteAddr)