diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9b44840..a4ff896 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -5,18 +5,19 @@ on: branches: [main] pull_request: branches: [main] + workflow_dispatch: jobs: audit: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Set up Go - uses: actions/setup-go@v2 + uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: 1.19 - name: Verify dependencies run: go mod verify @@ -37,4 +38,7 @@ jobs: run: golint ./... - name: Run tests - run: go test -race -vet=off ./... \ No newline at end of file + run: go test -race -vet=off ./... + + - name: Update coverage report + uses: ncruces/go-coverage-report@main \ No newline at end of file diff --git a/go.mod b/go.mod index 03d1e16..94e8601 100644 --- a/go.mod +++ b/go.mod @@ -1,7 +1,3 @@ module github.com/tsawler/toolbox -go 1.18 - -require github.com/gabriel-vasile/mimetype v1.4.0 - -require golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 // indirect +go 1.19 diff --git a/go.sum b/go.sum index 9876ebe..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,11 +0,0 @@ -github.com/gabriel-vasile/mimetype v1.4.0 h1:Cn9dkdYsMIu56tGho+fqzh7XmvY2YyGU0FnbhiOsEro= -github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmxRtOJlERCzSmRvr8= -golang.org/x/net v0.0.0-20210505024714-0287a6fb4125 h1:Ugb8sMTWuWRC3+sz5WeN/4kejDx9BvIwnPUiJBjJE+8= -golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA= -golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/load_sql.go b/load_sql.go new file mode 100644 index 0000000..24271e8 --- /dev/null +++ b/load_sql.go @@ -0,0 +1,78 @@ +package toolbox + +import ( + "bufio" + "fmt" + "os" + "strings" +) + +// LoadSQLQueries loads SQL queries from a file and populates the QUERY map. +// This tool aims to facilitate the use of the go language's database/sql standard library. +// Writing SQL queries directly in the code can make it messy, so writing SQL queries in .sql files +// and then calling them from the code helps prevent code clutter, +// allowing SQL queries to be centralized in one place for better organization. +func (t *Tools) LoadSQLQueries(fileName string) (map[string]string, error) { + query := make(map[string]string) + + file, err := os.Open(fileName) + if err != nil { + return query, err + } + defer func() { + _ = file.Close() + }() + + query, err = parseSQLQueries(file, query) + return query, err +} + +// parseSQLQueries reads the SQL queries from the provided file and populates the QUERY map. +func parseSQLQueries(file *os.File, query map[string]string) (map[string]string, error) { + scanner := bufio.NewScanner(file) + var key string + var queryBuilder strings.Builder + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if isSQLQuery(line) || len(key) > 0 { + if len(key) > 0 { + if strings.HasSuffix(line, ";") { + queryBuilder.WriteString(line) + query[key] = queryBuilder.String() + key, queryBuilder = "", strings.Builder{} + } else { + queryBuilder.WriteString(line + " ") + } + } else { + key = extractKey(line) + } + } + } + if err := scanner.Err(); err != nil { + return query, fmt.Errorf("error reading file: %w", err) + } + return query, nil +} + +// isSQLQuery checks if the given line is an SQL query or a comment. +func isSQLQuery(line string) bool { + return hasPrefixInList(line, []string{"-- ", "SELECT", "INSERT", "UPDATE", "DELETE"}) +} + +// extractKey extracts the key from the comment line. +func extractKey(line string) string { + if strings.HasPrefix(line, "-- ") { + return strings.Split(line, "-- ")[1] + } + return "" +} + +// hasPrefixInList is a prefix checker +func hasPrefixInList(str string, prefixes []string) bool { + for _, prefix := range prefixes { + if strings.HasPrefix(str, prefix) { + return true + } + } + return false +} diff --git a/load_sql_test.go b/load_sql_test.go new file mode 100644 index 0000000..4bb8417 --- /dev/null +++ b/load_sql_test.go @@ -0,0 +1,100 @@ +package toolbox + +import ( + "os" + "strconv" + "strings" + "testing" +) + +var tools Tools + +func TestLoadSQLQueries(t *testing.T) { + tests := []struct { + fileName string + key string + value string + equal bool + err bool + }{ + {fileName: "./testdata/not.sql", key: "", value: "", equal: true, err: true}, + {fileName: "./testdata/not.sql", key: "not", value: "", equal: true, err: true}, + {fileName: "./testdata/not.sql", key: "not", value: "equal", equal: false, err: true}, + {fileName: "./testdata/test.sql", key: "TEST1", value: "WHERE ass.id=$1;", equal: true, err: false}, + {fileName: "./testdata/test.sql", key: "TEST1", value: "WHERE ass.id=$1", equal: false, err: false}, + {fileName: "./testdata/test.sql", key: "TEST2", value: "id = $1;", equal: true, err: false}, + } + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + query, err := tools.LoadSQLQueries(tt.fileName) + if (err != nil) != tt.err { + t.Errorf("LoadSQLQueries() error: %v, except: %v", err, tt.err) + } + if strings.HasSuffix(query[tt.key], tt.value) != tt.equal { + t.Errorf("LoadSQLQueries() error: %v, except: %v", err, tt.equal) + } + }) + } +} + +func TestParseSQLQueries(t *testing.T) { + file, err := os.Open("./testdata/test.sql") + defer func(file *os.File) { + _ = file.Close() + }(file) + if (err != nil) != false { + t.Errorf("File Open result: %v, expect: %v", false, true) + } + _, err = parseSQLQueries(file, make(map[string]string)) + if (err != nil) != false { + t.Errorf("parseSQLQueries() result: %v, expect: %v", false, true) + } +} + +func TestIsSQLQuery(t *testing.T) { + if result := isSQLQuery("-- "); result != true { + t.Errorf("isSQLQuery() result: %v, expect: %v", result, true) + } + if result := isSQLQuery("--"); result != false { + t.Errorf("isSQLQuery() result: %v, expect: %v", result, false) + } +} + +func TestExtractKey(t *testing.T) { + tests := []struct { + value string + expect string + }{ + {value: "-- ABC", expect: "ABC"}, + {value: "DEF", expect: ""}, + } + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + if result := extractKey(tt.value); result != tt.expect { + t.Errorf("extractKey() result: %v, expect: %v", result, tt.expect) + } + }) + } +} + +func TestHasPrefixInList(t *testing.T) { + type args struct { + key string + value []string + } + tests := []struct { + args args + expect bool + }{ + {args: args{key: "abc"}, expect: false}, + {args: args{key: "abc", value: []string{"abc", "def"}}, expect: true}, + {args: args{key: "xyz", value: []string{"abc", "def"}}, expect: false}, + } + for i, tt := range tests { + t.Run(strconv.Itoa(i), func(t *testing.T) { + if result := hasPrefixInList(tt.args.key, tt.args.value); result != tt.expect { + t.Errorf("hasPrefixInList() result: %v, expect: %v", result, tt.expect) + } + }) + } +} diff --git a/readme.md b/readme.md index f75e8f0..f0abb7c 100644 --- a/readme.md +++ b/readme.md @@ -1,8 +1,11 @@ -[![Version](https://img.shields.io/badge/goversion-1.18.x-blue.svg)](https://golang.org) -[![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/tsawler/goblender/master/LICENSE) +[![Version](https://img.shields.io/badge/goversion-1.19.x-blue.svg)](https://golang.org) +Built with GoLang +[![License](http://img.shields.io/badge/license-mit-blue.svg?style=flat-square)](https://raw.githubusercontent.com/tsawler/toolbox/master/LICENSE) [![Go Report Card](https://goreportcard.com/badge/github.com/tsawler/toolbox)](https://goreportcard.com/report/github.com/tsawler/toolbox) ![Tests](https://github.com/tsawler/toolbox/actions/workflows/tests.yml/badge.svg) +[![Go Coverage](https://github.com/tsawler/toolbox/wiki/coverage.svg)](https://raw.githack.com/wiki/tsawler/toolbox/coverage.html) + # Toolbox A simple example of how to create a reusable Go module with commonly used tools. @@ -12,18 +15,23 @@ The included tools are: - Read JSON - Write JSON - Produce a JSON encoded error response +- Write XML +- Read XML +- Produce an XML encoded error response - Upload a file to a specified directory - Download a static file - Get a random string of length n - Post JSON to a remote service - Create a directory, including all parent directories, if it does not already exist - -**Not for production -- used in a course.** +- Create a URL safe slug from a string +- ContainsElement checks if a value exists in a slice +- Creating a QUERY map from a SQL file ## Installation `go get -u github.com/tsawler/toolbox` + ## Usage ```go @@ -107,4 +115,387 @@ func (app *Config) SomeHandler(w http.ResponseWriter, r *http.Request) { // keep going in the handler... } +``` + +### Uploading a File: + +To upload a file to a specific directory, with this for HTML: + +```html + + + + + + + + Upload test + + +
+
+
+

Upload a file

+
+ +
+ +
+ + +
+ + + +
+ +
+
+
+ + +``` +And this for a Go application: + +```go +package main + +import ( + "fmt" + "github.com/tsawler/toolbox" + "log" + "net/http" +) + +func main() { + + // handle html route (http://localhost:8080/) + http.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(".")))) + + // Post handler + http.HandleFunc("/upload", func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + t := toolbox.Tools{ + MaxFileSize: 1024 * 1024 * 1024, + AllowedFileTypes: []string{"image/gif", "image/png", "image/jpeg"}, + } + + + // Upload the file(s). Note that if you don't want the files to be renamed, + // you can add an optional final parameter -- true will rename the files (the default) + // and false will preserve the original filenames, for example: + // files, err := t.UploadFiles(r, "./uploads", false) + // n.b.: if the "./uploads" directory does not exist, we attempt to create it. + files, err := t.UploadFiles(r, "./uploads") + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + // the returned variable, files, will be a slice of the type toolbox.UploadedFile + _, _ = w.Write([]byte(fmt.Sprintf("Uploaded %d file(s) to the uploads folder", len(files)))) + }) + + // print a log message + log.Println("Starting server on port 8080") + + // start the server + http.ListenAndServe(":8080", nil) +} +``` + +### Calling a Remote API + +To make a JSON post to a remote URI, with this html: + +```html + + + + + + + + + JSON functionality + + + +
+
+
+

JSON functionality

+
+ +
+ + +
+ + +
+ + + Push JSON +
+
+

Response from server:

+
+
No response from server yet...
+
+ +
+
+
+ + + + +``` + +You can use this kind of Go code: + +```go +package main + +import ( + "github.com/tsawler/toolbox" + "log" + "net/http" +) + +func main() { + // create a default server mux + mux := http.NewServeMux() + + // register routes + mux.Handle("/", http.StripPrefix("/", http.FileServer(http.Dir(".")))) + mux.HandleFunc("/receive-post", receivePost) + mux.HandleFunc("/remote-service", remoteService) + + // print a log message + log.Println("Starting server on port 8081") + + // start the server + err := http.ListenAndServe(":8081", mux) + if err != nil { + log.Fatal(err) + } +} + +// RequestPayload describes the JSON that this service accepts as an HTTP Post request +type RequestPayload struct { + Action string `json:"action"` + Message string `json:"message"` +} + +// ResponsePayload is the structure used for sending a JSON response +type ResponsePayload struct { + Message string `json:"message"` + StatusCode int `json:"status_code,omitempty"` +} + +func receivePost(w http.ResponseWriter, r *http.Request) { + // get the posted json and decode it + var requestPayload RequestPayload + var t toolbox.Tools + + err := t.ReadJSON(w, r, &requestPayload) + if err != nil { + _ = t.ErrorJSON(w, err) + return + } + + // Call remote service. Note that we are ignoring the first return parameter, which is the + // entire response from the remote service, but you have access to it if you need it. + _, statusCode, err := t.PushJSONToRemote("http://localhost:8081/remote-service", requestPayload) + if err != nil { + _ = t.ErrorJSON(w, err) + return + } + + // send response + payload := ResponsePayload{ + Message: "hit the service ok", + StatusCode: statusCode, + } + + err = t.WriteJSON(w, http.StatusAccepted, payload) + if err != nil { + log.Println(err) + } +} + +// remoteService just simulates calling some remote API +func remoteService(w http.ResponseWriter, r *http.Request) { + payload := ResponsePayload{ + Message: "OK", + } + var t toolbox.Tools + + _ = t.WriteJSON(w, http.StatusOK, payload) +} +``` + +### Create a slug from a string + +To slugify a string, we simply remove all non URL safe characters and return the +original string with a hyphen where spaces would be. Example: + +```go +package main + +import ( + "fmt" + "github.com/tsawler/toolbox" +) + +func main() { + toSlugify := "hello, world! These are unsafe chars: こんにちは世界*!&^%" + fmt.Println("To slugify:", toSlugify) + var tools toolbox.Tools + + slug, err := tools.Slugify(toSlugify) + if err != nil { + fmt.Println(err) + } + + fmt.Println("Slugified:", slug) +} +``` + +Output from this is: + +``` +To slugify: hello, world! These are unsafe chars: こんにちは世界*!&^% +Slugified: hello-world-these-are-unsafe-chars +``` + +### Value exists in a slice +It is a method we often use when writing code. Is the value we are looking for present in the slice? The type of this slice can be of any type. Example: + +```go +package main + +import( + "fmt" + "github.com/tsawler/toolbox" +) + +func main(){ + var tools toolbox.Tools + + tests := []test{{name: "abc"}, {name: "def"}} + t1 := test{name: "def"} + t2 := test{name: "xyz"} + + if tools.ContainsElement(t1, tests) { + fmt.Println("This slice contains the key you are looking for.") + } + + if !tools.ContainsElement(t2, tests) { + fmt.Println("This slice does not contain the key you are looking for.") + } +} +``` + +### Creating a QUERY map from a SQL file +This tool aims to facilitate the use of the go language's database/sql standard library. Writing SQL queries directly in the code can make it messy, so writing SQL queries in .sql files and then calling them from the code helps prevent code clutter, allowing SQL queries to be centralized in one place for better organization. Example: + +```go +package main + +import( + "database/sql" + "encoding/json" + "log" + "github.com/tsawler/toolbox" +) + +type City struct { + ID int `json:"id,omitempty"` + Name string `json:"name,omitempty"` + CountryID int `json:"country_id,omitempty"` +} + +var DB *sql.DB +var QUERY map[string]string + +func main(){ + var tools toolbox.Tools + + // Load Sql + QUERY = make(map[string]string) + if query, err := tools.LoadSQLQueries("./testdata/test.sql"); err != nil { + log.Fatalf("Load Sql Error: %v", err) + } else { + QUERY = query + } + + var cities []City + rows, err := DB.Query(QUERY["CITIES"]) + if err != nil { + log.Printf("GET Cities Error %v", err) + return + } + defer func(rows *sql.Rows) { + _ = rows.Close() + }(rows) + for rows.Next() { + var city City + _ = rows.Scan(&city.ID, &city.Name, &city.CountryID) + cities = append(cities, city) + } + + if marshal, err := json.MarshalIndent(cities, "", " "); err != nil { + return + }else{ + log.Println(string(marshal)) + } +} ``` \ No newline at end of file diff --git a/testdata/test.sql b/testdata/test.sql new file mode 100644 index 0000000..da7ce85 --- /dev/null +++ b/testdata/test.sql @@ -0,0 +1,11 @@ +-- TEST1 +SELECT * FROM assistant_services as ass + JOIN assistants as a ON a.id=ass.assistant_id + JOIN vehicles as v ON v.id=a.vehicle_id + WHERE ass.id=$1; + +-- TEST2 +UPDATE assistants SET provider_id = $1 WHERE id = $1; + +-- CITIES +SELECT c.id, c.name, c.country_id FROM cities; \ No newline at end of file diff --git a/tools.go b/tools.go index 5ff341e..87fb192 100644 --- a/tools.go +++ b/tools.go @@ -4,86 +4,169 @@ import ( "bytes" "crypto/rand" "encoding/json" + "encoding/xml" "errors" "fmt" "io" + "log" "net/http" "os" "path" - - "github.com/gabriel-vasile/mimetype" + "path/filepath" + "reflect" + "regexp" + "strings" ) +// randomStringSource is the source for generating random strings. const randomStringSource = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0987654321_+" -// Tools is the type for this package. Create a variable of this type and you have access -// to all the methods with the receiver type *Tools. +// defaultMaxUpload is the default max upload size (10 mb) +const defaultMaxUpload = 10485760 + +// Tools is the type for this package. Create a variable of this type, and you have access +// to all the exported methods with the receiver type *Tools. type Tools struct { - MaxFileSize int + MaxJSONSize int // maximum size of JSON file we'll process + MaxXMLSize int // maximum size of XML file we'll process + MaxFileSize int // maximum size of uploaded files in bytes + AllowedFileTypes []string // allowed file types for upload (e.g. image/jpeg) + AllowUnknownFields bool // if set to true, allow unknown fields in JSON + ErrorLog *log.Logger // the info log. + InfoLog *log.Logger // the error log. +} + +// New returns a new toolbox with sensible defaults. +func New() Tools { + return Tools{ + MaxJSONSize: defaultMaxUpload, + MaxXMLSize: defaultMaxUpload, + MaxFileSize: defaultMaxUpload, + InfoLog: log.New(os.Stdout, "INFO\t", log.Ldate|log.Ltime), + ErrorLog: log.New(os.Stdout, "ERROR\t", log.Ldate|log.Ltime|log.Lshortfile), + } } -// JSONResponse is the type used for sending JSON around +// JSONResponse is the type used for sending JSON around. type JSONResponse struct { - Error bool `json:"error"` - Message string `json:"message"` - Data any `json:"data,omitempty"` + Error bool `json:"error"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// XMLResponse is the type used for sending XML around. +type XMLResponse struct { + Error bool `xml:"error"` + Message string `xml:"message"` + Data interface{} `xml:"data,omitempty"` } -// ReadJSON tries to read the body of a request and converts it into JSON -func (t *Tools) ReadJSON(w http.ResponseWriter, r *http.Request, data any) error { - maxBytes := 1048576 // one megabyte - if t.MaxFileSize > 0 { - maxBytes = t.MaxFileSize +// ReadJSON tries to read the body of a request and converts it from JSON to a variable. The third parameter, data, +// is expected to be a pointer, so that we can read data into it. +func (t *Tools) ReadJSON(w http.ResponseWriter, r *http.Request, data interface{}) error { + + // Check content-type header; it should be application/json. If it's not specified, + // try to decode the body anyway. + if r.Header.Get("Content-Type") != "" { + contentType := r.Header.Get("Content-Type") + if strings.ToLower(contentType) != "application/json" { + return errors.New("the Content-Type header is not application/json") + } } + // Set a sensible default for the maximum payload size. + maxBytes := defaultMaxUpload + + // If MaxJSONSize is set, use that value instead of default. + if t.MaxJSONSize > 0 { + maxBytes = t.MaxJSONSize + } r.Body = http.MaxBytesReader(w, r.Body, int64(maxBytes)) dec := json.NewDecoder(r.Body) + + // Should we allow unknown fields? + if !t.AllowUnknownFields { + dec.DisallowUnknownFields() + } + + // Attempt to decode the data, and figure out what the error is, if any, to send back a human-readable + // response. err := dec.Decode(data) if err != nil { - return err + var syntaxError *json.SyntaxError + var unmarshalTypeError *json.UnmarshalTypeError + var invalidUnmarshalError *json.InvalidUnmarshalError + + switch { + case errors.As(err, &syntaxError): + return fmt.Errorf("body contains badly-formed JSON (at character %d)", syntaxError.Offset) + + case errors.Is(err, io.ErrUnexpectedEOF): + return errors.New("body contains badly-formed JSON") + + case errors.As(err, &unmarshalTypeError): + return fmt.Errorf("body contains incorrect JSON type for field %q at offset %d", unmarshalTypeError.Field, unmarshalTypeError.Offset) + + case errors.Is(err, io.EOF): + return errors.New("body must not be empty") + + case strings.HasPrefix(err.Error(), "json: unknown field "): + fieldName := strings.TrimPrefix(err.Error(), "json: unknown field ") + return fmt.Errorf("body contains unknown key %s", fieldName) + + case err.Error() == "http: request body too large": + return fmt.Errorf("body must not be larger than %d bytes", maxBytes) + + case errors.As(err, &invalidUnmarshalError): + return fmt.Errorf("error unmarshalling json: %s", err.Error()) + + default: + return err + } } err = dec.Decode(&struct{}{}) if err != io.EOF { - return errors.New("body must have only a single json value") + return errors.New("body must only contain a single JSON value") } return nil } -// WriteJSON takes a response status code and arbitrary data and writes a json response to the client -func (t *Tools) WriteJSON(w http.ResponseWriter, status int, data any, headers ...http.Header) error { +// WriteJSON takes a response status code and arbitrary data and writes a JSON response to the client. +func (t *Tools) WriteJSON(w http.ResponseWriter, status int, data interface{}, headers ...http.Header) error { out, err := json.Marshal(data) if err != nil { return err } + // If we have a value as the last parameter in the function call, then we are setting a custom header. if len(headers) > 0 { for key, value := range headers[0] { w.Header()[key] = value } } + // Set the content type and send response. w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) - _, err = w.Write(out) - if err != nil { - return err - } + _, _ = w.Write(out) return nil } // ErrorJSON takes an error, and optionally a response status code, and generates and sends -// a json error response +// a JSON error response. func (t *Tools) ErrorJSON(w http.ResponseWriter, err error, status ...int) error { statusCode := http.StatusBadRequest + // If a custom response code is specified, use that instead of bad request. if len(status) > 0 { statusCode = status[0] } + // Build the JSON payload. var payload JSONResponse payload.Error = true payload.Message = err.Error() @@ -91,7 +174,7 @@ func (t *Tools) ErrorJSON(w http.ResponseWriter, err error, status ...int) error return t.WriteJSON(w, statusCode, payload) } -// RandomString returns a random string of letters of length n +// RandomString returns a random string of letters of length n, using characters specified in randomStringSource. func (t *Tools) RandomString(n int) string { s, r := make([]rune, n), []rune(randomStringSource) for i := range s { @@ -102,33 +185,40 @@ func (t *Tools) RandomString(n int) string { return string(s) } -// PushJSONToRemote posts arbitrary json to some url, and returns error, -// if any, as well as the response status code -func (t *Tools) PushJSONToRemote(client *http.Client, uri string, data any) (int, error) { +// PushJSONToRemote posts arbitrary json to some url, and returns the response, the response +// status code, and error, if any. The final parameter, client, is optional, and will default +// to the standard http.Client. It exists to make testing possible without an active remote +// url. +func (t *Tools) PushJSONToRemote(uri string, data interface{}, client ...*http.Client) (*http.Response, int, error) { // create json we'll send - jsonData, err := json.MarshalIndent(data, "", "\t") + jsonData, err := json.Marshal(data) if err != nil { - return 0, err + return nil, 0, err } - // build the request and set header + httpClient := &http.Client{} + if len(client) > 0 { + httpClient = client[0] + } + + // Build the request and set header. request, err := http.NewRequest("POST", uri, bytes.NewBuffer(jsonData)) if err != nil { - return 0, err + return nil, 0, err } request.Header.Set("Content-Type", "application/json") - // call the uri - response, err := client.Do(request) + // Call the url. + response, err := httpClient.Do(request) if err != nil { - return 0, err + return nil, 0, err } defer response.Body.Close() - return response.StatusCode, nil + return response, response.StatusCode, nil } -// DownloadStaticFile downloads a file, and tries to force the browser to avoid displaying it in +// DownloadStaticFile downloads a file to the remote user, and tries to force the browser to avoid displaying it in // the browser window by setting content-disposition. It also allows specification of the display name. func (t *Tools) DownloadStaticFile(w http.ResponseWriter, r *http.Request, p, file, displayName string) { fp := path.Join(p, file) @@ -137,62 +227,130 @@ func (t *Tools) DownloadStaticFile(w http.ResponseWriter, r *http.Request, p, fi http.ServeFile(w, r, fp) } -// UploadedFile is a struct used to +// UploadedFile is the type used for the uploaded file. type UploadedFile struct { NewFileName string OriginalFileName string FileSize int64 } -// UploadOneFile uploads one file to a specified directory, and gives it a random name. -// It returns the newly named file, the original file name, and potentially an error. -func (t *Tools) UploadOneFile(r *http.Request, uploadDir string) (*UploadedFile, error) { - // parse the form so we have access to the file - err := r.ParseMultipartForm(1024 * 1024 * 1024) +// UploadOneFile is just a convenience method that calls UploadFiles, but expects only one file to +// be in the upload. +func (t *Tools) UploadOneFile(r *http.Request, uploadDir string, rename ...bool) (*UploadedFile, error) { + renameFile := true + if len(rename) > 0 { + renameFile = rename[0] + } + + files, err := t.UploadFiles(r, uploadDir, renameFile) + if err != nil { + return nil, err + } + + return files[0], nil +} + +// UploadFiles uploads one or more file to a specified directory, and gives the files a random name. +// It returns a slice containing the newly named files, the original file names, the size of the files, +// and potentially an error. If the optional last parameter is set to true, then we will not rename +// the files, but will use the original file names. +func (t *Tools) UploadFiles(r *http.Request, uploadDir string, rename ...bool) ([]*UploadedFile, error) { + // check to see if we are renaming the uploadedFiles with the optional last parameter. + renameFile := true + if len(rename) > 0 { + renameFile = rename[0] + } + + var uploadedFiles []*UploadedFile + + // Create the upload directory if it does not exist. + err := t.CreateDirIfNotExist(uploadDir) if err != nil { return nil, err } - var uploadedFile UploadedFile + + // Sanity check on t.MaxFileSize. + if t.MaxFileSize == 0 { + t.MaxFileSize = defaultMaxUpload + } + + // Parse the form, so we have access to the file. + err = r.ParseMultipartForm(int64(t.MaxFileSize)) + if err != nil { + return nil, fmt.Errorf("error parsing form data: %v", err) + } for _, fHeaders := range r.MultipartForm.File { for _, hdr := range fHeaders { - infile, err := hdr.Open() - if err != nil { - return nil, err - } - defer infile.Close() + uploadedFiles, err = func(uploadedFiles []*UploadedFile) ([]*UploadedFile, error) { + var uploadedFile UploadedFile + infile, err := hdr.Open() + if err != nil { + return nil, err + } + defer infile.Close() - ext, err := mimetype.DetectReader(infile) - if err != nil { - fmt.Println(err) - return nil, err - } + if hdr.Size > int64(t.MaxFileSize) { + return nil, fmt.Errorf("the uploaded file is too big, and must be less than %d", t.MaxFileSize) + } - _, err = infile.Seek(0, 0) - if err != nil { - fmt.Println(err) - return nil, err - } + buff := make([]byte, 512) + _, err = infile.Read(buff) + if err != nil { + return nil, err + } + + allowed := false + filetype := http.DetectContentType(buff) + if len(t.AllowedFileTypes) > 0 { + for _, x := range t.AllowedFileTypes { + if strings.EqualFold(filetype, x) { + allowed = true + } + } + } else { + allowed = true + } + + if !allowed { + return nil, errors.New("the uploaded file type is not permitted") + } + + _, err = infile.Seek(0, 0) + if err != nil { + fmt.Println(err) + return nil, err + } - uploadedFile.NewFileName = t.RandomString(25) + ext.Extension() - uploadedFile.OriginalFileName = hdr.Filename + if renameFile { + uploadedFile.NewFileName = fmt.Sprintf("%s%s", t.RandomString(25), filepath.Ext(hdr.Filename)) + } else { + uploadedFile.NewFileName = hdr.Filename + } + uploadedFile.OriginalFileName = hdr.Filename - var outfile *os.File - defer outfile.Close() + var outfile *os.File + defer outfile.Close() - if outfile, err = os.Create(uploadDir + uploadedFile.NewFileName); nil != err { - return nil, err - } else { + if outfile, err = os.Create(filepath.Join(uploadDir, uploadedFile.NewFileName)); nil != err { + return nil, err + } fileSize, err := io.Copy(outfile, infile) if err != nil { return nil, err } uploadedFile.FileSize = fileSize + + uploadedFiles = append(uploadedFiles, &uploadedFile) + + return uploadedFiles, nil + }(uploadedFiles) + if err != nil { + return uploadedFiles, err } } - } - return &uploadedFile, nil + return uploadedFiles, nil } // CreateDirIfNotExist creates a directory, and all necessary parent directories, if it does not exist. @@ -206,3 +364,102 @@ func (t *Tools) CreateDirIfNotExist(path string) error { } return nil } + +// Slugify is a (very) simple means of creating a slug from a provided string. +func (t *Tools) Slugify(s string) (string, error) { + if s == "" { + return "", errors.New("empty string not permitted") + } + var re = regexp.MustCompile(`[^a-z\d]+`) + slug := strings.Trim(re.ReplaceAllString(strings.ToLower(s), "-"), "-") + if len(slug) == 0 { + return "", errors.New("after removing characters, slug is zero length") + } + + return slug, nil +} + +// WriteXML takes a response status code and arbitrary data and writes an XML response to the client. +// The Content-Type header is set to application/xml. +func (t *Tools) WriteXML(w http.ResponseWriter, status int, data interface{}, headers ...http.Header) error { + out, err := xml.Marshal(data) + if err != nil { + return err + } + + // If we have a value as the last parameter in the function call, then we are setting a custom header. + if len(headers) > 0 { + for key, value := range headers[0] { + w.Header()[key] = value + } + } + + // Set the content type and send response. According to RFC 7303, text/xml and application/xml are to be + // treated as the same, so we'll just pick one. + w.Header().Set("Content-Type", "application/xml") + w.WriteHeader(status) + + // Add the XML header. + xmlOut := []byte(xml.Header + string(out)) + _, _ = w.Write(xmlOut) + + return nil +} + +// ReadXML tries to read the body of an XML request into a variable. The third parameter, data, +// is expected to be a pointer, so that we can read data into it. +func (t *Tools) ReadXML(w http.ResponseWriter, r *http.Request, data interface{}) error { + maxBytes := defaultMaxUpload + + // If MaxXMLSize is set, use that value instead of default. + if t.MaxXMLSize > 0 { + maxBytes = t.MaxXMLSize + } + r.Body = http.MaxBytesReader(w, r.Body, int64(maxBytes)) + + dec := xml.NewDecoder(r.Body) + + // Attempt to decode the data. + err := dec.Decode(data) + if err != nil { + return err + } + + err = dec.Decode(&struct{}{}) + if err != io.EOF { + return errors.New("body must only contain a single XML value") + } + + return nil +} + +// ErrorXML takes an error, and optionally a response status code, and generates and sends +// an XML error response. +func (t *Tools) ErrorXML(w http.ResponseWriter, err error, status ...int) error { + statusCode := http.StatusBadRequest + + // If a custom response code is specified, use that instead of bad request. + if len(status) > 0 { + statusCode = status[0] + } + + var payload XMLResponse + payload.Error = true + payload.Message = err.Error() + + return t.WriteXML(w, statusCode, payload) +} + +// ContainsElement checks if a value exists in a slice. +func (t *Tools) ContainsElement(val interface{}, array interface{}) bool { + arr := reflect.ValueOf(array) + if arr.Kind() != reflect.Slice { + return false + } + for i := 0; i < arr.Len(); i++ { + if reflect.DeepEqual(val, arr.Index(i).Interface()) { + return true + } + } + return false +} diff --git a/tools_test.go b/tools_test.go index 406539b..82be010 100644 --- a/tools_test.go +++ b/tools_test.go @@ -3,16 +3,17 @@ package toolbox import ( "bytes" "encoding/json" + "encoding/xml" "errors" "fmt" "image" "image/png" "io" - "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "os" + "sync" "testing" ) @@ -31,46 +32,137 @@ func NewTestClient(fn RoundTripFunc) *http.Client { } } +type testData struct { + Data any `json:"bar"` +} + +var pushTests = []struct { + name string + payload any + errorExpected bool +}{ + { + name: "valid", + payload: testData{ + Data: "bar", + }, + errorExpected: false, + }, + { + name: "invalid", + payload: make(chan int), + errorExpected: true, + }, +} + +func TestNew(t *testing.T) { + tools := New() + if tools.MaxXMLSize != defaultMaxUpload { + t.Error("wrong MaxXMLSize") + } +} + func TestTools_PushJSONToRemote(t *testing.T) { - client := NewTestClient(func(req *http.Request) *http.Response { - // Test request parameters - return &http.Response{ - StatusCode: http.StatusOK, - // Send response to be tested - Body: ioutil.NopCloser(bytes.NewBufferString(`OK`)), - // Must be set to non-nil value or it panics - Header: make(http.Header), - } - }) - - var testApp Tools - var foo struct { - Bar string `json:"bar"` - } - foo.Bar = "bar" - _, err := testApp.PushJSONToRemote(client, "http://example.com/some/path", foo) - if err != nil { - t.Error("failed to call remote url", err) + for _, e := range pushTests { + client := NewTestClient(func(req *http.Request) *http.Response { + // Test request parameters + return &http.Response{ + StatusCode: http.StatusOK, + // Send response to be tested + Body: io.NopCloser(bytes.NewBufferString(`OK`)), + // Must be set to non-nil value or it panics + Header: make(http.Header), + } + }) + + var testTools Tools + + _, _, err := testTools.PushJSONToRemote("http://example.com/some/path", e.payload, client) + if err == nil && e.errorExpected { + t.Errorf("%s: error expected, but none received", e.name) + } + + if err != nil && !e.errorExpected { + t.Errorf("%s: no error expected, but one received: %v", e.name, err) + } } } +var jsonTests = []struct { + name string + json string + errorExpected bool + maxSize int + allowUnknown bool + contentType string +}{ + {name: "good json", json: `{"foo": "bar"}`, errorExpected: false, maxSize: 1024, allowUnknown: false}, + {name: "badly formatted json", json: `{"foo":"}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "incorrect type", json: `{"foo": 1}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "incorrect type", json: `{1: 1}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "two json files", json: `{"foo": "bar"}{"alpha": "beta"}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "empty body", json: ``, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "syntax error in json", json: `{"foo": 1"}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "unknown field in json", json: `{"fooo": "bar"}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "incorrect type for field", json: `{"foo": 10.2}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "allow unknown field in json", json: `{"fooo": "bar"}`, errorExpected: false, maxSize: 1024, allowUnknown: true}, + {name: "missing field name", json: `{jack: "bar"}`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "file too large", json: `{"foo": "bar"}`, errorExpected: true, maxSize: 5, allowUnknown: false}, + {name: "not json", json: `Hello, world`, errorExpected: true, maxSize: 1024, allowUnknown: false}, + {name: "wrong header", json: `{"foo": "bar"}`, errorExpected: true, maxSize: 1024, allowUnknown: false, contentType: "application/xml"}, +} + func TestTools_ReadJSON(t *testing.T) { - var testApp Tools - testApp.MaxFileSize = 1048576 * 2 + for _, e := range jsonTests { + var testTools Tools + // set max file size + testTools.MaxJSONSize = e.maxSize - // create a sample JSON file and add it to body - sampleJSON := map[string]interface{}{ - "foo": "bar", - } - body, _ := json.Marshal(sampleJSON) + // allow/disallow unknown fields. + testTools.AllowUnknownFields = e.allowUnknown + + // declare a variable to read the decoded json into. + var decodedJSON struct { + Foo string `json:"foo"` + } + + // create a request with the body. + req, err := http.NewRequest("POST", "/", bytes.NewReader([]byte(e.json))) + if err != nil { + t.Log("Error", err) + } + if e.contentType != "" { + req.Header.Add("Content-Type", e.contentType) + } else { + req.Header.Add("Content-Type", "application/json") + } + + // create a test response recorder, which satisfies the requirements + // for a ResponseWriter. + rr := httptest.NewRecorder() - // declare a variable to read the decoded json into - var decodedJSON struct { - Foo string `json:"foo"` + // call ReadJSON and check for an error. + err = testTools.ReadJSON(rr, req, &decodedJSON) + + // if we expect an error, but do not get one, something went wrong. + if e.errorExpected && err == nil { + t.Errorf("%s: error expected, but none received", e.name) + } + + // if we do not expect an error, but get one, something went wrong. + if !e.errorExpected && err != nil { + t.Errorf("%s: error not expected, but one received: %s \n%s", e.name, err.Error(), e.json) + } + req.Body.Close() } +} + +func TestTools_ReadJSONAndMarshal(t *testing.T) { + // set max file size + var testTools Tools // create a request with the body - req, err := http.NewRequest("POST", "/", bytes.NewReader(body)) + req, err := http.NewRequest("POST", "/", bytes.NewReader([]byte(`{"foo": "bar"}`))) if err != nil { t.Log("Error", err) } @@ -78,57 +170,63 @@ func TestTools_ReadJSON(t *testing.T) { // create a test response recorder, which satisfies the requirements // for a ResponseWriter rr := httptest.NewRecorder() - defer req.Body.Close() - // call readJSON and check for an error - err = testApp.ReadJSON(rr, req, &decodedJSON) - if err != nil { - t.Error("failed to decode json", err) - } - - // create json with two json entries - badJSON := ` - { - "foo": "bar" - } - { - "alpha": "beta" - }` + // call readJSON and check for an error; since we are using nil for the final parameter, + // we should get an error + err = testTools.ReadJSON(rr, req, nil) - // create a request with the body - req, err = http.NewRequest("POST", "/", bytes.NewReader([]byte(badJSON))) - if err != nil { - t.Log("Error", err) - } - - err = testApp.ReadJSON(rr, req, &decodedJSON) + // we expect an error, but did not get one, so something went wrong if err == nil { - t.Error("did not get an error with bad json") + t.Error("error expected, but none received") } + + req.Body.Close() +} + +var writeJSONTests = []struct { + name string + payload any + errorExpected bool +}{ + { + name: "valid", + payload: JSONResponse{ + Error: false, + Message: "foo", + }, + errorExpected: false, + }, + { + name: "invalid", + payload: make(chan int), + errorExpected: true, + }, } func TestTools_WriteJSON(t *testing.T) { - var testApp Tools + for _, e := range writeJSONTests { + // create a variable of type toolbox.Tools, and just use the defaults. + var testTools Tools - rr := httptest.NewRecorder() - payload := JSONResponse{ - Error: false, - Message: "foo", - } + rr := httptest.NewRecorder() - headers := make(http.Header) - headers.Add("FOO", "BAR") - err := testApp.WriteJSON(rr, http.StatusOK, payload, headers) - if err != nil { - t.Errorf("failed to write JSON: %v", err) + headers := make(http.Header) + headers.Add("FOO", "BAR") + err := testTools.WriteJSON(rr, http.StatusOK, e.payload, headers) + if err == nil && e.errorExpected { + t.Errorf("%s: expected error, but did not get one", e.name) + } + if err != nil && !e.errorExpected { + t.Errorf("%s: did not expect error, but got one: %v", e.name, err) + } } } func TestTools_ErrorJSON(t *testing.T) { - var testApp Tools + var testTools Tools rr := httptest.NewRecorder() - err := testApp.ErrorJSON(rr, errors.New("some error")) + err := testTools.ErrorJSON(rr, errors.New("some error"), http.StatusServiceUnavailable) if err != nil { t.Error(err) } @@ -144,29 +242,27 @@ func TestTools_ErrorJSON(t *testing.T) { t.Error("error set to false in response from ErrorJSON, and should be set to true") } - // test with status - err = testApp.ErrorJSON(rr, errors.New("another error"), http.StatusServiceUnavailable) - if err != nil { - t.Error(err) + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("wrong status code returned; expected 503, but got %d", rr.Code) } } func TestTools_RandomString(t *testing.T) { - var testApp Tools + var testTools Tools - s := testApp.RandomString(10) + s := testTools.RandomString(10) if len(s) != 10 { t.Error("wrong length random string returned") } } -func TestTools_DownloadStaticFile(t *testing.T) { +func TestTools_DownloadLargeStaticFile(t *testing.T) { rr := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/", nil) - var testApp Tools + var testTools Tools - testApp.DownloadStaticFile(rr, req, "./testdata", "tgg.jpg", "gatsby.jpg") + testTools.DownloadStaticFile(rr, req, "./testdata", "tgg.jpg", "gatsby.jpg") res := rr.Result() defer res.Body.Close() @@ -179,58 +275,160 @@ func TestTools_DownloadStaticFile(t *testing.T) { t.Error("wrong content disposition of", res.Header["Content-Disposition"][0]) } - _, err := ioutil.ReadAll(res.Body) + _, err := io.ReadAll(res.Body) if err != nil { t.Error(err) } } -func TestTools_UploadOneFile(t *testing.T) { - // set up a pipe to avoid buffering - pr, pw := io.Pipe() - writer := multipart.NewWriter(pw) - - go func() { - defer writer.Close() - // create the form data field 'fileupload' - part, err := writer.CreateFormFile("file", "./testdata/img.png") - if err != nil { - t.Error(err) - } +var uploadTests = []struct { + name string + allowedTypes []string + renameFile bool + errorExpected bool + maxSize int + uploadDir string +}{ + {name: "allowed no rename", allowedTypes: []string{"image/jpeg", "image/png"}, renameFile: false, errorExpected: false, maxSize: 0, uploadDir: ""}, + {name: "allowed rename", allowedTypes: []string{"image/jpeg", "image/png"}, renameFile: true, errorExpected: false, maxSize: 0, uploadDir: ""}, + {name: "allowed no filetype specified", allowedTypes: []string{}, renameFile: true, errorExpected: false, maxSize: 0, uploadDir: ""}, + {name: "not allowed", allowedTypes: []string{"image/jpeg"}, errorExpected: true, maxSize: 0, uploadDir: ""}, + {name: "too big", allowedTypes: []string{"image/jpeg,", "image/png"}, errorExpected: true, maxSize: 10, uploadDir: ""}, + {name: "invalid directory", allowedTypes: []string{"image/jpeg,", "image/png"}, errorExpected: true, maxSize: 0, uploadDir: "//"}, +} - f, err := os.Open("./testdata/img.png") - if err != nil { - t.Error(err) +func TestTools_UploadFiles(t *testing.T) { + for _, e := range uploadTests { + // set up a pipe to avoid buffering + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer writer.Close() + defer wg.Done() + + // create the form data field 'file' + part, err := writer.CreateFormFile("file", "./testdata/img.png") + if err != nil { + t.Error(err) + } + + f, err := os.Open("./testdata/img.png") + if err != nil { + t.Error(err) + } + defer f.Close() + img, _, err := image.Decode(f) + if err != nil { + t.Error("error decoding image", err) + } + + err = png.Encode(part, img) + if err != nil { + t.Error(err) + } + }() + + // read from the pipe which receives data + request := httptest.NewRequest("POST", "/", pr) + request.Header.Add("Content-Type", writer.FormDataContentType()) + + var testTools Tools + testTools.AllowedFileTypes = e.allowedTypes + if e.maxSize > 0 { + testTools.MaxFileSize = e.maxSize } - defer f.Close() - img, _, err := image.Decode(f) - if err != nil { - t.Error("error decoding image", err) + + var uploadDir = "./testdata/uploads/" + if e.uploadDir != "" { + uploadDir = e.uploadDir } - err = png.Encode(part, img) - if err != nil { + uploadedFiles, err := testTools.UploadFiles(request, uploadDir, e.renameFile) + if err != nil && !e.errorExpected { t.Error(err) } - }() - // read from the pipe which receives data - request := httptest.NewRequest("POST", "/", pr) - request.Header.Add("Content-Type", writer.FormDataContentType()) + if !e.errorExpected { + if _, err := os.Stat(fmt.Sprintf("./testdata/uploads/%s", uploadedFiles[0].NewFileName)); os.IsNotExist(err) { + t.Errorf("%s: expected file to exist: %s", e.name, err.Error()) + } - var testTools Tools + // clean up + _ = os.Remove(fmt.Sprintf("./testdata/uploads/%s", uploadedFiles[0].NewFileName)) + } - uploadedFile, err := testTools.UploadOneFile(request, "./testdata/uploads/") - if err != nil { - t.Error(err) - } + if e.errorExpected && err == nil { + t.Errorf("%s: error expected, but none received", e.name) + } - if _, err := os.Stat(fmt.Sprintf("./testdata/uploads/%s", uploadedFile.NewFileName)); os.IsNotExist(err) { - t.Error("Expected file to exist", err) + // we're running table tests, so have to use a waitgroup + wg.Wait() } +} + +var uploadOneTests = []struct { + name string + uploadDir string + errorExpected bool +}{ + {name: "valid", uploadDir: "./testdata/uploads/", errorExpected: false}, + {name: "invalid", uploadDir: "//", errorExpected: true}, +} + +func TestTools_UploadOneFile(t *testing.T) { + for _, e := range uploadOneTests { + // set up a pipe to avoid buffering + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + + go func() { + defer writer.Close() + + // create the form data field 'file' + part, err := writer.CreateFormFile("file", "./testdata/img.png") + if err != nil { + t.Error(err) + } + + f, err := os.Open("./testdata/img.png") + if err != nil { + t.Error(err) + } + defer f.Close() + img, _, err := image.Decode(f) + if err != nil { + t.Error("error decoding image", err) + } + + err = png.Encode(part, img) + if err != nil { + t.Error(err) + } + }() + + // read from the pipe which receives data + request := httptest.NewRequest("POST", "/", pr) + request.Header.Add("Content-Type", writer.FormDataContentType()) + + var testTools Tools + testTools.AllowedFileTypes = []string{"image/png"} + + uploadedFiles, err := testTools.UploadOneFile(request, e.uploadDir, true) + if e.errorExpected && err == nil { + t.Errorf("%s: error expected, but none received", e.name) + } + + if !e.errorExpected { + if _, err := os.Stat(fmt.Sprintf("./testdata/uploads/%s", uploadedFiles.NewFileName)); os.IsNotExist(err) { + t.Errorf("%s: expected file to exist: %s", e.name, err.Error()) + } - // clean up - _ = os.Remove(fmt.Sprintf("./testdata/uploads/%s", uploadedFile.NewFileName)) + // clean up + _ = os.Remove(fmt.Sprintf("./testdata/uploads/%s", uploadedFiles.NewFileName)) + } + } } func TestTools_CreateDirIfNotExist(t *testing.T) { @@ -248,3 +446,197 @@ func TestTools_CreateDirIfNotExist(t *testing.T) { _ = os.Remove("./testdata/myDir") } + +func TestTools_CreateDirIfNotExistInvalidDirectory(t *testing.T) { + var testTool Tools + + // we should not be able to create a directory at the root level (no permissions) + err := testTool.CreateDirIfNotExist("/mydir") + if err == nil { + t.Error(errors.New("able to create a directory where we should not be able to")) + } +} + +var slugTests = []struct { + name string + s string + expected string + errorExpected bool +}{ + {name: "valid string", s: "now is the time", expected: "now-is-the-time", errorExpected: false}, + {name: "empty string", s: "", expected: "", errorExpected: true}, + {name: "complex string", s: "Now is the time for all GOOD men! + Fish & such &^?123", expected: "now-is-the-time-for-all-good-men-fish-such-123", errorExpected: false}, + {name: "japanese string", s: "こんにちは世界", expected: "", errorExpected: true}, + {name: "japanese string plus roman characters", s: "こんにちは世界 hello world", expected: "hello-world", errorExpected: false}, +} + +func TestTools_Slugify(t *testing.T) { + var testTool Tools + + for _, e := range slugTests { + slug, err := testTool.Slugify(e.s) + if err != nil && !e.errorExpected { + t.Errorf("%s: error received when none expected: %s", e.name, err.Error()) + } + + if !e.errorExpected && slug != e.expected { + t.Errorf("%s: wrong slug returned; expected %s but got %s", e.name, e.expected, slug) + } + } +} + +var writeXMLTests = []struct { + name string + payload any + errorExpected bool +}{ + { + name: "valid", + payload: XMLResponse{ + Error: false, + Message: "foo", + }, + errorExpected: false, + }, + { + name: "invalid", + payload: make(chan int), + errorExpected: true, + }, +} + +func TestTools_WriteXML(t *testing.T) { + for _, e := range writeXMLTests { + // create a variable of type toolbox.Tools, and just use the defaults. + var testTools Tools + + rr := httptest.NewRecorder() + + headers := make(http.Header) + headers.Add("FOO", "BAR") + err := testTools.WriteXML(rr, http.StatusOK, e.payload, headers) + if err != nil && !e.errorExpected { + t.Errorf("%s, failed to write XML: %v", e.name, err) + } + + if err == nil && e.errorExpected { + t.Errorf("%s: error expected, but none received", e.name) + } + } +} + +var xmlTests = []struct { + name string + xml string + maxBytes int + errorExpected bool +}{ + { + name: "good xml", + xml: `John SmithJane Jones`, + errorExpected: false, + }, + { + name: "badly formatted xml", + xml: `John SmithJane Jones`, + errorExpected: true, + }, + { + name: "too big", + xml: `John SmithJane Jones`, + maxBytes: 10, + errorExpected: true, + }, + { + name: "double xml", + xml: `John SmithJane Jones + Luke SkywalkerR2D2`, + errorExpected: true, + }, +} + +func TestTools_ReadXML(t *testing.T) { + for _, e := range xmlTests { + // create a variable of type toolbox.Tools, and just use the defaults. + var tools Tools + + if e.maxBytes != 0 { + tools.MaxXMLSize = e.maxBytes + } + // create a request with the body. + req, err := http.NewRequest("POST", "/", bytes.NewReader([]byte(e.xml))) + if err != nil { + t.Log("Error", err) + } + + // create a test response recorder, which satisfies the requirements + // for a ResponseWriter. + rr := httptest.NewRecorder() + + // call ReadXML and check for an error. + var note struct { + To string `xml:"to"` + From string `xml:"from"` + } + + err = tools.ReadXML(rr, req, ¬e) + if e.errorExpected && err == nil { + t.Errorf("%s: expected an error, but did not get one", e.name) + } else if !e.errorExpected && err != nil { + t.Errorf("%s: did not expect an error, but got one: %s", e.name, err) + } + } +} + +func TestTools_ErrorXML(t *testing.T) { + var testTools Tools + + rr := httptest.NewRecorder() + err := testTools.ErrorXML(rr, errors.New("some error"), http.StatusServiceUnavailable) + if err != nil { + t.Error(err) + } + + var requestPayload XMLResponse + decoder := xml.NewDecoder(rr.Body) + err = decoder.Decode(&requestPayload) + if err != nil { + t.Error("received error when decoding ErrorXML payload:", err) + } + + if !requestPayload.Error { + t.Error("error set to false in response from ErrorXML, and should be set to true") + } + + if rr.Code != http.StatusServiceUnavailable { + t.Errorf("wrong status code returned; expected 503, but got %d", rr.Code) + } +} + +func TestTools_InArray(t *testing.T) { + var testTools Tools + if result := testTools.ContainsElement(55, []int{23, 45, 46, 68}); result != false { + t.Errorf("ContainsElement() result: %v, expect: %v", result, false) + } + if result := testTools.ContainsElement(45, []int{23, 45, 46, 68}); result != true { + t.Errorf("ContainsElement() result: %v, expect: %v", result, true) + } + if result := testTools.ContainsElement(45, []string{"abc", "def"}); result != false { + t.Errorf("ContainsElement() result: %v, expect: %v", result, false) + } + if result := testTools.ContainsElement("def", []string{"abc", "def"}); result != true { + t.Errorf("ContainsElement() result: %v, expect: %v", result, true) + } + type test struct { + name string + } + tests := []test{{name: "abc"}, {name: "def"}} + t1 := test{name: "def"} + t2 := test{name: "xyz"} + if result := testTools.ContainsElement(t1, tests); result != true { + t.Errorf("ContainsElement() result: %v, expect: %v", result, true) + } + if result := testTools.ContainsElement(t2, tests); result != false { + t.Errorf("ContainsElement() result: %v, expect: %v", result, false) + } +}