Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions src/model-proxy/src/proxy/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"log"
"net/http"
"strings"
"sync"
"time"

"modelproxy/types"
)
Expand All @@ -19,23 +21,33 @@ func obfuscateToken(token string) string {
return token[:3] + "***" + token[len(token)-3:]
}

const RefreshIntervalSeconds = 600 // 10 minutes
type RestServerAuthenticator struct {
// rest-server token => model names => model service list
tokenToModels map[string]map[string][]*types.BaseSpec
tokenToModels map[string]map[string][]*types.BaseSpec
tokenUpdatedTime map[string]int64
mu sync.RWMutex
}

func NewRestServerAuthenticator() *RestServerAuthenticator {
return &RestServerAuthenticator{
tokenToModels: make(map[string]map[string][]*types.BaseSpec),
tokenToModels: make(map[string]map[string][]*types.BaseSpec),
tokenUpdatedTime: make(map[string]int64),
}
}

// UpdateTokenModels updates the model mapping for a given token
func (ra *RestServerAuthenticator) UpdateTokenModels(token string, model2Service map[string][]*types.BaseSpec) {
ra.mu.Lock()
defer ra.mu.Unlock()
if ra.tokenToModels == nil {
ra.tokenToModels = make(map[string]map[string][]*types.BaseSpec)
}
if ra.tokenUpdatedTime == nil {
ra.tokenUpdatedTime = make(map[string]int64)
}
ra.tokenToModels[token] = model2Service
ra.tokenUpdatedTime[token] = time.Now().Unix()
}

// Check if the request is authenticated and return the available model urls
Expand All @@ -48,16 +60,31 @@ func (ra *RestServerAuthenticator) AuthenticateReq(req *http.Request, reqBody ma
log.Printf("[-] Error: 'model' field missing or not a string in request body")
return false, nil
}

ra.mu.RLock()
availableModels, ok := ra.tokenToModels[token]
tokenLastUpdated, timeOk := ra.tokenUpdatedTime[token]
ra.mu.RUnlock()

if !ok {
// request to RestServer to get the models
log.Printf("[-] Error: token %s not found in the authenticator\n", obfuscateToken(token))
availableModels, err := GetJobModelsMapping(req)
}
if !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds {
log.Printf("[-] Error: token %s info is outdated in the authenticator\n", obfuscateToken(token))
}
// If the token is not found or the token info is older than RefreshIntervalSeconds, refresh it

if !ok || !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds {
// request to RestServer to get the models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do the judgement !timeOk || time.Now().Unix()-tokenLastUpdated > RefreshIntervalSeconds twice? only for log?

freshModels, err := ListJobModelsMapping(req)
if err != nil {
log.Printf("[-] Error: failed to get models for token %s: %v\n", obfuscateToken(token), err)
return false, nil
}
ra.tokenToModels[token] = availableModels
ra.UpdateTokenModels(token, freshModels)

availableModels = freshModels
log.Printf("[*] Refreshed models for token %s: %v\n", obfuscateToken(token), availableModels)
}
if len(availableModels) == 0 {
log.Printf("[-] Error: no models found")
Expand Down
44 changes: 27 additions & 17 deletions src/model-proxy/src/proxy/model_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ func ListInferenceJobs(restServerUrl string, restServerToken string) ([]string,

result := make([]string, 0, len(jobs))
for _, j := range jobs {

jobId := fmt.Sprintf("%s~%s", j.Username, j.Name)
result = append(result, jobId)
}
Expand Down Expand Up @@ -213,29 +212,29 @@ func listModels(jobServerUrl string, modelApiKey string) ([]string, error) {
}

// return JobURL => models
func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error) {
func ListJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error) {
// modelName to job server url list
mapping := make(map[string][]*types.BaseSpec)
modelJobMapping := make(map[string][]*types.BaseSpec)

if req == nil || req.Host == "" {
return mapping, fmt.Errorf("invalid request or empty host")
return modelJobMapping, fmt.Errorf("invalid request or empty host")
}
// get rest server base url from the os environment variable
restServerUrl := os.Getenv("REST_SERVER_URI")
if restServerUrl == "" {
return mapping, fmt.Errorf("REST_SERVER_URI environment variable is not set")
return modelJobMapping, fmt.Errorf("REST_SERVER_URI environment variable is not set")
}
// Ensure restServerUrl doesn't end with slash
restServerUrl = strings.TrimRight(restServerUrl, "/")

restServerToken := req.Header.Get("Authorization")
jobIDs, err := ListInferenceJobs(restServerUrl, restServerToken)
if err != nil {
return mapping, fmt.Errorf("failed to list model serving jobs: %w", err)
return modelJobMapping, fmt.Errorf("failed to list model serving jobs: %w", err)
}

// Channel to collect results
type modelMapping struct {
type ModelEndpoint struct {
modelName string
modelService *types.BaseSpec
}
Expand All @@ -244,7 +243,7 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error
log.Printf("[-] Error: invalid FETCH_JOB_CONCURRENCY value: %s\n", err)
concurrency = 10 // default value
}
results := make(chan modelMapping, concurrency) // Buffer for potential models
allModelEndpoints := make(chan ModelEndpoint, concurrency) // Buffer for potential models

// Use a wait group to run jobs in parallel
var wg sync.WaitGroup
Expand Down Expand Up @@ -287,13 +286,17 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error
return
}

userName := strings.Split(jobId, "~")[0]
jobName := strings.Split(jobId, "~")[1]
// Send results to channel
for _, model := range models {
results <- modelMapping{
allModelEndpoints <- ModelEndpoint{
modelName: model,
modelService: &types.BaseSpec{
URL: jobServerUrl,
Key: apiKey,
URL: jobServerUrl,
Key: apiKey,
JobName: jobName,
UserName: userName,
},
}
}
Expand All @@ -303,16 +306,23 @@ func GetJobModelsMapping(req *http.Request) (map[string][]*types.BaseSpec, error
// Close the results channel when all goroutines are done
go func() {
wg.Wait()
close(results)
close(allModelEndpoints)
}()

// Collect results from channel
for result := range results {
if _, ok := mapping[result.modelName]; !ok {
mapping[result.modelName] = make([]*types.BaseSpec, 0)
for result := range allModelEndpoints {
if _, ok := modelJobMapping[result.modelName]; !ok {
modelJobMapping[result.modelName] = make([]*types.BaseSpec, 0)
}
modelJobMapping[result.modelName] = append(modelJobMapping[result.modelName], result.modelService)
jobModelName := fmt.Sprintf("%s@%s", result.modelService.JobName, result.modelName)

// Also map jobName@modelName to the model service which allows users to specify the job name in the model field
if _, ok := modelJobMapping[jobModelName]; !ok {
modelJobMapping[jobModelName] = make([]*types.BaseSpec, 0)
}
mapping[result.modelName] = append(mapping[result.modelName], result.modelService)
modelJobMapping[jobModelName] = append(modelJobMapping[jobModelName], result.modelService)
}

return mapping, nil
return modelJobMapping, nil
}
84 changes: 48 additions & 36 deletions src/model-proxy/src/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
Expand Down Expand Up @@ -73,6 +74,50 @@ func NewProxyHandler(config *types.Config) *ProxyHandler {
}
}

// listAllModels list all models from the rest server and return the models in OpenAI style response
func (ph *ProxyHandler) listAllModels(r *http.Request) ([]byte, error) {
log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr)
model2Service, internalErr := ListJobModelsMapping(r)
if internalErr != nil {
errorMsg := fmt.Sprintf("[-] Error: failed to list models: %v\n", internalErr)
internalErr = errors.New(errorMsg)
return nil, internalErr
}
// Update the ph.authenticator
token := r.Header.Get("Authorization")
token = strings.Replace(token, "Bearer ", "", 1)
ph.authenticator.UpdateTokenModels(token, model2Service)

// convert models list to OpenAI style list and write it to w
ids := make([]string, 0, len(model2Service))
for id := range model2Service {
ids = append(ids, id)
}
sort.Strings(ids)

list := map[string]interface{}{
"object": "list",
"data": make([]map[string]interface{}, 0, len(ids)),
}
for _, id := range ids {
item := map[string]interface{}{
"id": id,
"object": "model",
// intentionally not including created or owned_by
}
list["data"] = append(list["data"].([]map[string]interface{}), item)
}

out, internalErr := json.Marshal(list)
if internalErr != nil {
errorMsg := fmt.Sprintf("[-] Error: failed to marshal models list: %v\n", internalErr)
internalErr = errors.New(errorMsg)
return nil, internalErr
}

return out, nil
}

// ReverseProxyHandler act as a reverse proxy, it will redirect the request to the destination website and return the response
func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Request) (string, []string, bool) {
log.Printf("[*] receive a request: %s %s\n", r.Method, r.URL.String())
Expand All @@ -88,48 +133,15 @@ func (ph *ProxyHandler) ReverseProxyHandler(w http.ResponseWriter, r *http.Reque

// handle /v1/models
if r.URL.Path == "/v1/models" {
log.Printf("[*] receive a models list request from %s\n", r.RemoteAddr)
model2Service, err := GetJobModelsMapping(r)
if err != nil {
log.Printf("[-] Error: failed to get models mapping: %v\n", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return "", nil, false
}
// Update the ph.authenticator
token := r.Header.Get("Authorization")
token = strings.Replace(token, "Bearer ", "", 1)
ph.authenticator.UpdateTokenModels(token, model2Service)

// convert models list to OpenAI style list and write it to w
ids := make([]string, 0, len(model2Service))
for id := range model2Service {
ids = append(ids, id)
}
sort.Strings(ids)

list := map[string]interface{}{
"object": "list",
"data": make([]map[string]interface{}, 0, len(ids)),
}
for _, id := range ids {
item := map[string]interface{}{
"id": id,
"object": "model",
// intentionally not including created or owned_by
}
list["data"] = append(list["data"].([]map[string]interface{}), item)
}

out, err := json.Marshal(list)
output, err := ph.listAllModels(r)
if err != nil {
log.Printf("[-] Error: failed to marshal models list: %v\n", err)
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
http.Error(w, err.Error(), http.StatusInternalServerError)
return "", nil, false
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(out); err != nil {
if _, err := w.Write(output); err != nil {
log.Printf("[-] Error: failed to write response: %v\n", err)
}
// We've handled the response, do not continue proxying this request
Expand Down
8 changes: 5 additions & 3 deletions src/model-proxy/src/types/config_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@ type AzureStorage struct {

// BaseSpec is the base spec for azure and openai
type BaseSpec struct {
URL string `json:"url"`
Key string `json:"key"`
Version string `json:"version,omitempty"`
URL string `json:"url"`
Key string `json:"key"`
Version string `json:"version,omitempty"`
JobName string `json:"job_name,omitempty"`
UserName string `json:"user_name,omitempty"`
}

// ParseConfig parse the config file into Config struct
Expand Down