diff --git a/.env.example b/.env.example index 538aca3..3b4cfec 100644 --- a/.env.example +++ b/.env.example @@ -6,3 +6,5 @@ PSQL_HOST=host-name PSQL_USER=db-user PSQL_PASSWORD=db-password PSQL_NAME=db-name + +JWT_SECRET=AllYourBaseAreBelongToUs diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..b4d4b37 --- /dev/null +++ b/auth.go @@ -0,0 +1,86 @@ +package server + +import ( + "encoding/json" + "net/http" + "time" + + "github.com/benchttp/server/benchttp" + "github.com/benchttp/server/jwt" +) + +func (s *Server) handleSignin(w http.ResponseWriter, r *http.Request) { + var body struct { + Code string `json:"code"` + } + + err := json.NewDecoder(r.Body).Decode(&body) + if err != nil { + writeError(w, &ErrBadRequest) + return + } + + ghToken, err := s.OAuthClient.ExchangeForAccessToken(body.Code) + if err != nil { + writeError(w, &ErrBadRequest) + return + } + + name, email, err := s.OAuthClient.GetUser(ghToken) + if err != nil { + writeError(w, &ErrInternal) + return + } + + user := benchttp.User{} + + switch s.UserService.Exists(email) { + case true: + user, err = s.UserService.GetByEmail(email) + default: + user, err = s.UserService.Create(name, email) + } + + if err != nil { + writeError(w, &ErrUnauthorized) + return + } + + // webToken authenticates the user from the webapp. + webToken, err := createToken(user.Name, user.Email) + if err != nil { + writeError(w, &ErrInternal) + return + } + + writeJSON(w, struct { + WebToken string `json:"jwt"` + }{ + WebToken: webToken, + }, 201) +} + +func (s *Server) handleCreateAccessToken(w http.ResponseWriter, r *http.Request) { + user := userFromContext(r.Context()) + if user == nil { + writeError(w, &ErrInternal) + return + } + + accessToken, err := createToken(user.Name, user.Email) + if err != nil { + writeError(w, &ErrInternal) + return + } + + writeJSON(w, struct { + AccessToken string `json:"accessToken"` + }{ + AccessToken: accessToken, + }, 201) +} + +func createToken(name, email string) (string, error) { + claims := jwt.NewClaims(name, email, time.Now().Add(24*time.Hour)) + return jwt.Sign(claims) +} diff --git a/auth_middleware.go b/auth_middleware.go new file mode 100644 index 0000000..8e24690 --- /dev/null +++ b/auth_middleware.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/benchttp/server/benchttp" + "github.com/benchttp/server/jwt" +) + +type contextKey string + +const ( + userKey contextKey = "user" +) + +// mustAuth is a authentication middleware. It looks for +// a JWT inside the request headers and validates it. +// It the token is valid, mustAuth retreives the user +// associated in the claims and attaches it, if found, +// on request.Context. +func (s *Server) mustAuth(hf http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + signed, err := bearerToken(r) + if err != nil { + writeError(w, &ErrUnauthorized) + return + } + + token, err := jwt.Verify(signed) + if err != nil { + writeError(w, &ErrUnauthorized) + return + } + + claims, err := jwt.ClaimsOf(*token) + if err != nil { + writeError(w, &ErrUnauthorized) + return + } + + user, err := s.UserService.GetByEmail(claims.Email) + if err != nil { + writeError(w, &ErrUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), userKey, &user) + hf(w, r.WithContext(ctx)) + } +} + +// userFromContext returns the user that was set on request.Context +// by mustAuth middleware. Returns nil if no user is set in the context. +// A nil User must be treated as an internal error, as userFromContext +// must be called where we expect user to exist. +func userFromContext(ctx context.Context) *benchttp.User { + if u := ctx.Value(userKey); u != nil { + return u.(*benchttp.User) + } + return nil +} + +// bearerScheme is the string prefixing the key or token +// in the authorization headers: "Bearer ". +const bearerScheme = "Bearer " + +func bearerToken(r *http.Request) (string, error) { + ah := r.Header.Get("Authorization") + if !strings.HasPrefix(ah, bearerScheme) { + return "", errors.New("invalid authorization headers") + } + + return strings.TrimPrefix(ah, bearerScheme), nil +} diff --git a/benchttp/repository.go b/benchttp/repository.go index c4cd22c..04b785e 100644 --- a/benchttp/repository.go +++ b/benchttp/repository.go @@ -23,3 +23,17 @@ type StatsService interface { // when provided with a StatsDescriptor id. GetByID(id string) (Stats, error) } + +type UserService interface { + // Create creates and stores a User in the data layer + // and returns its ID. + Create(name, email string) (User, error) + + // GetByEmail retrieves a User by their email from + // the data layer. + GetByEmail(email string) (User, error) + + // Exists returns true if a user with email already exists + // in the data layer. By spec, a user email is unique. + Exists(email string) bool +} diff --git a/benchttp/user.go b/benchttp/user.go new file mode 100644 index 0000000..33522b1 --- /dev/null +++ b/benchttp/user.go @@ -0,0 +1,7 @@ +package benchttp + +type User struct { + ID int64 `json:"-"` + Name string `json:"name"` + Email string `json:"email"` +} diff --git a/cmd/main.go b/cmd/main.go index 57a20e8..cbf8b25 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -11,7 +11,9 @@ import ( "github.com/joho/godotenv" "github.com/benchttp/server" + "github.com/benchttp/server/jwt" "github.com/benchttp/server/services/firestore" + "github.com/benchttp/server/services/github" "github.com/benchttp/server/services/postgresql" ) @@ -74,11 +76,20 @@ func run() error { return err } - s, err := postgresql.NewStatsService(psqlConfig) + dbConn, err := postgresql.Connect(psqlConfig) if err != nil { return err } - srv := server.New(addr, rs, s) + ss := postgresql.NewStatsService(dbConn) + + us := postgresql.NewUserService(dbConn) + + secretKey := os.Getenv("JWT_SECRET") + jwt.SetSecretKey([]byte(secretKey)) + + oc := github.NewOauth("1234", "hi mom") + + srv := server.New(addr, rs, ss, us, oc) return srv.Start() } diff --git a/error.go b/error.go index d39e062..499c070 100644 --- a/error.go +++ b/error.go @@ -19,6 +19,10 @@ var ( Code: http.StatusInternalServerError, Message: http.StatusText(http.StatusInternalServerError), } + ErrUnauthorized = httpError{ + Code: http.StatusUnauthorized, + Message: http.StatusText(http.StatusUnauthorized), + } ) type httpError struct { diff --git a/go.mod b/go.mod index 8ff801b..6b81698 100644 --- a/go.mod +++ b/go.mod @@ -4,19 +4,20 @@ go 1.17 require ( cloud.google.com/go/firestore v1.6.1 + github.com/GoogleCloudPlatform/cloudsql-proxy v1.29.0 + github.com/golang-jwt/jwt v3.2.2+incompatible github.com/gorilla/mux v1.8.0 github.com/joho/godotenv v1.4.0 + github.com/lib/pq v1.10.4 ) require ( cloud.google.com/go v0.100.2 // indirect cloud.google.com/go/compute v1.3.0 // indirect - github.com/GoogleCloudPlatform/cloudsql-proxy v1.29.0 // indirect github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.7 // indirect github.com/googleapis/gax-go/v2 v2.1.1 // indirect - github.com/lib/pq v1.10.4 // indirect go.opencensus.io v0.23.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect diff --git a/go.sum b/go.sum index 9e4a05c..e01fc3d 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ cloud.google.com/go v0.87.0/go.mod h1:TpDYlFy7vuLzZMMZ+B6iRiELaY7z/gJPaqbMx6mlWc cloud.google.com/go v0.90.0/go.mod h1:kRX0mNRHe0e2rC6oNakvwQqzyDmg57xJ+SZU1eT2aDQ= cloud.google.com/go v0.93.3/go.mod h1:8utlLll2EF5XMAV15woO4lSbWQlk8rer9aLOfLh7+YI= cloud.google.com/go v0.94.1/go.mod h1:qAlAugsXlC+JWO+Bke5vCtc9ONxjQT3drlTTnAplMW4= -cloud.google.com/go v0.97.0 h1:3DXvAyifywvq64LfkKaMOmkWPS1CikIQdMe2lY9vxU8= cloud.google.com/go v0.97.0/go.mod h1:GF7l59pYBVlXQIBLx3a761cZ41F9bBH3JUlihCt2Udc= cloud.google.com/go v0.99.0/go.mod h1:w0Xx2nLzqWJPuozYQX+hFfCSI8WioryfRDzkoI/Y2ZA= cloud.google.com/go v0.100.2 h1:t9Iw5QH5v4XtlEQaCtUY7x6sCABps8sW0acw7e2WQ6Y= @@ -63,6 +62,7 @@ github.com/GoogleCloudPlatform/cloudsql-proxy v1.29.0/go.mod h1:spvB9eLJH9dutlbP github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= @@ -85,6 +85,7 @@ github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7 github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.12.0/go.mod h1:iiK0YP1ZeepvmBQk/QpLEhhTNJgfzrpArPY/aFvc9yU= github.com/dnaeon/go-vcr v1.2.0/go.mod h1:R4UdLID7HZT3taECzJs4YgbbH6PIGXB6W/sc5OLb6RQ= @@ -107,6 +108,8 @@ github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LB github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/sqlexp v0.0.0-20170517235910-f1bb20e5a188/go.mod h1:vXjM/+wXQnTPR4KqTKDgJukSZ6amVRtWMPEjE6sQoK8= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -155,7 +158,6 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.7 h1:81/ik6ipDQS2aGcBfIN5dHDB36BwrStyeAQquSYCV4o= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= @@ -256,7 +258,9 @@ github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -278,6 +282,7 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -300,6 +305,7 @@ go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= @@ -393,7 +399,6 @@ golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= -golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420 h1:a8jGStKg0XqKDlKqjLrXn0ioF5MH36pT7Z0BRTqLhbk= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210610132358-84b48f89b13b/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211020060615-d418f374d309 h1:A0lJIi+hcTR6aajJH4YqKWwohY4aW9RO7oRMcdv+HKI= @@ -413,7 +418,6 @@ golang.org/x/oauth2 v0.0.0-20210514164344-f6687ab2804c/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210628180205-a41e5a781914/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210805134026-6f1e6394065a/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20211005180243-6b3c2da341f1 h1:B333XXssMuKQeBwiNODx4TupZy7bf4sxFZnN2ZOcvUE= golang.org/x/oauth2 v0.0.0-20211005180243-6b3c2da341f1/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8 h1:RerP+noqYHUQ8CMRcPlC2nvTa4dcBIjegkuWdcUDuqg= golang.org/x/oauth2 v0.0.0-20211104180415-d3ed0bb246c8/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= @@ -481,7 +485,6 @@ golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210908233432-aa78b53d3365/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac h1:oN6lz7iLW/YC7un8pq+9bOLyXrprv2+DKfkJY+2LJJw= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211124211545-fe61309f8881/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -598,7 +601,6 @@ google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6 google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE= google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI= -google.golang.org/api v0.59.0 h1:fPfFO7gttlXYo2ALuD3HxJzh8vaF++4youI0BkFL6GE= google.golang.org/api v0.59.0/go.mod h1:sT2boj7M9YJxZzgeZqXogmhfmRWDtPzT31xkieUbuZU= google.golang.org/api v0.61.0/go.mod h1:xQRti5UdCmoCEqFxcz93fTl338AVqDgyaDRuOZ3hg9I= google.golang.org/api v0.63.0/go.mod h1:gs4ij2ffTRXwuzzgJl/56BdwJaA194ijkfn++9tDuPo= @@ -671,7 +673,6 @@ google.golang.org/genproto v0.0.0-20210903162649-d08c68adba83/go.mod h1:eFjDcFEc google.golang.org/genproto v0.0.0-20210909211513-a8c4777a87af/go.mod h1:eFjDcFEctNawg4eG61bRv87N7iHBWyVhJu7u1kqDUXY= google.golang.org/genproto v0.0.0-20210924002016-3dee208752a0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211008145708-270636b82663/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= -google.golang.org/genproto v0.0.0-20211028162531-8db9c33dc351 h1:uf3hR4mj3fn7tjJL1f0kkRqFE7GDPoBiyvLxvu1Gt/g= google.golang.org/genproto v0.0.0-20211028162531-8db9c33dc351/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211118181313-81c1377c94b1/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= google.golang.org/genproto v0.0.0-20211206160659-862468c7d6e0/go.mod h1:5CzLGKJ67TSI2B9POpiiyGha0AjJvZIUgRMt1dSmuhc= @@ -707,7 +708,6 @@ google.golang.org/grpc v1.37.1/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQ google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.39.0/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= google.golang.org/grpc v1.39.1/go.mod h1:PImNr+rS9TWYb2O4/emRugxiyHZ5JyHW5F+RPnDzfrE= -google.golang.org/grpc v1.40.0 h1:AGJ0Ih4mHjSeibYkFGh1dD9KJ/eOtZ93I6hoHhukQ5Q= google.golang.org/grpc v1.40.0/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.40.1/go.mod h1:ogyxbiOoUXAkP+4+xa6PZSE9DZgIHtSpzjDTB9KAK34= google.golang.org/grpc v1.44.0 h1:weqSxi/TMs1SqFRMHCtBgXRs8k3X39QIDEZ0pRcttUg= @@ -734,8 +734,10 @@ gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:a gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 0000000..99bee0a --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,93 @@ +package jwt + +import ( + "encoding/json" + "errors" + "time" + + "github.com/golang-jwt/jwt" +) + +var ( + ErrInvalidClaims = errors.New("invalid token claims") + ErrInvalidToken = errors.New("invalid token") +) + +// Claims is the custom claims used to signed JSON web token. +type Claims struct { + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` + jwt.StandardClaims +} + +// NewClaims returns Claims to provide to Sign. +func NewClaims(name, email string, exp time.Time) Claims { + return Claims{ + Name: name, + Email: email, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: exp.Unix(), + }, + } +} + +// Sign returns a new token as a signed a string. +func Sign(claims Claims) (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + signed, err := token.SignedString(key()) + if err != nil { + return "", ErrInvalidClaims + } + + return signed, nil +} + +// Verify parses and validates signed and returns a valid token. +// The token can then be used to access the Claims it holds. +func Verify(signed string) (*jwt.Token, error) { + // Parse returns an error if the token is invalid + // or if the signature does not match. + token, err := jwt.Parse(signed, func(token *jwt.Token) (interface{}, error) { + return key(), nil + }) + if err != nil { + return nil, ErrInvalidToken + } + + if !token.Valid { + return nil, ErrInvalidClaims + } + + return token, nil +} + +// ClaimsOf returns the underlying claims from token. +func ClaimsOf(token jwt.Token) (Claims, error) { + m, ok := token.Claims.(jwt.MapClaims) + if !ok { + return Claims{}, ErrInvalidClaims + } + + claims := Claims{} + + err := decodeMapClaims(m, &claims) + if err != nil { + return Claims{}, err + } + + return claims, nil +} + +func decodeMapClaims(m jwt.MapClaims, dst interface{}) error { + e, err := json.Marshal(m) + if err != nil { + return err + } + + if err := json.Unmarshal(e, dst); err != nil { + return err + } + + return nil +} diff --git a/jwt/secret.go b/jwt/secret.go new file mode 100644 index 0000000..0b419c2 --- /dev/null +++ b/jwt/secret.go @@ -0,0 +1,18 @@ +package jwt + +var secretKey []byte + +// SetSecretKey sets the secret key used for hashing JWT. +// It must be used before any usage of this package. +func SetSecretKey(key []byte) { + secretKey = key +} + +// key is a safe way to access secretKey. It panics at runtime +// if secretKey is not set. +func key() []byte { + if len(secretKey) == 0 { + panic("secretKey must be set before the use of package jwt") + } + return secretKey +} diff --git a/routes.go b/routes.go index 4ffa420..75c26da 100644 --- a/routes.go +++ b/routes.go @@ -38,17 +38,24 @@ func (s *Server) registerRoutes() { v1 := s.router.PathPrefix("/v1").Subrouter() - v1.HandleFunc("/reports", s.createReport).Methods("POST") + // Auth + v1.HandleFunc("/signin", s.handleSignin).Methods("POST") + v1.HandleFunc("/token", s.mustAuth(s.handleCreateAccessToken)).Methods("GET") - v1.HandleFunc("/reports/"+idPathVar, s.retrieveReport).Methods("GET") + // Users + v1.HandleFunc("/user", s.mustAuth(s.selfUser)).Methods("GET") - v1.HandleFunc("/stats", s.retrieveAllStats).Methods("GET") + // Reports + v1.HandleFunc("/reports", s.mustAuth(s.createReport)).Methods("POST") + v1.HandleFunc("/reports/"+idPathVar, s.mustAuth(s.retrieveReport)).Methods("GET") - v1.HandleFunc("/stats/"+idPathVar, s.retrieveStatsByID).Methods("GET") + // Stats + v1.HandleFunc("/stats", s.mustAuth(s.retrieveAllStats)).Methods("GET") + v1.HandleFunc("/stats/"+idPathVar, s.mustAuth(s.retrieveStatsByID)).Methods("GET") } -func handleRoot(rw http.ResponseWriter, _ *http.Request) { - rw.Header().Set("Content-Type", "text/html; charset=utf-8") - rw.WriteHeader(200) - rw.Write([]byte("⚡")) //nolint +func handleRoot(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(200) + w.Write([]byte("⚡")) //nolint } diff --git a/server.go b/server.go index 0ad3b34..6323fb7 100644 --- a/server.go +++ b/server.go @@ -9,6 +9,7 @@ import ( "github.com/benchttp/server/benchttp" "github.com/benchttp/server/httplog" + "github.com/benchttp/server/services/github" ) const maxBytesRead = 1 << 20 // 1 MiB @@ -19,14 +20,23 @@ type Server struct { ReportService benchttp.ReportService StatsService benchttp.StatsService + UserService benchttp.UserService + + OAuthClient github.OAuthClient } // New returns a Server with specified configuration parameters. -func New(addr string, rs benchttp.ReportService, s benchttp.StatsService) *Server { +func New(addr string, + rs benchttp.ReportService, ss benchttp.StatsService, us benchttp.UserService, + oauthClient github.OAuthClient, +) *Server { + // return &Server{ Server: &http.Server{Addr: addr}, ReportService: rs, - StatsService: s, + StatsService: ss, + UserService: us, + OAuthClient: oauthClient, } } diff --git a/services/github/error.go b/services/github/error.go new file mode 100644 index 0000000..1d56901 --- /dev/null +++ b/services/github/error.go @@ -0,0 +1,9 @@ +package github + +import "errors" + +var ( + errRequest = errors.New("error sending request") + errStatusCode = errors.New("unexpected status (expected 200)") + errParse = errors.New("error parsing response") +) diff --git a/services/github/oauth.go b/services/github/oauth.go new file mode 100644 index 0000000..b1c26cc --- /dev/null +++ b/services/github/oauth.go @@ -0,0 +1,94 @@ +package github + +import ( + "encoding/json" + "fmt" + "net/http" +) + +const ( + exchangeURL = "https://github.com/login/oauth/access_token" + userURL = "https://api.github.com/user" +) + +type OAuthClient struct { + id string + secret string +} + +func NewOauth(id, secret string) OAuthClient { + return OAuthClient{ + id: id, + secret: secret, + } +} + +func (c OAuthClient) ExchangeForAccessToken(code string) (string, error) { + req, err := http.NewRequest("POST", exchangeURL, nil) + if err != nil { + return "", err + } + + // Parameters are expected to be inside the query. + q := req.URL.Query() + q.Set("client_id", c.id) + q.Set("client_secret", c.secret) + q.Set("code", code) + req.URL.RawQuery = q.Encode() + + // GitHub defaults to encode the response in plain text. + // Explicitly asks for JSON encoding. + req.Header.Set("Accept", "application/json") + + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("%w: %s", errRequest, err) + } + defer res.Body.Close() + + if res.StatusCode != 200 { + return "", fmt.Errorf("%w: %s: %s", errStatusCode, res.Status, err) + } + + var t struct { + AccessToken string `json:"access_token"` + } + + if err := json.NewDecoder(res.Body).Decode(&t); err != nil { + return "", fmt.Errorf("%s: %s", errParse, err) + } + + return t.AccessToken, nil +} + +func (c OAuthClient) GetUser(token string) (name, email string, err error) { + req, err := http.NewRequest("GET", userURL, nil) + if err != nil { + return "", "", err + } + + req.Header.Set("Authorization", fmt.Sprintf("token %s", token)) + + res, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", fmt.Errorf("%w: %s", errRequest, err) + } + defer res.Body.Close() + + if res.StatusCode != 200 { + return "", "", fmt.Errorf("%w: %s: %s", errStatusCode, res.Status, err) + } + + var u struct { + Name string `json:"name"` + Email string `json:"email"` + // All the available properties are defined in GitHub docs: + // https://docs.github.com/en/rest/reference/users#get-the-authenticated-user + } + + if err := json.NewDecoder(res.Body).Decode(&u); err != nil { + return "", "", fmt.Errorf("%s: %s", errParse, err) + } + + return u.Name, u.Email, nil +} diff --git a/services/postgresql/connection.go b/services/postgresql/connection.go index b62502f..ae7a3b0 100644 --- a/services/postgresql/connection.go +++ b/services/postgresql/connection.go @@ -7,11 +7,24 @@ import ( _ "github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/dialers/postgres" // blank import ) -type StatsService struct { +type Config struct { + Host string + User string + Password string + DBName string + IdleConn int + MaxConn int +} + +// Connection holds the connection to a PostgreSQL database. +// It must be passed to a service constructor. +type Connection struct { db *sql.DB } -func NewStatsService(config Config) (StatsService, error) { +// Connect opens a connection with a PostgreSQL database and +// returns a Connection to utilize it. +func Connect(config Config) (Connection, error) { dbURI := fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=disable", config.Host, config.User, @@ -20,25 +33,16 @@ func NewStatsService(config Config) (StatsService, error) { db, err := sql.Open("cloudsqlpostgres", dbURI) if err != nil { - return StatsService{}, ErrDatabaseConnection + return Connection{}, ErrDatabaseConnection } err = db.Ping() if err != nil { - return StatsService{}, ErrDatabasePing + return Connection{}, ErrDatabasePing } db.SetMaxIdleConns(config.IdleConn) db.SetMaxOpenConns(config.MaxConn) - return StatsService{db}, nil -} - -type Config struct { - Host string - User string - Password string - DBName string - IdleConn int - MaxConn int + return Connection{db}, nil } diff --git a/services/postgresql/error.go b/services/postgresql/error.go index f40e1b1..4e85ec0 100644 --- a/services/postgresql/error.go +++ b/services/postgresql/error.go @@ -20,4 +20,7 @@ var ( // ErrScanningRows is returned when server fails to scan the rows // returned by a query. ErrScanningRows = errors.New("error trying to scan result rows") + // ErrGettingIDInsertion is returned when server fails to get the ID + // of an object that has just been inserted into the database. + ErrGettingIDInsertion = errors.New("error retrieving ID of the row inserted") ) diff --git a/services/postgresql/stats.go b/services/postgresql/stats.go index 1536ef3..bab46fb 100644 --- a/services/postgresql/stats.go +++ b/services/postgresql/stats.go @@ -1,11 +1,24 @@ package postgresql import ( + "database/sql" + "github.com/lib/pq" "github.com/benchttp/server/benchttp" ) +// Ensure service implements interface. +var _ benchttp.StatsService = (*StatsService)(nil) + +type StatsService struct { + db *sql.DB +} + +func NewStatsService(conn Connection) StatsService { + return StatsService{conn.db} +} + func (s StatsService) ListAvailable(userID string) ([]benchttp.StatsDescriptor, error) { list := []benchttp.StatsDescriptor{} diff --git a/services/postgresql/user.go b/services/postgresql/user.go new file mode 100644 index 0000000..4ad7166 --- /dev/null +++ b/services/postgresql/user.go @@ -0,0 +1,64 @@ +package postgresql + +import ( + "database/sql" + + "github.com/benchttp/server/benchttp" +) + +// Ensure service implements interface. +var _ benchttp.UserService = (*UserService)(nil) + +type UserService struct { + db *sql.DB +} + +func NewUserService(conn Connection) UserService { + return UserService{conn.db} +} + +func (s UserService) Create(name, email string) (benchttp.User, error) { + user := benchttp.User{Name: name, Email: email} + stmt := ` +INSERT INTO public.users(name, email) +VALUES ($1, $2)`[1:] + result, err := s.db.Exec(stmt, name, email) + if err != nil { + return user, ErrExecutingPreparedStmt + } + id, err := result.LastInsertId() + if err != nil { + return user, ErrGettingIDInsertion + } + user.ID = id + return user, nil +} + +func (s UserService) GetByEmail(email string) (benchttp.User, error) { + user := benchttp.User{Email: email} + stmt := ` +SELECT id, name +FROM public.users +WHERE email = $1`[1:] + row := s.db.QueryRow(stmt, email) + err := row.Scan( + &user.ID, + &user.Name) + if err != nil { + return user, ErrScanningRows + } + return user, nil +} + +func (s UserService) Exists(email string) bool { + var exists bool + stmt := ` +SELECT EXISTS( +SELECT email +FROM public.users WHERE email=$1)`[1:] + err := s.db.QueryRow(stmt, email).Scan(&exists) + if err != nil && err != sql.ErrNoRows { + return false + } + return exists +} diff --git a/user.go b/user.go new file mode 100644 index 0000000..231db43 --- /dev/null +++ b/user.go @@ -0,0 +1,15 @@ +package server + +import ( + "net/http" +) + +func (s *Server) selfUser(w http.ResponseWriter, r *http.Request) { + user := userFromContext(r.Context()) + if user == nil { + writeError(w, &ErrInternal) + return + } + + writeJSON(w, user, 200) +}