diff --git a/src/model-proxy/src/proxy/authenticator.go b/src/model-proxy/src/proxy/authenticator.go index bacc1874..fa3ed8f7 100644 --- a/src/model-proxy/src/proxy/authenticator.go +++ b/src/model-proxy/src/proxy/authenticator.go @@ -7,6 +7,8 @@ import ( "log" "net/http" "strings" + "sync" + "time" "modelproxy/types" ) @@ -19,23 +21,33 @@ func obfuscateToken(token string) string { return token[:3] + "***" + token[len(token)-3:] } +const RefreshIntervalSeconds = 600 // 10 minutes type RestServerAuthenticator struct { // rest-server token => model names => model service list - tokenToModels map[string]map[string][]*types.BaseSpec + tokenToModels map[string]map[string][]*types.BaseSpec + tokenUpdatedTime map[string]int64 + mu sync.RWMutex } func NewRestServerAuthenticator() *RestServerAuthenticator { return &RestServerAuthenticator{ - tokenToModels: make(map[string]map[string][]*types.BaseSpec), + tokenToModels: make(map[string]map[string][]*types.BaseSpec), + tokenUpdatedTime: make(map[string]int64), } } // UpdateTokenModels updates the model mapping for a given token func (ra *RestServerAuthenticator) UpdateTokenModels(token string, model2Service map[string][]*types.BaseSpec) { + ra.mu.Lock() + defer ra.mu.Unlock() if ra.tokenToModels == nil { ra.tokenToModels = make(map[string]map[string][]*types.BaseSpec) } + if ra.tokenUpdatedTime == nil { + ra.tokenUpdatedTime = make(map[string]int64) + } ra.tokenToModels[token] = model2Service + ra.tokenUpdatedTime[token] = time.Now().Unix() } // Check if the request is authenticated and return the available model urls @@ -48,16 +60,31 @@ func (ra *RestServerAuthenticator) AuthenticateReq(req *http.Request, reqBody ma log.Printf("[-] Error: 'model' field missing or not a string in request body") return false, nil } + + ra.mu.RLock() availableModels, ok := ra.tokenToModels[token] + tokenLastUpdated, timeOk := ra.tokenUpdatedTime[token] + ra.mu.RUnlock() + if !ok { - // request to RestServer to get the models log.Printf("[-] Error: token %s not found in the authenticator\n", obfuscateToken(token)) - availableModels, err := GetJobModelsMapping(req) + } + if !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds { + log.Printf("[-] Error: token %s info is outdated in the authenticator\n", obfuscateToken(token)) + } + // If the token is not found or the token info is older than RefreshIntervalSeconds, refresh it + + if !ok || !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds { + // request to RestServer to get the models + freshModels, err := ListJobModelsMapping(req) if err != nil { log.Printf("[-] Error: failed to get models for token %s: %v\n", obfuscateToken(token), err) return false, nil } - ra.tokenToModels[token] = availableModels + ra.UpdateTokenModels(token, freshModels) + + availableModels = freshModels + log.Printf("[*] Refreshed models for token %s: %v\n", obfuscateToken(token), availableModels) } if len(availableModels) == 0 { log.Printf("[-] Error: no models found") diff --git a/src/model-proxy/src/proxy/model_server.go b/src/model-proxy/src/proxy/model_server.go index 6e5f6c44..b9071b69 100644 --- a/src/model-proxy/src/proxy/model_server.go +++ b/src/model-proxy/src/proxy/model_server.go @@ -74,7 +74,6 @@ func ListInferenceJobs(restServerUrl string, restServerToken string) ([]string, result := make([]string, 0, len(jobs)) for _, j := range jobs { - jobId := fmt.Sprintf("%s~%s", j.Username, j.Name) result = append(result, jobId) } @@ -213,17 +212,17 @@ func listModels(jobServerUrl string, modelApiKey string) ([]string, error) { } // return JobURL => models -func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error) { +func ListJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error) { // modelName to job server url list - mapping := make(map[string][]*types.BaseSpec) + modelJobMapping := make(map[string][]*types.BaseSpec) if req == nil || req.Host == "" { - return mapping, fmt.Errorf("invalid request or empty host") + return modelJobMapping, fmt.Errorf("invalid request or empty host") } // get rest server base url from the os environment variable restServerUrl := os.Getenv("REST_SERVER_URI") if restServerUrl == "" { - return mapping, fmt.Errorf("REST_SERVER_URI environment variable is not set") + return modelJobMapping, fmt.Errorf("REST_SERVER_URI environment variable is not set") } // Ensure restServerUrl doesn't end with slash restServerUrl = strings.TrimRight(restServerUrl, "/") @@ -231,11 +230,11 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error restServerToken := req.Header.Get("Authorization") jobIDs, err := ListInferenceJobs(restServerUrl, restServerToken) if err != nil { - return mapping, fmt.Errorf("failed to list model serving jobs: %w", err) + return modelJobMapping, fmt.Errorf("failed to list model serving jobs: %w", err) } // Channel to collect results - type modelMapping struct { + type ModelEndpoint struct { modelName string modelService *types.BaseSpec } @@ -244,7 +243,7 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error log.Printf("[-] Error: invalid FETCH_JOB_CONCURRENCY value: %s\n", err) concurrency = 10 // default value } - results := make(chan modelMapping, concurrency) // Buffer for potential models + allModelEndpoints := make(chan ModelEndpoint, concurrency) // Buffer for potential models // Use a wait group to run jobs in parallel var wg sync.WaitGroup @@ -287,13 +286,17 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error return } + userName := strings.Split(jobId, "~")[0] + jobName := strings.Split(jobId, "~")[1] // Send results to channel for _, model := range models { - results <- modelMapping{ + allModelEndpoints <- ModelEndpoint{ modelName: model, modelService: &types.BaseSpec{ - URL: jobServerUrl, - Key: apiKey, + URL: jobServerUrl, + Key: apiKey, + JobName: jobName, + UserName: userName, }, } } @@ -303,16 +306,23 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error // Close the results channel when all goroutines are done go func() { wg.Wait() - close(results) + close(allModelEndpoints) }() // Collect results from channel - for result := range results { - if _, ok := mapping[result.modelName]; !ok { - mapping[result.modelName] = make([]*types.BaseSpec, 0) + for result := range allModelEndpoints { + if _, ok := modelJobMapping[result.modelName]; !ok { + modelJobMapping[result.modelName] = make([]*types.BaseSpec, 0) + } + modelJobMapping[result.modelName] = append(modelJobMapping[result.modelName], result.modelService) + jobModelName := fmt.Sprintf("%s@%s", result.modelService.JobName, result.modelName) + + // Also map jobName@modelName to the model service which allows users to specify the job name in the model field + if _, ok := modelJobMapping[jobModelName]; !ok { + modelJobMapping[jobModelName] = make([]*types.BaseSpec, 0) } - mapping[result.modelName] = append(mapping[result.modelName], result.modelService) + modelJobMapping[jobModelName] = append(modelJobMapping[jobModelName], result.modelService) } - return mapping, nil + return modelJobMapping, nil } diff --git a/src/model-proxy/src/proxy/proxy.go b/src/model-proxy/src/proxy/proxy.go index a520beac..619eb451 100644 --- a/src/model-proxy/src/proxy/proxy.go +++ b/src/model-proxy/src/proxy/proxy.go @@ -7,6 +7,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "log" "net/http" @@ -73,6 +74,50 @@ func NewProxyHandler(config *types.Config) *ProxyHandler { } } +// listAllModels list all models from the rest server and return the models in OpenAI style response +func (ph *ProxyHandler) listAllModels(r *http.Request) ([]byte, error) { + log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr) + model2Service, internalErr := ListJobModelsMapping(r) + if internalErr != nil { + errorMsg := fmt.Sprintf("[-] Error: failed to list models: %v\n", internalErr) + internalErr = errors.New(errorMsg) + return nil, internalErr + } + // Update the ph.authenticator + token := r.Header.Get("Authorization") + token = strings.Replace(token, "Bearer ", "", 1) + ph.authenticator.UpdateTokenModels(token, model2Service) + + // convert models list to OpenAI style list and write it to w + ids := make([]string, 0, len(model2Service)) + for id := range model2Service { + ids = append(ids, id) + } + sort.Strings(ids) + + list := map[string]interface{}{ + "object": "list", + "data": make([]map[string]interface{}, 0, len(ids)), + } + for _, id := range ids { + item := map[string]interface{}{ + "id": id, + "object": "model", + // intentionally not including created or owned_by + } + list["data"] = append(list["data"].([]map[string]interface{}), item) + } + + out, internalErr := json.Marshal(list) + if internalErr != nil { + errorMsg := fmt.Sprintf("[-] Error: failed to marshal models list: %v\n", internalErr) + internalErr = errors.New(errorMsg) + return nil, internalErr + } + + return out, nil +} + // ReverseProxyHandler act as a reverse proxy, it will redirect the request to the destination website and return the response func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Request) (string, []string, bool) { log.Printf("[*] receive a request: %s %s\n", r.Method, r.URL.String()) @@ -88,48 +133,15 @@ func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Reque // handle /v1/models if r.URL.Path == "/v1/models" { - log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr) - model2Service, err := GetJobModelsMapping(r) - if err != nil { - log.Printf("[-] Error: failed to get models mapping: %v\n", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) - return "", nil, false - } - // Update the ph.authenticator - token := r.Header.Get("Authorization") - token = strings.Replace(token, "Bearer ", "", 1) - ph.authenticator.UpdateTokenModels(token, model2Service) - - // convert models list to OpenAI style list and write it to w - ids := make([]string, 0, len(model2Service)) - for id := range model2Service { - ids = append(ids, id) - } - sort.Strings(ids) - - list := map[string]interface{}{ - "object": "list", - "data": make([]map[string]interface{}, 0, len(ids)), - } - for _, id := range ids { - item := map[string]interface{}{ - "id": id, - "object": "model", - // intentionally not including created or owned_by - } - list["data"] = append(list["data"].([]map[string]interface{}), item) - } - - out, err := json.Marshal(list) + output, err := ph.listAllModels(r) if err != nil { - log.Printf("[-] Error: failed to marshal models list: %v\n", err) - http.Error(w, "Internal Server Error", http.StatusInternalServerError) + http.Error(w, err.Error(), http.StatusInternalServerError) return "", nil, false } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - if _, err := w.Write(out); err != nil { + if _, err := w.Write(output); err != nil { log.Printf("[-] Error: failed to write response: %v\n", err) } // We've handled the response, do not continue proxying this request diff --git a/src/model-proxy/src/types/config_types.go b/src/model-proxy/src/types/config_types.go index b8b7f07f..3409d7a3 100644 --- a/src/model-proxy/src/types/config_types.go +++ b/src/model-proxy/src/types/config_types.go @@ -48,9 +48,11 @@ type AzureStorage struct { // BaseSpec is the base spec for azure and openai type BaseSpec struct { - URL string `json:"url"` - Key string `json:"key"` - Version string `json:"version,omitempty"` + URL string `json:"url"` + Key string `json:"key"` + Version string `json:"version,omitempty"` + JobName string `json:"job_name,omitempty"` + UserName string `json:"user_name,omitempty"` } // ParseConfig parse the config file into Config struct