Skip to content

Commit c298104

Browse files
committed
auth middleware integrated
1 parent 7bf2461 commit c298104

11 files changed

Lines changed: 229 additions & 112 deletions

controllers/llm_controller.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"io"
55
"net/http"
66

7+
"github.com/Thanus-Kumaar/controller_microservice_v2/middleware"
78
"github.com/Thanus-Kumaar/controller_microservice_v2/modules"
89
"github.com/rs/zerolog"
910
"maps"
@@ -26,8 +27,13 @@ func NewLlmController(module *modules.LlmModule, logger zerolog.Logger) *LlmCont
2627
// GenerateNotebookHandler handles POST /api/v1/llm/generate
2728
func (c *LlmController) GenerateNotebookHandler(w http.ResponseWriter, r *http.Request) {
2829
ctx := r.Context()
30+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
31+
if !ok {
32+
http.Error(w, "user not found in context", http.StatusUnauthorized)
33+
return
34+
}
2935

30-
resp, err := c.Module.GenerateNotebook(ctx, r.Body)
36+
resp, err := c.Module.GenerateNotebook(ctx, r.Body, user.ID)
3137
if err != nil {
3238
c.Logger.Error().Err(err).Msg("Failed to proxy generate request")
3339
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -51,8 +57,13 @@ func (c *LlmController) GenerateNotebookHandler(w http.ResponseWriter, r *http.R
5157
func (c *LlmController) ModifyNotebookHandler(w http.ResponseWriter, r *http.Request) {
5258
sessionID := r.PathValue("session_id")
5359
ctx := r.Context()
60+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
61+
if !ok {
62+
http.Error(w, "user not found in context", http.StatusUnauthorized)
63+
return
64+
}
5465

55-
resp, err := c.Module.ModifyNotebook(ctx, sessionID, r.Body)
66+
resp, err := c.Module.ModifyNotebook(ctx, sessionID, r.Body, user.ID)
5667
if err != nil {
5768
c.Logger.Error().Err(err).Msgf("Failed to proxy modify request for session %s", sessionID)
5869
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -72,8 +83,13 @@ func (c *LlmController) ModifyNotebookHandler(w http.ResponseWriter, r *http.Req
7283
func (c *LlmController) FixNotebookHandler(w http.ResponseWriter, r *http.Request) {
7384
sessionID := r.PathValue("session_id")
7485
ctx := r.Context()
86+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
87+
if !ok {
88+
http.Error(w, "user not found in context", http.StatusUnauthorized)
89+
return
90+
}
7591

76-
resp, err := c.Module.FixNotebook(ctx, sessionID, r.Body)
92+
resp, err := c.Module.FixNotebook(ctx, sessionID, r.Body, user.ID)
7793
if err != nil {
7894
c.Logger.Error().Err(err).Msgf("Failed to proxy fix request for session %s", sessionID)
7995
http.Error(w, err.Error(), http.StatusInternalServerError)

controllers/notebook_controller.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,23 @@ func (c *NotebookController) ListNotebooksHandler(w http.ResponseWriter, r *http
5252
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
5353
defer cancel()
5454

55-
// TODO: parse limit/offset/created_by/problem_statement_id from r.URL.Query()
56-
nbs, err := c.NotebookModule.ListNotebooks(ctx, nil)
55+
filters := make(map[string]string)
56+
query := r.URL.Query()
57+
58+
if createdBy := query.Get("created_by"); createdBy != "" {
59+
filters["created_by"] = createdBy
60+
}
61+
if problemStatementID := query.Get("problem_statement_id"); problemStatementID != "" {
62+
filters["problem_statement_id"] = problemStatementID
63+
}
64+
if limit := query.Get("limit"); limit != "" {
65+
filters["limit"] = limit
66+
}
67+
if offset := query.Get("offset"); offset != "" {
68+
filters["offset"] = offset
69+
}
70+
71+
nbs, err := c.NotebookModule.ListNotebooks(ctx, filters)
5772
if err != nil {
5873
c.Logger.Error().Err(err).Msg("failed to list notebooks")
5974
http.Error(w, "error listing notebooks", http.StatusInternalServerError)

controllers/problem_controller.go

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"time"
1010

11+
"github.com/Thanus-Kumaar/controller_microservice_v2/middleware"
1112
"github.com/Thanus-Kumaar/controller_microservice_v2/modules"
1213
"github.com/Thanus-Kumaar/controller_microservice_v2/pkg"
1314
"github.com/Thanus-Kumaar/controller_microservice_v2/pkg/models"
@@ -30,18 +31,21 @@ func NewProblemController(problemModule *modules.ProblemModule, logger zerolog.L
3031

3132
// CreateProblemHandler handles POST /api/v1/problems
3233
func (c *ProblemController) CreateProblemHandler(w http.ResponseWriter, r *http.Request) {
33-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
34+
ctx, cancel := context.WithTimeout(r.Context(), 10*time.Second)
3435
defer cancel()
35-
// For now, we'll use a hardcoded user ID for testing.
36-
// TODO: Replace with actual user_id from auth middleware.
37-
const hardcodedUserID = "123e4567-e89b-12d3-a456-426614174000"
36+
37+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
38+
if !ok {
39+
http.Error(w, "user not found in context", http.StatusUnauthorized)
40+
return
41+
}
3842

3943
var req models.CreateProblemRequest
4044
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
4145
http.Error(w, "invalid request body", http.StatusBadRequest)
4246
return
4347
}
44-
problemStatement, err := c.ProblemModule.CreateProblem(ctx, &req, hardcodedUserID)
48+
problemStatement, err := c.ProblemModule.CreateProblem(ctx, &req, user.ID)
4549
if err != nil {
4650
c.Logger.Error().Err(err).Msg("failed to create problem statement")
4751
http.Error(w, fmt.Sprintf("error creating problem statemnet: %v", err), http.StatusInternalServerError)
@@ -52,13 +56,16 @@ func (c *ProblemController) CreateProblemHandler(w http.ResponseWriter, r *http.
5256

5357
// ListProblemsHandler handles GET /api/v1/problems
5458
func (c *ProblemController) ListProblemsHandler(w http.ResponseWriter, r *http.Request) {
55-
// For now, we'll use a hardcoded user ID for testing.
56-
// TODO: Replace with actual user_id from auth middleware.
57-
const hardcodedUserID = "123e4567-e89b-12d3-a456-426614174000"
5859
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
5960
defer cancel()
6061

61-
problems, err := c.ProblemModule.GetProblemsByUserID(ctx, hardcodedUserID)
62+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
63+
if !ok {
64+
http.Error(w, "user not found in context", http.StatusUnauthorized)
65+
return
66+
}
67+
68+
problems, err := c.ProblemModule.GetProblemsByUserID(ctx, user.ID)
6269
if err != nil {
6370
c.Logger.Error().Err(err).Msg("failed to list problems by user id")
6471
http.Error(w, "failed to retrieve problems", http.StatusInternalServerError)
@@ -95,17 +102,19 @@ func (c *ProblemController) UpdateProblemByIDHandler(w http.ResponseWriter, r *h
95102
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
96103
defer cancel()
97104

98-
// For now, we'll use a hardcoded user ID for testing.
99-
// TODO: Replace with actual user_id from auth middleware.
100-
const hardcodedUserID = "123e4567-e89b-12d3-a456-426614174000"
105+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
106+
if !ok {
107+
http.Error(w, "user not found in context", http.StatusUnauthorized)
108+
return
109+
}
101110

102111
var req models.UpdateProblemRequest
103112
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
104113
http.Error(w, "invalid request body", http.StatusBadRequest)
105114
return
106115
}
107116

108-
updatedProblem, err := c.ProblemModule.UpdateProblem(ctx, problemID, &req, hardcodedUserID)
117+
updatedProblem, err := c.ProblemModule.UpdateProblem(ctx, problemID, &req, user.ID)
109118
if err != nil {
110119
c.Logger.Error().Err(err).Msg("failed to update problem")
111120
// This could be a not found error or an authorization error.
@@ -122,11 +131,13 @@ func (c *ProblemController) DeleteProblemByIDHandler(w http.ResponseWriter, r *h
122131
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
123132
defer cancel()
124133

125-
// For now, we'll use a hardcoded user ID for testing.
126-
// TODO: Replace with actual user_id from auth middleware.
127-
const hardcodedUserID = "123e4567-e89b-12d3-a456-426614174000"
134+
user, ok := ctx.Value(middleware.UserContextKey).(*middleware.User)
135+
if !ok {
136+
http.Error(w, "user not found in context", http.StatusUnauthorized)
137+
return
138+
}
128139

129-
err := c.ProblemModule.DeleteProblem(ctx, problemID, hardcodedUserID)
140+
err := c.ProblemModule.DeleteProblem(ctx, problemID, user.ID)
130141
if err != nil {
131142
c.Logger.Error().Err(err).Msg("failed to delete problem")
132143
// This could be a not found error or an authorization error.

docker-compose.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ services:
1515
JUPYTER_AUTH_TOKEN: "YOUR_SECRET_TOKEN"
1616
CULL_INTERVAL_MINUTES: 10
1717
IDLE_THRESHOLD_MINUTES: 30
18-
AUTH_GRPC_SERVER_ADDRESS: "localhost:5001"
18+
AUTH_GRPC_ADDRESS: "localhost:5001"
1919
LLM_MICROSERVICE_URL: "http://host.docker.internal:8000"
2020
networks:
2121
- evoc-net

middleware/auth_middleware.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"os"
7+
8+
"github.com/Thanus-Kumaar/controller_microservice_v2/proto"
9+
"google.golang.org/grpc"
10+
"google.golang.org/grpc/credentials/insecure"
11+
)
12+
13+
type contextKey string
14+
15+
const (
16+
UserContextKey = contextKey("user")
17+
)
18+
19+
type User struct {
20+
ID string
21+
Role string
22+
Email string
23+
UserName string
24+
FullName string
25+
}
26+
27+
func AuthMiddleware(next http.Handler) http.Handler {
28+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
29+
cookie, err := r.Cookie("t")
30+
if err != nil {
31+
http.Error(w, "Authorization cookie required", http.StatusUnauthorized)
32+
return
33+
}
34+
35+
tokenString := cookie.Value
36+
if tokenString == "" {
37+
http.Error(w, "Authorization cookie is empty", http.StatusUnauthorized)
38+
return
39+
}
40+
41+
conn, err := grpc.NewClient(os.Getenv("AUTH_GRPC_ADDRESS"), grpc.WithTransportCredentials(insecure.NewCredentials()))
42+
if err != nil {
43+
http.Error(w, "Failed to connect to auth service", http.StatusInternalServerError)
44+
return
45+
}
46+
defer conn.Close()
47+
48+
client := proto.NewAuthenticateClient(conn)
49+
res, err := client.Auth(context.Background(), &proto.TokenValidateRequest{Token: tokenString})
50+
if err != nil {
51+
http.Error(w, "Failed to validate token", http.StatusInternalServerError)
52+
return
53+
}
54+
55+
if !res.Valid {
56+
http.Error(w, "Invalid token", http.StatusUnauthorized)
57+
return
58+
}
59+
60+
user := &User{
61+
ID: res.Id,
62+
Role: res.Role,
63+
Email: res.Email,
64+
UserName: res.UserName,
65+
FullName: res.FullName,
66+
}
67+
68+
ctx := context.WithValue(r.Context(), UserContextKey, user)
69+
next.ServeHTTP(w, r.WithContext(ctx))
70+
})
71+
}

modules/llm_module.go

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ func NewLlmModule(repo repository.LlmRepository) *LlmModule {
2424
}
2525

2626
// GenerateNotebook validates and proxies the generate request.
27-
func (m *LlmModule) GenerateNotebook(ctx context.Context, body io.Reader) (*http.Response, error) {
27+
func (m *LlmModule) GenerateNotebook(ctx context.Context, body io.Reader, userID string) (*http.Response, error) {
2828
bodyBytes, err := io.ReadAll(body)
2929
if err != nil {
3030
return nil, fmt.Errorf("failed to read request body: %w", err)
@@ -35,10 +35,12 @@ func (m *LlmModule) GenerateNotebook(ctx context.Context, body io.Reader) (*http
3535
return nil, fmt.Errorf("failed to decode request body as JSON: %w", err)
3636
}
3737

38-
if err := IsUserIDandNotebookIDPresent(requestData); err != nil {
38+
if err := IsNotebookIDPresent(requestData); err != nil {
3939
return nil, err
4040
}
4141

42+
requestData["user_id"] = userID
43+
4244
finalBodyBytes, err := json.Marshal(requestData)
4345
if err != nil {
4446
return nil, fmt.Errorf("failed to re-encode request body: %w", err)
@@ -48,7 +50,7 @@ func (m *LlmModule) GenerateNotebook(ctx context.Context, body io.Reader) (*http
4850
}
4951

5052
// ModifyNotebook validates and proxies the modify request.
51-
func (m *LlmModule) ModifyNotebook(ctx context.Context, sessionID string, body io.Reader) (*http.Response, error) {
53+
func (m *LlmModule) ModifyNotebook(ctx context.Context, sessionID string, body io.Reader, userID string) (*http.Response, error) {
5254
bodyBytes, err := io.ReadAll(body)
5355
if err != nil {
5456
return nil, fmt.Errorf("failed to read request body: %w", err)
@@ -59,10 +61,12 @@ func (m *LlmModule) ModifyNotebook(ctx context.Context, sessionID string, body i
5961
return nil, fmt.Errorf("failed to decode request body as JSON: %w", err)
6062
}
6163

62-
if err := IsUserIDandNotebookIDPresent(requestData); err != nil {
64+
if err := IsNotebookIDPresent(requestData); err != nil {
6365
return nil, err
6466
}
6567

68+
requestData["user_id"] = userID
69+
6670
if instruction, ok := requestData["instruction"].(string); !ok || instruction == "" {
6771
return nil, fmt.Errorf("request body must contain a non-empty 'instruction' string")
6872
}
@@ -75,11 +79,16 @@ func (m *LlmModule) ModifyNotebook(ctx context.Context, sessionID string, body i
7579
return nil, fmt.Errorf("'notebook' object must contain a 'cells' array")
7680
}
7781

78-
return m.Repo.ModifyNotebook(ctx, bytes.NewBuffer(bodyBytes))
82+
finalBodyBytes, err := json.Marshal(requestData)
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to re-encode request body: %w", err)
85+
}
86+
87+
return m.Repo.ModifyNotebook(ctx, bytes.NewBuffer(finalBodyBytes))
7988
}
8089

8190
// FixNotebook validates and proxies the fix request.
82-
func (m *LlmModule) FixNotebook(ctx context.Context, sessionID string, body io.Reader) (*http.Response, error) {
91+
func (m *LlmModule) FixNotebook(ctx context.Context, sessionID string, body io.Reader, userID string) (*http.Response, error) {
8392
bodyBytes, err := io.ReadAll(body)
8493
if err != nil {
8594
return nil, fmt.Errorf("failed to read request body: %w", err)
@@ -90,10 +99,12 @@ func (m *LlmModule) FixNotebook(ctx context.Context, sessionID string, body io.R
9099
return nil, fmt.Errorf("failed to decode request body as JSON: %w", err)
91100
}
92101

93-
if err := IsUserIDandNotebookIDPresent(requestData); err != nil {
102+
if err := IsNotebookIDPresent(requestData); err != nil {
94103
return nil, err
95104
}
96105

106+
requestData["user_id"] = userID
107+
97108
if traceback, ok := requestData["traceback"].(string); !ok || traceback == "" {
98109
return nil, fmt.Errorf("request body must contain a non-empty 'traceback' string")
99110
}
@@ -106,26 +117,22 @@ func (m *LlmModule) FixNotebook(ctx context.Context, sessionID string, body io.R
106117
return nil, fmt.Errorf("'notebook' object must contain a 'cells' array")
107118
}
108119

109-
return m.Repo.FixNotebook(ctx, bytes.NewBuffer(bodyBytes))
120+
finalBodyBytes, err := json.Marshal(requestData)
121+
if err != nil {
122+
return nil, fmt.Errorf("failed to re-encode request body: %w", err)
123+
}
124+
125+
return m.Repo.FixNotebook(ctx, bytes.NewBuffer(finalBodyBytes))
110126
}
111127

112-
func IsUserIDandNotebookIDPresent(requestData map[string]any) error {
113-
// TODO: User ID should not be passed in the body.
114-
// TODO: It should be extracted from the auth context, which i am not going to do now :)
115-
// making sure user_id and notebook_id are present
116-
if _, hasUserID := requestData["user_id"]; !hasUserID {
117-
return fmt.Errorf("request body must contain 'user_id'")
118-
}
128+
func IsNotebookIDPresent(requestData map[string]any) error {
129+
// making sure notebook_id is present
119130
if _, hasNotebookID := requestData["notebook_id"]; !hasNotebookID {
120131
return fmt.Errorf("request body must contain 'notebook_id'")
121132
}
122133
notebookIDStr, isString := requestData["notebook_id"].(string)
123134
if !isString || notebookIDStr == "" {
124135
return fmt.Errorf("'notebook_id' must be a non-empty string")
125136
}
126-
userIDStr, isString := requestData["user_id"].(string)
127-
if !isString || userIDStr == "" {
128-
return fmt.Errorf("'user_id' must be a non-empty string")
129-
}
130137
return nil
131138
}

modules/problem_module.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ func (m *ProblemModule) CreateProblem(ctx context.Context, problem *models.Creat
2626
return nil, errors.New("invalid problem creation request")
2727
}
2828

29-
// TODO: The createdBy string should be a valid UUID.
30-
// We should probably parse it here. For now, we'll assume it's valid.
3129
creatorID, err := uuid.Parse(createdBy)
3230
if err != nil {
3331
// If createdBy is an empty string or not a valid UUID, handle it.

0 commit comments

Comments
 (0)