diff --git a/controller/grpc/authenticate.go b/controller/grpc/authenticate.go index 3dace03..c15cc24 100644 --- a/controller/grpc/authenticate.go +++ b/controller/grpc/authenticate.go @@ -7,6 +7,7 @@ import ( "evolve/proto" "evolve/util/auth" dbutil "evolve/util/db/user" + "os" ) type GRPCServer struct { @@ -15,6 +16,7 @@ type GRPCServer struct { func (*GRPCServer) Auth(ctx context.Context, req *proto.TokenValidateRequest) (*proto.TokenValidateResponse, error) { user, err := auth.ValidateToken(req.GetToken()) + if err != nil { return &proto.TokenValidateResponse{ Valid: false, @@ -28,6 +30,21 @@ func (*GRPCServer) Auth(ctx context.Context, req *proto.TokenValidateRequest) (* }, err } + //checking for CSRF token + if os.Getenv("CSRF_PROTECTION") == "true" && req.GetCsrfToken() != "" && user["csrf_token"] != "" { + + csrfToken := req.GetCsrfToken() + if csrfToken != user["csrf_token"] { + return &proto.TokenValidateResponse{ + Valid: false, + }, nil + } + } else { + return &proto.TokenValidateResponse{ + Valid: false, + }, nil + } + userData, err := dbutil.UserById(ctx, user["id"], db) if err != nil { return &proto.TokenValidateResponse{ diff --git a/controller/login.go b/controller/login.go index 51a44f2..df5d39b 100644 --- a/controller/login.go +++ b/controller/login.go @@ -1,12 +1,21 @@ package controller import ( + "crypto/rand" + "encoding/base64" "evolve/modules" "evolve/util" "net/http" + "os" "time" ) +func generateCSRFToken() string { + b := make([]byte, 32) + rand.Read(b) + return base64.StdEncoding.EncodeToString(b) +} + func Login(res http.ResponseWriter, req *http.Request) { logger := util.SharedLogger @@ -42,5 +51,18 @@ func Login(res http.ResponseWriter, req *http.Request) { delete(user, "token") + if os.Getenv("CSRF_PROTECTION") == "true" { + csrfToken := generateCSRFToken() + http.SetCookie(res, &http.Cookie{ + Name: "csrf_token", + Value: csrfToken, + Path: "/", + HttpOnly: false, + SameSite: http.SameSiteLaxMode, + }) + + res.Header().Set("X-CSRF-Token", csrfToken) + } + util.JSONResponse(res, http.StatusOK, "Success", user) } diff --git a/go.mod b/go.mod index 27f172c..138d174 100644 --- a/go.mod +++ b/go.mod @@ -27,5 +27,5 @@ require ( golang.org/x/sync v0.17.0 // indirect golang.org/x/sys v0.37.0 // indirect golang.org/x/text v0.30.0 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect ) diff --git a/go.sum b/go.sum index bc1e9fd..91e6568 100644 --- a/go.sum +++ b/go.sum @@ -72,8 +72,8 @@ golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8 h1:M1rk8KBnUsBDg1oPGHNCxG4vc1f49epmTO7xscSajMk= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251022142026-3a174f9686a8/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda h1:i/Q+bfisr7gq6feoJnS/DlpdwEL4ihp41fvRiM3Ork0= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= diff --git a/main.go b/main.go index 901a193..ea50b18 100644 --- a/main.go +++ b/main.go @@ -15,6 +15,8 @@ import ( "os" "runtime" + "reflect" + "aidanwoods.dev/go-paseto" "github.com/rs/cors" "google.golang.org/grpc" @@ -27,6 +29,8 @@ var ( func serveHTTP() { logger := util.SharedLogger + + fmt.Println("type of logger", reflect.TypeOf(&logger)) // Register routes. http.HandleFunc(routes.TEST, controller.Test) http.HandleFunc(routes.REGISTER, controller.Register) @@ -46,11 +50,18 @@ func serveHTTP() { AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"*"}, AllowCredentials: true, + ExposedHeaders: []string{"X-CSRF-Token"}, }).Handler(http.DefaultServeMux) handler := util.SharedLogger.LogMiddleware(corsHandler) + var finalHandler http.Handler + if os.Getenv("CSRF_PROTECTION") == "true" { + finalHandler = util.CSRFMiddleware(handler) + } else { + finalHandler = handler + } - if err := http.ListenAndServe(HTTP_PORT, handler); err != nil { + if err := http.ListenAndServe(HTTP_PORT, finalHandler); err != nil { logger.Error(fmt.Sprintf("Failed to start server: %v", err), err) return } @@ -75,15 +86,15 @@ func serveGRPC() { func main() { + logger, err := util.InitLogger(os.Getenv("ENV")) + util.SharedLogger = logger HTTP_PORT = fmt.Sprintf(":%v", os.Getenv("HTTP_PORT")) GRPC_PORT = fmt.Sprintf(":%v", os.Getenv("GRPC_PORT")) - logger, err := util.InitLogger(os.Getenv("ENV")) if err != nil { fmt.Println("failed to init logger:", err) return } - util.SharedLogger = logger // Initialize db with schema. if err := db.InitDb(context.Background()); err != nil { diff --git a/proto/authenticate.pb.go b/proto/authenticate.pb.go index ffd679a..cf51760 100644 --- a/proto/authenticate.pb.go +++ b/proto/authenticate.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.5 -// protoc v5.29.3 +// protoc-gen-go v1.36.10 +// protoc v6.33.0 // source: proto/authenticate.proto package proto @@ -25,6 +25,7 @@ const ( type TokenValidateRequest struct { state protoimpl.MessageState `protogen:"open.v1"` Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"` + CsrfToken *string `protobuf:"bytes,2,opt,name=csrf_token,json=csrfToken,proto3,oneof" json:"csrf_token,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -66,6 +67,13 @@ func (x *TokenValidateRequest) GetToken() string { return "" } +func (x *TokenValidateRequest) GetCsrfToken() string { + if x != nil && x.CsrfToken != nil { + return *x.CsrfToken + } + return "" +} + // TokenValidateResponse is the response message for the Authenticate method. type TokenValidateResponse struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -153,30 +161,23 @@ func (x *TokenValidateResponse) GetFullName() string { var File_proto_authenticate_proto protoreflect.FileDescriptor -var file_proto_authenticate_proto_rawDesc = string([]byte{ - 0x0a, 0x18, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2f, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, - 0x63, 0x61, 0x74, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x22, 0x2c, 0x0a, 0x14, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x74, 0x6f, 0x6b, - 0x65, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x74, 0x6f, 0x6b, 0x65, 0x6e, 0x22, - 0x9f, 0x01, 0x0a, 0x15, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x12, - 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, - 0x12, 0x0a, 0x04, 0x72, 0x6f, 0x6c, 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x72, - 0x6f, 0x6c, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x04, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, - 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, - 0x72, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x66, 0x75, 0x6c, 0x6c, 0x4e, 0x61, 0x6d, - 0x65, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x75, 0x6c, 0x6c, 0x4e, 0x61, 0x6d, - 0x65, 0x32, 0x53, 0x0a, 0x0c, 0x41, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, 0x69, 0x63, 0x61, 0x74, - 0x65, 0x12, 0x43, 0x0a, 0x04, 0x41, 0x75, 0x74, 0x68, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x2e, 0x54, 0x6f, 0x6b, 0x65, 0x6e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x54, - 0x6f, 0x6b, 0x65, 0x6e, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -}) +const file_proto_authenticate_proto_rawDesc = "" + + "\n" + + "\x18proto/authenticate.proto\x12\x05proto\"_\n" + + "\x14TokenValidateRequest\x12\x14\n" + + "\x05token\x18\x01 \x01(\tR\x05token\x12\"\n" + + "\n" + + "csrf_token\x18\x02 \x01(\tH\x00R\tcsrfToken\x88\x01\x01B\r\n" + + "\v_csrf_token\"\x9f\x01\n" + + "\x15TokenValidateResponse\x12\x14\n" + + "\x05valid\x18\x01 \x01(\bR\x05valid\x12\x0e\n" + + "\x02id\x18\x02 \x01(\tR\x02id\x12\x12\n" + + "\x04role\x18\x03 \x01(\tR\x04role\x12\x14\n" + + "\x05email\x18\x04 \x01(\tR\x05email\x12\x1a\n" + + "\buserName\x18\x05 \x01(\tR\buserName\x12\x1a\n" + + "\bfullName\x18\x06 \x01(\tR\bfullName2S\n" + + "\fAuthenticate\x12C\n" + + "\x04Auth\x12\x1b.proto.TokenValidateRequest\x1a\x1c.proto.TokenValidateResponse\"\x00B\bZ\x06/protob\x06proto3" var ( file_proto_authenticate_proto_rawDescOnce sync.Once @@ -210,6 +211,7 @@ func file_proto_authenticate_proto_init() { if File_proto_authenticate_proto != nil { return } + file_proto_authenticate_proto_msgTypes[0].OneofWrappers = []any{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ diff --git a/proto/authenticate.proto b/proto/authenticate.proto index 12f6987..922ef6c 100644 --- a/proto/authenticate.proto +++ b/proto/authenticate.proto @@ -12,6 +12,7 @@ service Authenticate { // TokenValidateRequest is the request message for the Authenticate method. message TokenValidateRequest { string token = 1; + optional string csrf_token = 2; } // TokenValidateResponse is the response message for the Authenticate method. @@ -22,4 +23,5 @@ message TokenValidateResponse { string email = 4; string userName = 5; string fullName = 6; + } \ No newline at end of file diff --git a/proto/authenticate_grpc.pb.go b/proto/authenticate_grpc.pb.go index 9c1105b..66ac57c 100644 --- a/proto/authenticate_grpc.pb.go +++ b/proto/authenticate_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.29.3 +// - protoc v6.33.0 // source: proto/authenticate.proto package proto diff --git a/util/request.go b/util/request.go index d990140..c058bfb 100644 --- a/util/request.go +++ b/util/request.go @@ -35,3 +35,28 @@ func FromJson[T any](data map[string]any) (*T, error) { return result, nil } + +func CSRFMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + //skipping CSRF authentication for routes where user is not authenticated + if r.URL.Path == "/api/login" || r.URL.Path == "/api/register" || r.URL.Path == "/api/password/reset" || r.URL.Path == "/api/password/verify" || r.URL.Path == "/api/verify" { + next.ServeHTTP(w, r) + return + } + + cookie, err := r.Cookie("csrf_token") + if err != nil { + http.Error(w, "CSRF cookie not found", http.StatusForbidden) + return + } + + csrfToken := r.Header.Get("X-CSRF-Token") + if csrfToken == "" || csrfToken != cookie.Value { + http.Error(w, "Invalid CSRF token", http.StatusForbidden) + return + } + + next.ServeHTTP(w, r) + }) +}