diff --git a/.env b/.env index 7e43ac9..e45c7b4 100644 --- a/.env +++ b/.env @@ -1,3 +1,3 @@ -DATABASE_URL='postgresql://localhost?sslmode=disable' +DATABASE_URL='postgresql://localhost/cerealnotes?sslmode=disable' PORT=8080 TOKEN_SIGNING_KEY='AllYourBase' diff --git a/README.md b/README.md index fe1ec89..7bff1d2 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,6 @@ # Installation ## Locally -* postgres server installed and running - * `brew install postgres` - * `pg_ctl -D /usr/local/var/postgres start` +* postgres server installed and running: please refer to `migrations/README.md` for more info * heroku cli installed * `brew install heroku` * golang installed @@ -20,7 +18,4 @@ Assuming your local environment is setup correctly with Golang standards, you ca 1. `cd to this repo` 2. `go install && heroku local` -3. Visit `localhost:8080/login-or-signup` - -# Run DB migrations -More db information in `migrations/README.md` +3. Visit `localhost:8080/` diff --git a/databaseutil/databaseutil.go b/databaseutil/databaseutil.go deleted file mode 100644 index 02e24cf..0000000 --- a/databaseutil/databaseutil.go +++ /dev/null @@ -1,191 +0,0 @@ -/* -Package databaseutil abstracts away details about sql and postgres. - -These functions only accept and return primitive types. -*/ -package databaseutil - -import ( - "database/sql" - "errors" - "time" - - "github.com/lib/pq" -) - -var db *sql.DB - -// UniqueConstraintError is returned when a uniqueness constraint is violated during an insert. -var UniqueConstraintError = errors.New("postgres: unique constraint violation") - -// QueryResultContainedMultipleRowsError is returned when a query unexpectedly returns more than one row. -var QueryResultContainedMultipleRowsError = errors.New("query result unexpectedly contained multiple rows") - -// QueryResultContainedNoRowsError is returned when a query unexpectedly returns no rows. -var QueryResultContainedNoRowsError = errors.New("query result unexpectedly contained no rows") - -// ConnectToDatabase also pings the database to ensure a working connection. -func ConnectToDatabase(databaseUrl string) error { - { - tempDb, err := sql.Open("postgres", databaseUrl) - if err != nil { - return err - } - - db = tempDb - } - - if err := db.Ping(); err != nil { - return err - } - - return nil -} - -func InsertIntoUserTable( - displayName string, - emailAddress string, - password []byte, - creationTime time.Time, -) error { - sqlQuery := ` - INSERT INTO app_user (display_name, email_address, password, creation_time) - VALUES ($1, $2, $3, $4)` - - rows, err := db.Query(sqlQuery, displayName, emailAddress, password, creationTime) - if err != nil { - return convertPostgresError(err) - } - defer rows.Close() - - if err := rows.Err(); err != nil { - return convertPostgresError(err) - } - - return nil -} - -func GetPasswordForUserWithEmailAddress(emailAddress string) ([]byte, error) { - sqlQuery := ` - SELECT password FROM app_user - WHERE email_address = $1` - - rows, err := db.Query(sqlQuery, emailAddress) - if err != nil { - return nil, convertPostgresError(err) - } - defer rows.Close() - - var password []byte - for rows.Next() { - if password != nil { - return nil, QueryResultContainedMultipleRowsError - } - - if err := rows.Scan(&password); err != nil { - return nil, err - } - } - - if password == nil { - return nil, QueryResultContainedNoRowsError - } - - return password, nil -} - -func InsertNewNote(authorId int64, content string, creationTime time.Time) (int64, error) { - sqlQuery := ` - INSERT INTO note (author_id, content, creation_time) - VALUES ($1, $2, $3) - RETURNING id` - - rows, err := db.Query(sqlQuery, authorId, content, creationTime) - if err != nil { - return 0, convertPostgresError(err) - } - defer rows.Close() - - var noteId int64 = 0 - for rows.Next() { - - if noteId != 0 { - return 0, QueryResultContainedMultipleRowsError - } - - if err := rows.Scan(¬eId); err != nil { - return 0, convertPostgresError(err) - } - } - - if noteId == 0 { - return 0, QueryResultContainedNoRowsError - } - - if err := rows.Err(); err != nil { - return 0, convertPostgresError(err) - } - - return noteId, nil -} - -func InsertNoteCategoryRelationship(noteId int64, category string) error { - sqlQuery := ` - INSERT INTO note_to_category_relationship (note_id, category) - VALUES ($1, $2)` - - rows, err := db.Query(sqlQuery, noteId, category) - if err != nil { - return convertPostgresError(err) - } - defer rows.Close() - - if err := rows.Err(); err != nil { - return convertPostgresError(err) - } - - return nil -} - -func GetIdForUserWithEmailAddress(emailAddress string) (int64, error) { - sqlQuery := ` - SELECT id FROM app_user - WHERE email_address = $1` - - rows, err := db.Query(sqlQuery, emailAddress) - if err != nil { - return 0, convertPostgresError(err) - } - defer rows.Close() - - var userId int64 - for rows.Next() { - if userId != 0 { - return 0, QueryResultContainedMultipleRowsError - } - - if err := rows.Scan(&userId); err != nil { - return 0, err - } - } - - if userId == 0 { - return 0, QueryResultContainedNoRowsError - } - - return userId, nil -} - -// PRIVATE - -func convertPostgresError(err error) error { - const uniqueConstraintErrorCode = "23505" - - if postgresErr, ok := err.(*pq.Error); ok { - if postgresErr.Code == uniqueConstraintErrorCode { - return UniqueConstraintError - } - } - - return err -} diff --git a/handlers/handlers.go b/handlers/handlers.go index d64844a..a8a446c 100644 --- a/handlers/handlers.go +++ b/handlers/handlers.go @@ -11,8 +11,6 @@ import ( "github.com/atmiguel/cerealnotes/models" "github.com/atmiguel/cerealnotes/paths" - "github.com/atmiguel/cerealnotes/services/noteservice" - "github.com/atmiguel/cerealnotes/services/userservice" "github.com/dgrijalva/jwt-go" ) @@ -28,10 +26,15 @@ type JwtTokenClaim struct { jwt.StandardClaims } -var tokenSigningKey []byte +type Environment struct { + Db models.Datastore + TokenSigningKey []byte +} -func SetTokenSigningKey(key []byte) { - tokenSigningKey = key +func WrapUnauthenticatedEndpoint(env *Environment, handler UnauthenticatedEndpointHandlerType) http.HandlerFunc { + return func(responseWriter http.ResponseWriter, request *http.Request) { + handler(env, responseWriter, request) + } } // UNAUTHENTICATED HANDLERS @@ -39,12 +42,13 @@ func SetTokenSigningKey(key []byte) { // HandleLoginOrSignupPageRequest responds to unauthenticated GET requests with the login or signup page. // For authenticated requests, it redirects to the home page. func HandleLoginOrSignupPageRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, ) { switch request.Method { case http.MethodGet: - if _, err := getUserIdFromJwtToken(request); err == nil { + if _, err := getUserIdFromJwtToken(env, request); err == nil { http.Redirect( responseWriter, request, @@ -67,6 +71,7 @@ func HandleLoginOrSignupPageRequest( } func HandleUserApiRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, ) { @@ -86,12 +91,12 @@ func HandleUserApiRequest( } var statusCode int - if err := userservice.StoreNewUser( + if err := env.Db.StoreNewUser( signupForm.DisplayName, models.NewEmailAddress(signupForm.EmailAddress), signupForm.Password, ); err != nil { - if err == userservice.EmailAddressAlreadyInUseError { + if err == models.EmailAddressAlreadyInUseError { statusCode = http.StatusConflict } else { http.Error(responseWriter, err.Error(), http.StatusInternalServerError) @@ -105,7 +110,7 @@ func HandleUserApiRequest( case http.MethodGet: - if _, err := getUserIdFromJwtToken(request); err != nil { + if _, err := getUserIdFromJwtToken(env, request); err != nil { http.Error(responseWriter, err.Error(), http.StatusUnauthorized) return } @@ -136,6 +141,7 @@ func HandleUserApiRequest( // HandleSessionApiRequest responds to POST requests by authenticating and responding with a JWT. // It responds to DELETE requests by expiring the client's cookie. func HandleSessionApiRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, ) { @@ -153,12 +159,12 @@ func HandleSessionApiRequest( return } - if err := userservice.AuthenticateUserCredentials( + if err := env.Db.AuthenticateUserCredentials( models.NewEmailAddress(loginForm.EmailAddress), loginForm.Password, ); err != nil { statusCode := http.StatusInternalServerError - if err == userservice.CredentialsNotAuthorizedError { + if err == models.CredentialsNotAuthorizedError { statusCode = http.StatusUnauthorized } http.Error(responseWriter, err.Error(), statusCode) @@ -167,13 +173,13 @@ func HandleSessionApiRequest( // Set our cookie to have a valid JWT Token as the value { - userId, err := userservice.GetIdForUserWithEmailAddress(models.NewEmailAddress(loginForm.EmailAddress)) + userId, err := env.Db.GetIdForUserWithEmailAddress(models.NewEmailAddress(loginForm.EmailAddress)) if err != nil { http.Error(responseWriter, err.Error(), http.StatusInternalServerError) return } - token, err := createTokenAsString(userId, credentialTimeoutDuration) + token, err := CreateTokenAsString(env, userId, credentialTimeoutDuration) if err != nil { http.Error(responseWriter, err.Error(), http.StatusInternalServerError) return @@ -217,6 +223,7 @@ func HandleSessionApiRequest( } func HandleNoteApiRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, @@ -224,7 +231,7 @@ func HandleNoteApiRequest( switch request.Method { case http.MethodGet: - var notesById noteservice.NotesById = make(map[models.NoteId]*models.Note, 2) + var notesById models.NotesById = make(map[models.NoteId]*models.Note, 2) notesById[models.NoteId(1)] = &models.Note{ AuthorId: 1, @@ -272,7 +279,7 @@ func HandleNoteApiRequest( CreationTime: time.Now().UTC(), } - noteId, err := noteservice.StoreNewNote(note) + noteId, err := env.Db.StoreNewNote(note) if err != nil { http.Error(responseWriter, err.Error(), http.StatusInternalServerError) return @@ -299,6 +306,7 @@ func HandleNoteApiRequest( } func HandleNoteCateogryApiRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, @@ -309,25 +317,25 @@ func HandleNoteCateogryApiRequest( id, err := strconv.ParseInt(request.URL.Query().Get("id"), 10, 64) noteId := models.NoteId(id) - type CategoryForm struct { - Category string `json:"category"` + type NoteCategoryForm struct { + NoteCategory string `json:"category"` } - categoryForm := new(CategoryForm) + categoryForm := new(NoteCategoryForm) if err := json.NewDecoder(request.Body).Decode(categoryForm); err != nil { http.Error(responseWriter, err.Error(), http.StatusBadRequest) return } - category, err := models.DeserializeCategory(categoryForm.Category) + category, err := models.DeserializeNoteCategory(categoryForm.NoteCategory) if err != nil { http.Error(responseWriter, err.Error(), http.StatusBadRequest) return } - if err := noteservice.StoreNewNoteCategoryRelationship(models.NoteId(noteId), category); err != nil { + if err := env.Db.StoreNewNoteCategoryRelationship(models.NoteId(noteId), category); err != nil { http.Error(responseWriter, err.Error(), http.StatusInternalServerError) return } @@ -341,16 +349,25 @@ func HandleNoteCateogryApiRequest( } type AuthenticatedRequestHandlerType func( + *Environment, + http.ResponseWriter, + *http.Request, + models.UserId, +) + +type UnauthenticatedEndpointHandlerType func( + *Environment, http.ResponseWriter, *http.Request, - models.UserId) +) func AuthenticateOrRedirect( + env *Environment, authenticatedHandlerFunc AuthenticatedRequestHandlerType, redirectPath string, ) http.HandlerFunc { return func(responseWriter http.ResponseWriter, request *http.Request) { - if userId, err := getUserIdFromJwtToken(request); err != nil { + if userId, err := getUserIdFromJwtToken(env, request); err != nil { switch request.Method { // If not logged in, redirect to login page case http.MethodGet: @@ -364,21 +381,22 @@ func AuthenticateOrRedirect( respondWithMethodNotAllowed(responseWriter, http.MethodGet) } } else { - authenticatedHandlerFunc(responseWriter, request, userId) + authenticatedHandlerFunc(env, responseWriter, request, userId) } } } func AuthenticateOrReturnUnauthorized( + env *Environment, authenticatedHandlerFunc AuthenticatedRequestHandlerType, ) http.HandlerFunc { return func(responseWriter http.ResponseWriter, request *http.Request) { - if userId, err := getUserIdFromJwtToken(request); err != nil { + if userId, err := getUserIdFromJwtToken(env, request); err != nil { responseWriter.Header().Set("WWW-Authenticate", `Bearer realm="`+request.URL.Path+`"`) http.Error(responseWriter, err.Error(), http.StatusUnauthorized) } else { - authenticatedHandlerFunc(responseWriter, request, userId) + authenticatedHandlerFunc(env, responseWriter, request, userId) } } } @@ -404,6 +422,7 @@ func RedirectToPathHandler( // AUTHENTICATED HANDLERS func HandleHomePageRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, @@ -423,6 +442,7 @@ func HandleHomePageRequest( } func HandleNotesPageRequest( + env *Environment, responseWriter http.ResponseWriter, request *http.Request, userId models.UserId, diff --git a/handlers/tokenutil.go b/handlers/tokenutil.go index 7d379c5..b4f2f6b 100644 --- a/handlers/tokenutil.go +++ b/handlers/tokenutil.go @@ -2,8 +2,6 @@ package handlers import ( "errors" - "fmt" - "log" "net/http" "strings" "time" @@ -14,16 +12,17 @@ import ( var InvalidJWTokenError = errors.New("Token was invalid or unreadable") -func parseTokenFromString(tokenAsString string) (*jwt.Token, error) { +func ParseTokenFromString(env *Environment, tokenAsString string) (*jwt.Token, error) { return jwt.ParseWithClaims( strings.TrimSpace(tokenAsString), &JwtTokenClaim{}, func(*jwt.Token) (interface{}, error) { - return tokenSigningKey, nil + return env.TokenSigningKey, nil }) } -func createTokenAsString( +func CreateTokenAsString( + env *Environment, userId models.UserId, durationTilExpiration time.Duration, ) (string, error) { @@ -36,16 +35,16 @@ func createTokenAsString( } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - return token.SignedString(tokenSigningKey) + return token.SignedString(env.TokenSigningKey) } -func getUserIdFromJwtToken(request *http.Request) (models.UserId, error) { +func getUserIdFromJwtToken(env *Environment, request *http.Request) (models.UserId, error) { cookie, err := request.Cookie(cerealNotesCookieName) if err != nil { return 0, err } - token, err := parseTokenFromString(cookie.Value) + token, err := ParseTokenFromString(env, cookie.Value) if err != nil { return 0, err } @@ -56,27 +55,3 @@ func getUserIdFromJwtToken(request *http.Request) (models.UserId, error) { return 0, InvalidJWTokenError } - -func tokenTest1() { - var num models.UserId = 32 - bob, err := createTokenAsString(num, 1) - if err != nil { - fmt.Println("create error") - log.Fatal(err) - } - - token, err := parseTokenFromString(bob) - if err != nil { - fmt.Println("parse error") - log.Fatal(err) - } - fmt.Println(bob) - if claims, ok := token.Claims.(*JwtTokenClaim); ok && token.Valid { - if claims.UserId != 32 { - log.Fatal("error in token") - } - fmt.Printf("%v %v", claims.UserId, claims.StandardClaims.ExpiresAt) - } else { - fmt.Println("Token claims could not be read") - } -} diff --git a/handlers/tokenutil_test.go b/handlers/tokenutil_test.go new file mode 100644 index 0000000..e92ad84 --- /dev/null +++ b/handlers/tokenutil_test.go @@ -0,0 +1,34 @@ +package handlers_test + +import ( + "fmt" + "testing" + + "github.com/atmiguel/cerealnotes/handlers" + "github.com/atmiguel/cerealnotes/models" + "github.com/atmiguel/cerealnotes/test_util" +) + +func TestToken(t *testing.T) { + env := &handlers.Environment{nil, []byte("TheWorld")} + + var num models.UserId = 32 + bob, err := handlers.CreateTokenAsString(env, num, 1) + if err != nil { + panic(err) + } + + token, err := handlers.ParseTokenFromString(env, bob) + if err != nil { + panic(err) + } + + if claims, ok := token.Claims.(*handlers.JwtTokenClaim); ok && token.Valid { + test_util.Equals(t, int64(32), int64(claims.UserId)) + + fmt.Printf("%v %v", claims.UserId, claims.StandardClaims.ExpiresAt) + } else { + fmt.Println("Token claims could not be read") + t.FailNow() + } +} diff --git a/integration_test.go b/integration_test.go new file mode 100644 index 0000000..434fb7a --- /dev/null +++ b/integration_test.go @@ -0,0 +1,201 @@ +package main_test + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "strconv" + "testing" + + "github.com/atmiguel/cerealnotes/handlers" + "github.com/atmiguel/cerealnotes/models" + "github.com/atmiguel/cerealnotes/paths" + "github.com/atmiguel/cerealnotes/routers" + "github.com/atmiguel/cerealnotes/test_util" +) + +func TestLoginOrSignUpPage(t *testing.T) { + mockDb := &DiyMockDataStore{} + env := &handlers.Environment{mockDb, []byte("")} + + server := httptest.NewServer(routers.DefineRoutes(env)) + defer server.Close() + + resp, err := http.Get(server.URL) + test_util.Ok(t, err) + test_util.Equals(t, http.StatusOK, resp.StatusCode) +} + +func TestAuthenticatedFlow(t *testing.T) { + mockDb := &DiyMockDataStore{} + env := &handlers.Environment{mockDb, []byte("")} + + server := httptest.NewServer(routers.DefineRoutes(env)) + defer server.Close() + + // Create testing client + client := &http.Client{} + { + jar, err := cookiejar.New(&cookiejar.Options{}) + + if err != nil { + panic(err) + } + + client.Jar = jar + } + + // Test login + userIdAsInt := int64(1) + { + theEmail := "justsomeemail@gmail.com" + thePassword := "worldsBestPassword" + + mockDb.Func_AuthenticateUserCredentials = func(email *models.EmailAddress, password string) error { + if email.String() == theEmail && password == thePassword { + return nil + } + + return models.CredentialsNotAuthorizedError + } + + mockDb.Func_GetIdForUserWithEmailAddress = func(email *models.EmailAddress) (models.UserId, error) { + return models.UserId(userIdAsInt), nil + } + + userValues := map[string]string{"emailAddress": theEmail, "password": thePassword} + + userJsonValue, _ := json.Marshal(userValues) + + resp, err := client.Post(server.URL+paths.SessionApi, "application/json", bytes.NewBuffer(userJsonValue)) + + test_util.Ok(t, err) + + test_util.Equals(t, http.StatusCreated, resp.StatusCode) + } + + // Test Add Note + noteIdAsInt := int64(33) + content := "Duuude I just said something cool" + { + noteValues := map[string]string{"content": content} + + mockDb.Func_StoreNewNote = func(*models.Note) (models.NoteId, error) { + return models.NoteId(noteIdAsInt), nil + } + + noteJsonValue, _ := json.Marshal(noteValues) + + resp, err := client.Post(server.URL+paths.NoteApi, "application/json", bytes.NewBuffer(noteJsonValue)) + test_util.Ok(t, err) + test_util.Equals(t, http.StatusCreated, resp.StatusCode) + + type NoteResponse struct { + NoteId int64 `json:"noteId"` + } + + jsonNoteReponse := &NoteResponse{} + + err = json.NewDecoder(resp.Body).Decode(jsonNoteReponse) + test_util.Ok(t, err) + + test_util.Equals(t, noteIdAsInt, jsonNoteReponse.NoteId) + + resp.Body.Close() + } + + // Test get notes + { + resp, err := client.Get(server.URL + paths.NoteApi) + test_util.Ok(t, err) + test_util.Equals(t, http.StatusOK, resp.StatusCode) + + // TODO when we implement a real get notes feature we should enhance this code. + } + + // Test Add category + { + type NoteCategoryForm struct { + NoteCategory string `json:"category"` + } + + metaNoteCategory := models.META + + categoryForm := &NoteCategoryForm{NoteCategory: metaNoteCategory.String()} + + mockDb.Func_StoreNewNoteCategoryRelationship = func(noteId models.NoteId, cat models.NoteCategory) error { + if int64(noteId) == noteIdAsInt && cat == metaNoteCategory { + return nil + } + + return errors.New("Incorrect Data Arrived") + } + + jsonValue, _ := json.Marshal(categoryForm) + + resp, err := sendPutRequest(client, server.URL+paths.NoteCategoryApi+"?id="+strconv.FormatInt(noteIdAsInt, 10), "application/json", bytes.NewBuffer(jsonValue)) + test_util.Ok(t, err) + test_util.Equals(t, http.StatusCreated, resp.StatusCode) + } + +} + +func sendPutRequest(client *http.Client, myUrl string, contentType string, body io.Reader) (resp *http.Response, err error) { + req, err := http.NewRequest("PUT", myUrl, body) + + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", contentType) + return client.Do(req) +} + +func printBody(resp *http.Response) { + buf, bodyErr := ioutil.ReadAll(resp.Body) + if bodyErr != nil { + fmt.Print("bodyErr ", bodyErr.Error()) + return + } + + rdr1 := ioutil.NopCloser(bytes.NewBuffer(buf)) + rdr2 := ioutil.NopCloser(bytes.NewBuffer(buf)) + fmt.Printf("BODY: %q", rdr1) + resp.Body = rdr2 +} + +// Helpers + +type DiyMockDataStore struct { + Func_StoreNewNote func(*models.Note) (models.NoteId, error) + Func_StoreNewNoteCategoryRelationship func(models.NoteId, models.NoteCategory) error + Func_StoreNewUser func(string, *models.EmailAddress, string) error + Func_AuthenticateUserCredentials func(*models.EmailAddress, string) error + Func_GetIdForUserWithEmailAddress func(*models.EmailAddress) (models.UserId, error) +} + +func (mock *DiyMockDataStore) StoreNewNote(note *models.Note) (models.NoteId, error) { + return mock.Func_StoreNewNote(note) +} + +func (mock *DiyMockDataStore) StoreNewNoteCategoryRelationship(noteId models.NoteId, cat models.NoteCategory) error { + return mock.Func_StoreNewNoteCategoryRelationship(noteId, cat) +} + +func (mock *DiyMockDataStore) StoreNewUser(str1 string, email *models.EmailAddress, str2 string) error { + return mock.Func_StoreNewUser(str1, email, str2) +} + +func (mock *DiyMockDataStore) AuthenticateUserCredentials(email *models.EmailAddress, str string) error { + return mock.Func_AuthenticateUserCredentials(email, str) +} + +func (mock *DiyMockDataStore) GetIdForUserWithEmailAddress(email *models.EmailAddress) (models.UserId, error) { + return mock.Func_GetIdForUserWithEmailAddress(email) +} diff --git a/main.go b/main.go index 60fb3ab..f02a2f3 100644 --- a/main.go +++ b/main.go @@ -6,8 +6,8 @@ import ( "net/http" "os" - "github.com/atmiguel/cerealnotes/databaseutil" "github.com/atmiguel/cerealnotes/handlers" + "github.com/atmiguel/cerealnotes/models" "github.com/atmiguel/cerealnotes/routers" ) @@ -53,15 +53,23 @@ func determineTokenSigningKey() ([]byte, error) { func main() { // Set up db + + env := &handlers.Environment{} + { databaseUrl, err := determineDatabaseUrl() if err != nil { log.Fatal(err) } - if err := databaseutil.ConnectToDatabase(databaseUrl); err != nil { + db, err := models.ConnectToDatabase(databaseUrl) + + if err != nil { log.Fatal(err) } + + env.Db = db + } // Set up token signing key @@ -70,8 +78,7 @@ func main() { if err != nil { log.Fatal(err) } - - handlers.SetTokenSigningKey(tokenSigningKey) + env.TokenSigningKey = tokenSigningKey } // Start server @@ -83,7 +90,7 @@ func main() { log.Printf("Listening on %s...\n", port) - if err := http.ListenAndServe(port, routers.DefineRoutes()); err != nil { + if err := http.ListenAndServe(port, routers.DefineRoutes(env)); err != nil { log.Fatal(err) } } diff --git a/migrations/README.md b/migrations/README.md index 9111fcf..0de104c 100644 --- a/migrations/README.md +++ b/migrations/README.md @@ -1,9 +1,14 @@ -# Locally: -1. first make sure that postgres is running. - * If installed via homebrew on a macOS: `pg_ctl -D /usr/local/var/postgres start` -2. then run migration locally - * `psql < *MIGRATION_NAME*` +# Locally: +1. install & setup postgres + * `brew install postgres` + * ``createdb `whoami` `` +2. Run postgres daemon. + * `pg_ctl start -D /usr/local/var/postgres` +3. Create cerealnotes databases + * `psql < tools/createDatabases.sql` +3. Run all the migrations on both "unittest" database (`cerealnotes_test`) and as well as the "live" database (`cerealnotes`). + * `psql [DATABASENAME] < [MIGRATION_NAME]` -# On Heroku: +# On Heroku: -1. `heroku pg:psql < *MIGRATION_NAME*` +1. `heroku pg:psql < [MIGRATION_NAME]` \ No newline at end of file diff --git a/migrations/tools/createDatabases.sql b/migrations/tools/createDatabases.sql new file mode 100644 index 0000000..6e6b188 --- /dev/null +++ b/migrations/tools/createDatabases.sql @@ -0,0 +1,2 @@ +CREATE DATABASE cerealnotes_test; +CREATE DATABASE cerealnotes; \ No newline at end of file diff --git a/migrations/tools/drop_everything.sql b/migrations/tools/drop_everything.sql index ac4fd8a..5463955 100644 --- a/migrations/tools/drop_everything.sql +++ b/migrations/tools/drop_everything.sql @@ -1,11 +1,11 @@ DROP TYPE category_type CASCADE; -DROP TABLE app_user CASCADE; +DROP TABLE note_to_category_relationship CASCADE; + +DROP TABLE note_to_publication_relationship CASCADE; DROP TABLE publication CASCADE; DROP TABLE note CASCADE; -DROP TABLE note_to_type_relationship CASCADE; - -DROP TABLE note_to_publication_relationship CASCADE; \ No newline at end of file +DROP TABLE app_user CASCADE; \ No newline at end of file diff --git a/migrations/tools/truncate_tables.sql b/migrations/tools/truncate_tables.sql new file mode 100644 index 0000000..f5b94c6 --- /dev/null +++ b/migrations/tools/truncate_tables.sql @@ -0,0 +1,9 @@ +TRUNCATE note_to_publication_relationship CASCADE; + +TRUNCATE publication CASCADE; + +TRUNCATE note_to_category_relationship CASCADE; + +TRUNCATE note CASCADE; + +TRUNCATE app_user CASCADE; \ No newline at end of file diff --git a/models/datastore.go b/models/datastore.go new file mode 100644 index 0000000..d76600e --- /dev/null +++ b/models/datastore.go @@ -0,0 +1,55 @@ +package models + +import ( + "database/sql" + "errors" + + "github.com/lib/pq" +) + +// UniqueConstraintError is returned when a uniqueness constraint is violated during an insert. +var UniqueConstraintError = errors.New("postgres: unique constraint violation") + +// QueryResultContainedMultipleRowsError is returned when a query unexpectedly returns more than one row. +var QueryResultContainedMultipleRowsError = errors.New("query result unexpectedly contained multiple rows") + +// QueryResultContainedNoRowsError is returned when a query unexpectedly returns no rows. +var QueryResultContainedNoRowsError = errors.New("query result unexpectedly contained no rows") + +// ConnectToDatabase also pings the database to ensure a working connection. +func ConnectToDatabase(databaseUrl string) (*DB, error) { + tempDb, err := sql.Open("postgres", databaseUrl) + if err != nil { + return nil, err + } + + if err := tempDb.Ping(); err != nil { + return nil, err + } + + return &DB{tempDb}, nil +} + +type Datastore interface { + StoreNewNote(*Note) (NoteId, error) + StoreNewNoteCategoryRelationship(NoteId, NoteCategory) error + StoreNewUser(string, *EmailAddress, string) error + AuthenticateUserCredentials(*EmailAddress, string) error + GetIdForUserWithEmailAddress(*EmailAddress) (UserId, error) +} + +type DB struct { + *sql.DB +} + +func convertPostgresError(err error) error { + const uniqueConstraintErrorCode = "23505" + + if postgresErr, ok := err.(*pq.Error); ok { + if postgresErr.Code == uniqueConstraintErrorCode { + return UniqueConstraintError + } + } + + return err +} diff --git a/models/datastore_test.go b/models/datastore_test.go new file mode 100644 index 0000000..1cb8c5c --- /dev/null +++ b/models/datastore_test.go @@ -0,0 +1,113 @@ +package models_test + +import ( + "fmt" + "strconv" + "testing" + "time" + + "github.com/atmiguel/cerealnotes/models" + "github.com/atmiguel/cerealnotes/test_util" +) + +var postgresUrl = "postgresql://localhost/cerealnotes_test?sslmode=disable" + +const noteTable = "note" +const publicationTable = "publication" +const noteToPublicationTable = "note_to_publication_relationship" +const noteToCategoryTable = "note_to_category_relationship" +const userTable = "app_user" + +var tables = []string{ + noteToPublicationTable, + publicationTable, + noteToCategoryTable, + noteTable, + userTable, +} + +func ClearAllValuesInTable(db *models.DB) { + for _, val := range tables { + if err := ClearValuesInTable(db, val); err != nil { + panic(err) + } + } +} + +func ClearValuesInTable(db *models.DB, table string) error { + // db.Query() doesn't allow variables to replace columns or table names. + sqlQuery := fmt.Sprintf(`TRUNCATE %s CASCADE;`, table) + + _, err := db.Exec(sqlQuery) + if err != nil { + return err + } + + return nil +} + +func TestUser(t *testing.T) { + db, err := models.ConnectToDatabase(postgresUrl) + test_util.Ok(t, err) + ClearValuesInTable(db, userTable) + + displayName := "boby" + password := "aPassword" + emailAddress := models.NewEmailAddress("thisIsMyOtherEmail@gmail.com") + + err = db.StoreNewUser(displayName, emailAddress, password) + test_util.Ok(t, err) + + _, err = db.GetIdForUserWithEmailAddress(emailAddress) + test_util.Ok(t, err) + + err = db.AuthenticateUserCredentials(emailAddress, password) + test_util.Ok(t, err) +} + +func TestNote(t *testing.T) { + db, err := models.ConnectToDatabase(postgresUrl) + test_util.Ok(t, err) + ClearValuesInTable(db, userTable) + ClearValuesInTable(db, noteTable) + + displayName := "bob" + password := "aPassword" + emailAddress := models.NewEmailAddress("thisIsMyEmail@gmail.com") + + err = db.StoreNewUser(displayName, emailAddress, password) + test_util.Ok(t, err) + + userId, err := db.GetIdForUserWithEmailAddress(emailAddress) + test_util.Ok(t, err) + + note := &models.Note{AuthorId: userId, Content: "I'm a note", CreationTime: time.Now()} + id, err := db.StoreNewNote(note) + test_util.Ok(t, err) + test_util.Assert(t, int64(id) > 0, "Note Id was not a valid index: "+strconv.Itoa(int(id))) +} + +func TestCategory(t *testing.T) { + db, err := models.ConnectToDatabase(postgresUrl) + test_util.Ok(t, err) + ClearValuesInTable(db, userTable) + ClearValuesInTable(db, noteTable) + ClearValuesInTable(db, noteToCategoryTable) + + displayName := "bob" + password := "aPassword" + emailAddress := models.NewEmailAddress("thisyetAnotherIsMyEmail@gmail.com") + + err = db.StoreNewUser(displayName, emailAddress, password) + test_util.Ok(t, err) + + userId, err := db.GetIdForUserWithEmailAddress(emailAddress) + test_util.Ok(t, err) + + note := &models.Note{AuthorId: userId, Content: "I'm a note", CreationTime: time.Now()} + noteId, err := db.StoreNewNote(note) + test_util.Ok(t, err) + + err = db.StoreNewNoteCategoryRelationship(noteId, models.META) + test_util.Ok(t, err) +} diff --git a/models/note.go b/models/note.go index 323c245..ae1f516 100644 --- a/models/note.go +++ b/models/note.go @@ -1,50 +1,57 @@ package models import ( - "errors" "time" ) type NoteId int64 -type Category int +type Note struct { + AuthorId UserId `json:"authorId"` + Content string `json:"content"` + CreationTime time.Time `json:"creationTime"` +} -const ( - MARGINALIA Category = iota - META - QUESTIONS - PREDICTIONS -) +// DB methods -var categoryStrings = [...]string{ - "marginalia", - "meta", - "questions", - "predictions", -} +func (db *DB) StoreNewNote( + note *Note, +) (NoteId, error) { -var CannotDeserializeCategoryStringError = errors.New("String does not correspond to a Note Category") + authorId := int64(note.AuthorId) + content := note.Content + creationTime := note.CreationTime -func DeserializeCategory(input string) (Category, error) { - for i := 0; i < len(categoryStrings); i++ { - if input == categoryStrings[i] { - return Category(i), nil - } + sqlQuery := ` + INSERT INTO note (author_id, content, creation_time) + VALUES ($1, $2, $3) + RETURNING id` + + rows, err := db.Query(sqlQuery, authorId, content, creationTime) + if err != nil { + return 0, convertPostgresError(err) } - return 0, CannotDeserializeCategoryStringError -} + defer rows.Close() -func (category Category) String() string { + var noteId int64 = 0 + for rows.Next() { - if category < MARGINALIA || category > PREDICTIONS { - return "Unknown" + if noteId != 0 { + return 0, QueryResultContainedMultipleRowsError + } + + if err := rows.Scan(¬eId); err != nil { + return 0, convertPostgresError(err) + } } - return categoryStrings[category] -} + if noteId == 0 { + return 0, QueryResultContainedNoRowsError + } -type Note struct { - AuthorId UserId `json:"authorId"` - Content string `json:"content"` - CreationTime time.Time `json:"creationTime"` + if err := rows.Err(); err != nil { + return 0, convertPostgresError(err) + } + + return NoteId(noteId), nil } diff --git a/models/note_category.go b/models/note_category.go new file mode 100644 index 0000000..e7992c9 --- /dev/null +++ b/models/note_category.go @@ -0,0 +1,62 @@ +package models + +import ( + "errors" +) + +type NoteCategory int + +const ( + MARGINALIA NoteCategory = iota + META + QUESTIONS + PREDICTIONS +) + +var categoryStrings = [...]string{ + "marginalia", + "meta", + "questions", + "predictions", +} + +var CannotDeserializeNoteCategoryStringError = errors.New("String does not correspond to a Note Category") + +func DeserializeNoteCategory(input string) (NoteCategory, error) { + for i := 0; i < len(categoryStrings); i++ { + if input == categoryStrings[i] { + return NoteCategory(i), nil + } + } + return 0, CannotDeserializeNoteCategoryStringError +} + +func (category NoteCategory) String() string { + + if category < MARGINALIA || category > PREDICTIONS { + return "Unknown" + } + + return categoryStrings[category] +} + +func (db *DB) StoreNewNoteCategoryRelationship( + noteId NoteId, + category NoteCategory, +) error { + sqlQuery := ` + INSERT INTO note_to_category_relationship (note_id, category) + VALUES ($1, $2)` + + rows, err := db.Query(sqlQuery, int64(noteId), category.String()) + if err != nil { + return convertPostgresError(err) + } + defer rows.Close() + + if err := rows.Err(); err != nil { + return convertPostgresError(err) + } + + return nil +} diff --git a/models/note_test.go b/models/note_test.go new file mode 100644 index 0000000..94003e0 --- /dev/null +++ b/models/note_test.go @@ -0,0 +1,26 @@ +package models_test + +import ( + "testing" + + "github.com/atmiguel/cerealnotes/models" + "github.com/atmiguel/cerealnotes/test_util" +) + +var deserializationTests = []models.NoteCategory{ + models.MARGINALIA, + models.META, + models.QUESTIONS, + models.PREDICTIONS, +} + +func TestDeserialization(t *testing.T) { + for _, val := range deserializationTests { + t.Run(val.String(), func(t *testing.T) { + cat, err := models.DeserializeNoteCategory(val.String()) + test_util.Ok(t, err) + test_util.Equals(t, val, cat) + }) + } + +} diff --git a/models/notesById.go b/models/notesById.go new file mode 100644 index 0000000..1edd58f --- /dev/null +++ b/models/notesById.go @@ -0,0 +1,19 @@ +package models + +import ( + "encoding/json" + "fmt" +) + +type NotesById map[NoteId]*Note + +func (notesById NotesById) ToJson() ([]byte, error) { + // json doesn't support int indexed maps + notesByIdString := make(map[string]Note, len(notesById)) + + for id, note := range notesById { + notesByIdString[fmt.Sprint(id)] = *note + } + + return json.Marshal(notesByIdString) +} diff --git a/models/user.go b/models/user.go index 0ecac0c..b90198f 100644 --- a/models/user.go +++ b/models/user.go @@ -1,6 +1,12 @@ package models -import "strings" +import ( + "errors" + "strings" + "time" + + "golang.org/x/crypto/bcrypt" +) type UserId int64 @@ -20,3 +26,109 @@ func NewEmailAddress(emailAddressAsString string) *EmailAddress { func (emailAddress *EmailAddress) String() string { return emailAddress.emailAddressAsString } + +var EmailAddressAlreadyInUseError = errors.New("Email address already in use") + +var CredentialsNotAuthorizedError = errors.New("The provided credentials were not found") + +// + +func (db *DB) StoreNewUser( + displayName string, + emailAddress *EmailAddress, + password string, +) error { + hashedPassword, err := bcrypt.GenerateFromPassword( + []byte(password), + bcrypt.DefaultCost) + if err != nil { + return err + } + + creationTime := time.Now().UTC() + + sqlQuery := ` + INSERT INTO app_user (display_name, email_address, password, creation_time) + VALUES ($1, $2, $3, $4)` + + rows, err := db.Query(sqlQuery, displayName, emailAddress.String(), hashedPassword, creationTime) + if err != nil { + return convertPostgresError(err) + } + defer rows.Close() + + if err := rows.Err(); err != nil { + return convertPostgresError(err) + } + + return nil +} + +func (db *DB) AuthenticateUserCredentials(emailAddress *EmailAddress, password string) error { + sqlQuery := ` + SELECT password FROM app_user + WHERE email_address = $1` + + rows, err := db.Query(sqlQuery, emailAddress.String()) + if err != nil { + return convertPostgresError(err) + } + defer rows.Close() + + var storedHashedPassword []byte + for rows.Next() { + if storedHashedPassword != nil { + return QueryResultContainedMultipleRowsError + } + + if err := rows.Scan(&storedHashedPassword); err != nil { + return err + } + } + + if storedHashedPassword == nil { + return QueryResultContainedNoRowsError + } + + if err := bcrypt.CompareHashAndPassword( + storedHashedPassword, + []byte(password), + ); err != nil { + if err == bcrypt.ErrMismatchedHashAndPassword { + return CredentialsNotAuthorizedError + } + + return err + } + + return nil +} + +func (db *DB) GetIdForUserWithEmailAddress(emailAddress *EmailAddress) (UserId, error) { + sqlQuery := ` + SELECT id FROM app_user + WHERE email_address = $1` + + rows, err := db.Query(sqlQuery, emailAddress.String()) + if err != nil { + return 0, convertPostgresError(err) + } + defer rows.Close() + + var userId int64 + for rows.Next() { + if userId != 0 { + return 0, QueryResultContainedMultipleRowsError + } + + if err := rows.Scan(&userId); err != nil { + return 0, err + } + } + + if userId == 0 { + return 0, QueryResultContainedNoRowsError + } + + return UserId(userId), nil +} diff --git a/paths/paths.go b/paths/paths.go index a36e22b..e9db54b 100644 --- a/paths/paths.go +++ b/paths/paths.go @@ -8,8 +8,8 @@ const ( HomePage = "/home" NotesPage = "/notes" - UserApi = "/api/user" - SessionApi = "/api/session" - NoteApi = "/api/note" - CategoryApi = "/api/note-category" + UserApi = "/api/user" + SessionApi = "/api/session" + NoteApi = "/api/note" + NoteCategoryApi = "/api/note-category" ) diff --git a/routers/routers.go b/routers/routers.go index dfcee5c..ff12b73 100644 --- a/routers/routers.go +++ b/routers/routers.go @@ -15,21 +15,31 @@ type routeHandler struct { } func (mux *routeHandler) handleAuthenticatedPage( + env *handlers.Environment, pattern string, handlerFunc handlers.AuthenticatedRequestHandlerType, ) { - mux.HandleFunc(pattern, handlers.AuthenticateOrRedirect(handlerFunc, paths.LoginOrSignupPage)) + mux.HandleFunc(pattern, handlers.AuthenticateOrRedirect(env, handlerFunc, paths.LoginOrSignupPage)) } func (mux *routeHandler) handleAuthenticatedApi( + env *handlers.Environment, pattern string, handlerFunc handlers.AuthenticatedRequestHandlerType, ) { - mux.HandleFunc(pattern, handlers.AuthenticateOrReturnUnauthorized(handlerFunc)) + mux.HandleFunc(pattern, handlers.AuthenticateOrReturnUnauthorized(env, handlerFunc)) +} + +func (mux *routeHandler) handleUnAutheticedRequest( + env *handlers.Environment, + pattern string, + handlerFunc handlers.UnauthenticatedEndpointHandlerType, +) { + mux.HandleFunc(pattern, handlers.WrapUnauthenticatedEndpoint(env, handlerFunc)) } // DefineRoutes returns a new servemux with all the required path and handler pairs attached. -func DefineRoutes() http.Handler { +func DefineRoutes(env *handlers.Environment) http.Handler { mux := &routeHandler{http.NewServeMux()} // static files { @@ -49,18 +59,18 @@ func DefineRoutes() http.Handler { mux.HandleFunc("/favicon.ico", handlers.RedirectToPathHandler("/static/favicon.ico")) // pages - mux.HandleFunc(paths.LoginOrSignupPage, handlers.HandleLoginOrSignupPageRequest) + mux.handleUnAutheticedRequest(env, paths.LoginOrSignupPage, handlers.HandleLoginOrSignupPageRequest) - mux.handleAuthenticatedPage(paths.HomePage, handlers.HandleHomePageRequest) - mux.handleAuthenticatedPage(paths.NotesPage, handlers.HandleNotesPageRequest) + mux.handleAuthenticatedPage(env, paths.HomePage, handlers.HandleHomePageRequest) + mux.handleAuthenticatedPage(env, paths.NotesPage, handlers.HandleNotesPageRequest) // api - mux.HandleFunc(paths.UserApi, handlers.HandleUserApiRequest) - mux.HandleFunc(paths.SessionApi, handlers.HandleSessionApiRequest) + mux.handleUnAutheticedRequest(env, paths.UserApi, handlers.HandleUserApiRequest) + mux.handleUnAutheticedRequest(env, paths.SessionApi, handlers.HandleSessionApiRequest) - mux.handleAuthenticatedApi(paths.NoteApi, handlers.HandleNoteApiRequest) - mux.handleAuthenticatedApi(paths.CategoryApi, handlers.HandleNoteCateogryApiRequest) + mux.handleAuthenticatedApi(env, paths.NoteApi, handlers.HandleNoteApiRequest) + mux.handleAuthenticatedApi(env, paths.NoteCategoryApi, handlers.HandleNoteCateogryApiRequest) return mux } diff --git a/services/noteservice/noteservice.go b/services/noteservice/noteservice.go deleted file mode 100644 index dede981..0000000 --- a/services/noteservice/noteservice.go +++ /dev/null @@ -1,48 +0,0 @@ -/* -Package noteservice handles interactions with database layer. -*/ -package noteservice - -import ( - "encoding/json" - "fmt" - - "github.com/atmiguel/cerealnotes/databaseutil" - "github.com/atmiguel/cerealnotes/models" -) - -func StoreNewNote( - note *models.Note, -) (models.NoteId, error) { - - id, err := databaseutil.InsertNewNote(int64(note.AuthorId), note.Content, note.CreationTime) - if err != nil { - return models.NoteId(0), err - } - - return models.NoteId(id), nil -} - -func StoreNewNoteCategoryRelationship( - noteId models.NoteId, - category models.Category, -) error { - if err := databaseutil.InsertNoteCategoryRelationship(int64(noteId), category.String()); err != nil { - return err - } - - return nil -} - -type NotesById map[models.NoteId]*models.Note - -func (notesById NotesById) ToJson() ([]byte, error) { - // json doesn't support int indexed maps - notesByIdString := make(map[string]models.Note, len(notesById)) - - for id, note := range notesById { - notesByIdString[fmt.Sprint(id)] = *note - } - - return json.Marshal(notesByIdString) -} diff --git a/services/userservice/userservice.go b/services/userservice/userservice.go deleted file mode 100644 index fab7356..0000000 --- a/services/userservice/userservice.go +++ /dev/null @@ -1,92 +0,0 @@ -/* -Package userservice handles interactions with database layer. -*/ -package userservice - -import ( - "errors" - "time" - - "github.com/atmiguel/cerealnotes/databaseutil" - "github.com/atmiguel/cerealnotes/models" - "golang.org/x/crypto/bcrypt" -) - -var EmailAddressAlreadyInUseError = errors.New("Email address already in use") - -var CredentialsNotAuthorizedError = errors.New("The provided credentials were not found") - -func StoreNewUser( - displayName string, - emailAddress *models.EmailAddress, - password string, -) error { - hashedPassword, err := bcrypt.GenerateFromPassword( - []byte(password), - bcrypt.DefaultCost) - if err != nil { - return err - } - - creationTime := time.Now().UTC() - - if err := databaseutil.InsertIntoUserTable( - displayName, - emailAddress.String(), - hashedPassword, - creationTime, - ); err != nil { - if err == databaseutil.UniqueConstraintError { - return EmailAddressAlreadyInUseError - } - - return err - } - - return nil -} - -func AuthenticateUserCredentials(emailAddress *models.EmailAddress, password string) error { - storedHashedPassword, err := databaseutil.GetPasswordForUserWithEmailAddress(emailAddress.String()) - if err != nil { - if err == databaseutil.QueryResultContainedMultipleRowsError { - return err // would normally throw a runtime here - } - - if err == databaseutil.QueryResultContainedNoRowsError { - return CredentialsNotAuthorizedError - } - - return err - } - - if err := bcrypt.CompareHashAndPassword( - storedHashedPassword, - []byte(password), - ); err != nil { - if err == bcrypt.ErrMismatchedHashAndPassword { - return CredentialsNotAuthorizedError - } - - return err - } - - return nil -} - -func GetIdForUserWithEmailAddress(emailAddress *models.EmailAddress) (models.UserId, error) { - userIdAsInt, err := databaseutil.GetIdForUserWithEmailAddress(emailAddress.String()) - if err != nil { - if err == databaseutil.QueryResultContainedMultipleRowsError { - return 0, err // would normally throw a runtime here - } - - if err == databaseutil.QueryResultContainedNoRowsError { - return 0, err - } - - return 0, err - } - - return models.UserId(userIdAsInt), nil -} diff --git a/static/js/notes.js b/static/js/notes.js index bb8bff9..f3d1d38 100644 --- a/static/js/notes.js +++ b/static/js/notes.js @@ -22,7 +22,7 @@ const $createDivider = function() { return $('', {text: ' - '}); }; -const $createNote = function(note) { +const $createNote = function(noteId, note) { const $author = $createAuthor(note.authorId); const $type = $createType(note.type); const $creationTime = $createCreationTime(note.creationTime); @@ -45,11 +45,9 @@ $(function() { $.get('/api/note', function(notes) { const $notes = $('#notes'); - notes.forEach((note) => { - $notes.append( - $createNote(note) - ); - }); + for (const key of Object.keys(notes)) { + $notes.append($createNote(key, notes[key])); + } }); }); }); diff --git a/test_util/test_util.go b/test_util/test_util.go new file mode 100644 index 0000000..e61c628 --- /dev/null +++ b/test_util/test_util.go @@ -0,0 +1,33 @@ +package test_util + +import ( + "fmt" + "path/filepath" + "reflect" + "runtime" + "testing" +) + +func Assert(tb testing.TB, condition bool, msg string, v ...interface{}) { + if !condition { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...) + tb.FailNow() + } +} + +func Ok(tb testing.TB, err error) { + if err != nil { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error()) + tb.FailNow() + } +} + +func Equals(tb testing.TB, expected interface{}, actual interface{}) { + if !reflect.DeepEqual(expected, actual) { + _, file, line, _ := runtime.Caller(1) + fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, expected, actual) + tb.FailNow() + } +}