diff --git a/traffic_ops/traffic_ops_golang/api/errors.go b/traffic_ops/traffic_ops_golang/api/errors.go index 22e2aef14a..2c275087af 100644 --- a/traffic_ops/traffic_ops_golang/api/errors.go +++ b/traffic_ops/traffic_ops_golang/api/errors.go @@ -23,6 +23,8 @@ import ( "errors" "fmt" "net/http" + + "github.com/apache/trafficcontrol/v8/lib/go-util" ) // Errs is the concrete implementation of Errors, which is used so that we can @@ -221,7 +223,7 @@ func NewSystemErrorf(format string, args ...any) Errors { // and has the appropriate response code. func NewUserError(err error) Errors { return &Errs{ - code: http.StatusInternalServerError, + code: http.StatusBadRequest, systemError: nil, userError: err, } @@ -231,7 +233,7 @@ func NewUserError(err error) Errors { // having the appropriate response code and containing the given message. func NewUserErrorString(err string) Errors { return &Errs{ - code: http.StatusInternalServerError, + code: http.StatusBadRequest, systemError: nil, userError: errors.New(err), } @@ -243,7 +245,7 @@ func NewUserErrorString(err string) Errors { // wrapping). func NewUserErrorf(format string, args ...any) Errors { return &Errs{ - code: http.StatusInternalServerError, + code: http.StatusBadRequest, systemError: nil, userError: fmt.Errorf(format, args...), } @@ -258,3 +260,25 @@ func NewResourceModifiedError() Errors { userError: ResourceModifiedError, } } + +// NewUserErrorFromErrorList creates a new user-facing error (400 Bad Request) +// by concatenating the given list of errors. Uniquely, this can return nil +// if the passed slice is empty (or nil). +func NewUserErrorFromErrorList(errs []error) Errors { + err := util.JoinErrs(errs) + if err == nil { + return nil + } + return NewUserError(err) +} + +// NewNotFoundError creates an Errors that contains an HTTP Not Found status +// code and the given error message as a user-visible error (supports wrapping +// with the '%w' format specifier verb). +func NewNotFoundError(format string, args ...any) Errors { + return &Errs{ + code: http.StatusNotFound, + systemError: nil, + userError: fmt.Errorf(format, args...), + } +} diff --git a/traffic_ops/traffic_ops_golang/api/errors_test.go b/traffic_ops/traffic_ops_golang/api/errors_test.go index f2fff8a410..ec1a1606ac 100644 --- a/traffic_ops/traffic_ops_golang/api/errors_test.go +++ b/traffic_ops/traffic_ops_golang/api/errors_test.go @@ -160,17 +160,32 @@ func ExampleNewSystemErrorf() { } func ExampleNewUserError() { fmt.Println(NewUserError(errors.New("testquest")).String()) - // Output: 500 Internal Server Error, SystemError='', UserError='testquest' + // Output: 400 Bad Request, SystemError='', UserError='testquest' } func ExampleNewUserErrorString() { fmt.Println(NewUserErrorString("testquest").String()) - // Output: 500 Internal Server Error, SystemError='', UserError='testquest' + // Output: 400 Bad Request, SystemError='', UserError='testquest' } func ExampleNewUserErrorf() { fmt.Println(NewUserErrorf("test: %w", errors.New("quest")).String()) - // Output: 500 Internal Server Error, SystemError='', UserError='test: quest' + // Output: 400 Bad Request, SystemError='', UserError='test: quest' } func ExampleNewResourceModifiedError() { fmt.Println(NewResourceModifiedError().String()) // Output: 412 Precondition Failed, SystemError='', UserError='resource was modified since the time specified by the request headers' } +func ExampleNewUserErrorFromErrorList() { + errs := []error{} + fmt.Println(NewUserErrorFromErrorList(errs)) + + errs = append(errs, errors.New("Test")) + errs = append(errs, errors.New("Quest")) + fmt.Println(NewUserErrorFromErrorList(errs)) + + // Output: + // Test, Quest +} +func ExampleNewNotFoundError() { + fmt.Println(NewNotFoundError("test: %s", "quest").String()) + // Output: 404 Not Found, SystemError='', UserError='test: quest' +} diff --git a/traffic_ops/traffic_ops_golang/api/info.go b/traffic_ops/traffic_ops_golang/api/info.go index f6462c672d..bd224b1c44 100644 --- a/traffic_ops/traffic_ops_golang/api/info.go +++ b/traffic_ops/traffic_ops_golang/api/info.go @@ -257,9 +257,9 @@ func (inf *Info) Close() { // // This CANNOT be used by any Info that wasn't constructed for the caller by // Wrap - ing a Handler (yet). -func (inf Info) WriteOKResponse(resp any) (int, error, error) { +func (inf Info) WriteOKResponse(resp any) error { WriteResp(inf.w, inf.request, resp) - return http.StatusOK, nil, nil + return nil } // WriteOKResponseWithSummary writes a 200 OK response with the given object as @@ -272,9 +272,9 @@ func (inf Info) WriteOKResponse(resp any) (int, error, error) { // Deprecated: Summary sections on responses were intended to cover up for a // deficiency in jQuery-based tables on the front-end, so now that we aren't // using those anymore it serves no purpose. -func (inf Info) WriteOKResponseWithSummary(resp any, count uint64) (int, error, error) { +func (inf Info) WriteOKResponseWithSummary(resp any, count uint64) error { WriteRespWithSummary(inf.w, inf.request, resp, count) - return http.StatusOK, nil, nil + return nil } // WriteNotModifiedResponse writes a 304 Not Modified response with the given @@ -282,19 +282,19 @@ func (inf Info) WriteOKResponseWithSummary(resp any, count uint64) (int, error, // // This CANNOT be used by any Info that wasn't constructed for the caller by // Wrap - ing a Handler (yet). -func (inf Info) WriteNotModifiedResponse(lastModified time.Time) (int, error, error) { +func (inf Info) WriteNotModifiedResponse(lastModified time.Time) error { inf.w.Header().Set(rfc.LastModified, FormatLastModified(lastModified)) inf.w.WriteHeader(http.StatusNotModified) setRespWritten(inf.request) - return http.StatusNotModified, nil, nil + return nil } // WriteSuccessResponse writes the given response object as the `response` // property of the response body, with the accompanying message as a // success-level Alert. -func (inf Info) WriteSuccessResponse(resp any, message string) (int, error, error) { +func (inf Info) WriteSuccessResponse(resp any, message string) error { WriteAlertsObj(inf.w, inf.request, http.StatusOK, tc.CreateAlerts(tc.SuccessLevel, message), resp) - return http.StatusOK, nil, nil + return nil } // WriteCreatedResponse writes the given response object as the `response` @@ -302,11 +302,11 @@ func (inf Info) WriteSuccessResponse(resp any, message string) (int, error, erro // accompanying message as a success-level Alert. It also sets the Location // header to the given path. This will be automatically prefaced with the // correct path to the API version the client requested. -func (inf Info) WriteCreatedResponse(resp any, message, path string) (int, error, error) { +func (inf Info) WriteCreatedResponse(resp any, message, path string) error { inf.w.Header().Set(rfc.Location, strings.Join([]string{"/api", inf.Version.String(), strings.TrimPrefix(path, "/")}, "/")) inf.w.WriteHeader(http.StatusCreated) WriteAlertsObj(inf.w, inf.request, http.StatusCreated, tc.CreateAlerts(tc.SuccessLevel, message), resp) - return http.StatusCreated, nil, nil + return nil } // RequestHeaders returns the headers sent by the client in the API request. @@ -427,3 +427,18 @@ func (inf Info) DefaultSort(param string) { inf.Params["orderby"] = param } } + +// HandleErrors handles errors that occur during handling API operations - as +// represented by an appropriate Errors. +func (inf Info) HandleErrors(errs Errors) { + HandleErr(inf.w, inf.request, inf.Tx.Tx, errs.Code(), errs.UserError(), errs.SystemError()) +} + +// HandleDBError handles errors from database actions. This is identical to +// ParseDBError followed by handling what that returns, but does the +// intermediary step for you. +func (inf Info) HandleDBError(err error) error { + userErr, sysErr, code := ParseDBError(err) + HandleErr(inf.w, inf.request, inf.Tx.Tx, code, userErr, sysErr) + return nil +} diff --git a/traffic_ops/traffic_ops_golang/api/info_test.go b/traffic_ops/traffic_ops_golang/api/info_test.go index dafcd4b034..b3d1f9d629 100644 --- a/traffic_ops/traffic_ops_golang/api/info_test.go +++ b/traffic_ops/traffic_ops_golang/api/info_test.go @@ -330,15 +330,9 @@ func TestInfo_WriteOKResponse(t *testing.T) { request: r, w: w, } - code, userErr, sysErr := inf.WriteOKResponse("test") - if code != http.StatusOK { - t.Errorf("WriteOKResponse should return a %d %s code, got: %d %s", http.StatusOK, http.StatusText(http.StatusOK), code, http.StatusText(code)) - } - if userErr != nil { - t.Errorf("Unexpected user error: %v", userErr) - } - if sysErr != nil { - t.Errorf("Unexpected system error: %v", sysErr) + err := inf.WriteOKResponse("test") + if err != nil { + t.Errorf("Unexpected error: %v", err) } if w.Code != http.StatusOK { @@ -354,15 +348,9 @@ func TestInfo_WriteOKResponseWithSummary(t *testing.T) { request: r, w: w, } - code, userErr, sysErr := inf.WriteOKResponseWithSummary("test", 42) - if code != http.StatusOK { - t.Errorf("WriteOKResponseWithSummary should return a %d %s code, got: %d %s", http.StatusOK, http.StatusText(http.StatusOK), code, http.StatusText(code)) - } - if userErr != nil { - t.Errorf("Unexpected user error: %v", userErr) - } - if sysErr != nil { - t.Errorf("Unexpected system error: %v", sysErr) + err := inf.WriteOKResponseWithSummary("test", 42) + if err != nil { + t.Errorf("Unexpected error: %v", err) } if w.Code != http.StatusOK { @@ -378,15 +366,9 @@ func TestInfo_WriteNotModifiedResponse(t *testing.T) { request: r, w: w, } - code, userErr, sysErr := inf.WriteNotModifiedResponse(time.Time{}) - if code != http.StatusNotModified { - t.Errorf("WriteNotModifiedResponse should return a %d %s code, got: %d %s", http.StatusNotModified, http.StatusText(http.StatusNotModified), code, http.StatusText(code)) - } - if userErr != nil { - t.Errorf("Unexpected user error: %v", userErr) - } - if sysErr != nil { - t.Errorf("Unexpected system error: %v", sysErr) + err := inf.WriteNotModifiedResponse(time.Time{}) + if err != nil { + t.Errorf("Unexpected error: %v", err) } if w.Code != http.StatusNotModified { @@ -402,15 +384,9 @@ func TestInfo_WriteSuccessResponse(t *testing.T) { request: r, w: w, } - code, userErr, sysErr := inf.WriteSuccessResponse("test", "quest") - if code != http.StatusOK { - t.Errorf("WriteSuccessResponse should return a %d %s code, got: %d %s", http.StatusOK, http.StatusText(http.StatusOK), code, http.StatusText(code)) - } - if userErr != nil { - t.Errorf("Unexpected user error: %v", userErr) - } - if sysErr != nil { - t.Errorf("Unexpected system error: %v", sysErr) + err := inf.WriteSuccessResponse("test", "quest") + if err != nil { + t.Errorf("Unexpected error: %v", err) } if w.Code != http.StatusOK { @@ -442,15 +418,9 @@ func TestInfo_WriteCreatedResponse(t *testing.T) { Version: &Version{Major: 420, Minor: 9001}, w: w, } - code, userErr, sysErr := inf.WriteCreatedResponse("test", "quest", "mypath") - if code != http.StatusCreated { - t.Errorf("WriteCreatedResponse should return a %d %s code, got: %d %s", http.StatusCreated, http.StatusText(http.StatusCreated), code, http.StatusText(code)) - } - if userErr != nil { - t.Errorf("Unexpected user error: %v", userErr) - } - if sysErr != nil { - t.Errorf("Unexpected system error: %v", sysErr) + err := inf.WriteCreatedResponse("test", "quest", "mypath") + if err != nil { + t.Errorf("Unexpected error: %v", err) } if w.Code != http.StatusCreated { @@ -802,3 +772,70 @@ func ExampleInfo_DefaultSort() { // Output: testquest // testquest } + +func TestInfo_HandleErrors(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/", nil) + + inf := Info{ + request: r, + w: w, + Tx: &sqlx.Tx{}, + } + + code := http.StatusFailedDependency + inf.HandleErrors(NewErrors(code, errors.New("test"), errors.New("quest"))) + + if w.Code != code { + t.Errorf("incorrect response status code; want: %d, got: %d", code, w.Code) + } + + var alerts tc.Alerts + if err := json.NewDecoder(w.Body).Decode(&alerts); err != nil { + t.Fatalf("couldn't decode response body: %v", err) + } + + if len(alerts.Alerts) != 1 { + t.Fatalf("expected exactly one alert; got: %d", len(alerts.Alerts)) + } + alert := alerts.Alerts[0] + if alert.Level != tc.ErrorLevel.String() { + t.Errorf("Incorrect alert level; want: %s, got: %s", tc.ErrorLevel, alert.Level) + } + if alert.Text != "test" { + t.Errorf("Incorrect alert text; want: 'test', got: '%s'", alert.Text) + } +} + +func TestInfo_HandleDBError(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodPost, "/", nil) + + inf := Info{ + request: r, + w: w, + Tx: &sqlx.Tx{}, + } + + inf.HandleDBError(errors.New("a non-parsable error")) + + if code := http.StatusInternalServerError; w.Code != code { + t.Errorf("incorrect response status code; want: %d, got: %d", code, w.Code) + } + + var alerts tc.Alerts + if err := json.NewDecoder(w.Body).Decode(&alerts); err != nil { + t.Fatalf("couldn't decode response body: %v", err) + } + + if len(alerts.Alerts) != 1 { + t.Fatalf("expected exactly one alert; got: %d", len(alerts.Alerts)) + } + alert := alerts.Alerts[0] + if alert.Level != tc.ErrorLevel.String() { + t.Errorf("Incorrect alert level; want: %s, got: %s", tc.ErrorLevel, alert.Level) + } + if expected := http.StatusText(http.StatusInternalServerError); alert.Text != expected { + t.Errorf("Incorrect alert text; want: '%s', got: '%s'", expected, alert.Text) + } +} diff --git a/traffic_ops/traffic_ops_golang/api/shared_handlers.go b/traffic_ops/traffic_ops_golang/api/shared_handlers.go index 44ed2e57f0..9321ad13c3 100644 --- a/traffic_ops/traffic_ops_golang/api/shared_handlers.go +++ b/traffic_ops/traffic_ops_golang/api/shared_handlers.go @@ -679,10 +679,13 @@ func parseMultipleCreates(data []byte, desiredType reflect.Type, inf *Info) ([]C } // A Handler is an API endpoint handlers. They take in Info helper objects and -// return - in order - an HTTP response status code, a user-facing error (if one -// occurred), and a system-only error not safe for exposure to clients (if one -// occurred). -type Handler = func(*Info) (int, error, error) +// return errors that occur in handling. If no error occurs, the response MUST +// have been written. It is strongly recommended that the returned error be an +// Errors type as defined by this package, as that will allow for the most +// versatile handling. Any other kind of error returned by a Handler is treated +// as a "system-only" error, which returns no information to the user beyond +// that the request failed. +type Handler = func(*Info) error // Wrap wraps an API endpoint handler in the more generic HTTP request handler // type from the http package. This constructs and provides the Info for the @@ -709,9 +712,16 @@ func Wrap(h Handler, requiredParams, intParams []string) http.HandlerFunc { } inf.w = w - errCode, userErr, sysErr = h(inf) - if userErr != nil || sysErr != nil { - HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr) + err := h(inf) + if err == nil { + return + } + + var apiErr Errors + if errors.As(err, &apiErr) { + inf.HandleErrors(apiErr) + } else { + HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, err) } } } diff --git a/traffic_ops/traffic_ops_golang/api/shared_handlers_test.go b/traffic_ops/traffic_ops_golang/api/shared_handlers_test.go index 98bc2f0d97..d58f4e52fa 100644 --- a/traffic_ops/traffic_ops_golang/api/shared_handlers_test.go +++ b/traffic_ops/traffic_ops_golang/api/shared_handlers_test.go @@ -351,11 +351,11 @@ func TestDeleteHandler(t *testing.T) { // The constructed handler will return an error if fail is true, or nothing // special otherwise. func testingHandler(fail bool) Handler { - return func(inf *Info) (int, error, error) { + return func(inf *Info) error { if fail { - return http.StatusBadRequest, errors.New("testing user error"), errors.New("testing system error") + return NewErrors(http.StatusBadRequest, errors.New("testing user error"), errors.New("testing system error")) } - return http.StatusOK, nil, nil + return nil } } @@ -416,3 +416,42 @@ func TestWrap(t *testing.T) { t.Errorf("not all expectations were met: %v", err) } } + +func TestWrap_HandleSimpleErrors(t *testing.T) { + var handler Handler = func(inf *Info) error { + return errors.New("testquest") + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to open a stub database connection: %v", err) + } + defer func() { + mock.ExpectClose() + err := db.Close() + if err != nil { + t.Errorf("failed to close database: %v", err) + } + err = mock.ExpectationsWereMet() + if err != nil { + t.Errorf("expectations unmet: %v", err) + } + }() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + r = wrapContext(r, ConfigContextKey, &config.Config{ConfigTrafficOpsGolang: config.ConfigTrafficOpsGolang{DBQueryTimeoutSeconds: 1000}}) + r = wrapContext(r, DBContextKey, &sqlx.DB{DB: db}) + r = wrapContext(r, TrafficVaultContextKey, &disabled.Disabled{}) + r = wrapContext(r, ReqIDContextKey, uint64(0)) + r = wrapContext(r, auth.CurrentUserKey, auth.CurrentUser{}) + r = wrapContext(r, PathParamsKey, make(map[string]string)) + mock.ExpectBegin() + mock.ExpectRollback() + h := Wrap(handler, nil, nil) + h(w, r) + if w.Code != http.StatusInternalServerError { + t.Errorf("wrong status code when the trivial handler is used without an API version; want: %d, got: %d", http.StatusInternalServerError, w.Code) + } + +} diff --git a/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations.go b/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations.go index 8cf97086ae..4c4975c019 100644 --- a/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations.go +++ b/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations.go @@ -499,17 +499,17 @@ func addTenancyStmt(where string) string { return where } -func getCDNFederations(inf *api.Info) ([]tc.CDNFederationV5, time.Time, int, error, error) { +func getCDNFederations(inf *api.Info) ([]tc.CDNFederationV5, time.Time, api.Errors) { tenantList, err := tenant.GetUserTenantIDListTx(inf.Tx.Tx, inf.User.TenantID) if err != nil { - return nil, time.Time{}, http.StatusInternalServerError, nil, fmt.Errorf("getting tenant list for user: %w", err) + return nil, time.Time{}, api.NewSystemErrorf("getting tenant list for user: %w", err) } cols := paramColumnInfo(*inf.Version) where, orderBy, pagination, queryValues, errs := dbhelpers.BuildWhereAndOrderByAndPagination(inf.Params, cols) if len(errs) > 0 { - return nil, time.Time{}, http.StatusBadRequest, util.JoinErrs(errs), nil + return nil, time.Time{}, api.NewUserErrorFromErrorList(errs) } queryValues["tenantIDs"] = pq.Array(tenantList) @@ -520,7 +520,7 @@ func getCDNFederations(inf *api.Info) ([]tc.CDNFederationV5, time.Time, int, err cont, max := ims.TryIfModifiedSinceQuery(inf.Tx, inf.RequestHeaders(), queryValues, query) if !cont { log.Debugln("IMS HIT") - return nil, max, http.StatusNotModified, nil, nil + return nil, max, api.NewErrors(http.StatusNotModified, nil, nil) } log.Debugln("IMS MISS") } else { @@ -531,7 +531,7 @@ func getCDNFederations(inf *api.Info) ([]tc.CDNFederationV5, time.Time, int, err rows, err := inf.Tx.NamedQuery(query, queryValues) if err != nil { userErr, sysErr, code := api.ParseDBError(err) - return nil, time.Time{}, code, userErr, sysErr + return nil, time.Time{}, api.NewErrors(code, userErr, sysErr) } defer log.Close(rows, "closing CDNFederation rows") @@ -552,21 +552,21 @@ func getCDNFederations(inf *api.Info) ([]tc.CDNFederationV5, time.Time, int, err &fed.DeliveryService.XMLID, ) if err != nil { - return nil, time.Time{}, http.StatusInternalServerError, nil, fmt.Errorf("scanning a CDN Federation: %w", err) + return nil, time.Time{}, api.NewSystemErrorf("scanning a CDN Federation: %w", err) } ret = append(ret, fed) } - return ret, time.Time{}, http.StatusOK, nil, nil + return ret, time.Time{}, nil } // Read handles GET requests to `cdns/{{name}}/federations`. -func Read(inf *api.Info) (int, error, error) { +func Read(inf *api.Info) error { api.DefaultSort(inf, "cname") - feds, max, code, userErr, sysErr := getCDNFederations(inf) - if userErr != nil || sysErr != nil { - return code, userErr, sysErr + feds, max, err := getCDNFederations(inf) + if err != nil { + return err } if feds == nil { return inf.WriteNotModifiedResponse(max) @@ -575,10 +575,10 @@ func Read(inf *api.Info) (int, error, error) { } // ReadID handles GET requests to `cdns/{{name}}/federations/{{ID}}`. -func ReadID(inf *api.Info) (int, error, error) { - feds, max, code, userErr, sysErr := getCDNFederations(inf) - if userErr != nil || sysErr != nil { - return code, userErr, sysErr +func ReadID(inf *api.Info) error { + feds, max, err := getCDNFederations(inf) + if err != nil { + return err } if feds == nil { return inf.WriteNotModifiedResponse(max) @@ -586,10 +586,10 @@ func ReadID(inf *api.Info) (int, error, error) { id := inf.IntParams["id"] if len(feds) == 0 { - return http.StatusNotFound, fmt.Errorf("no such Federation #%d in CDN %s", id, inf.Params["name"]), nil + return api.NewNotFoundError("no such Federation #%d in CDN %s", id, inf.Params["name"]) } if len(feds) > 1 { - return http.StatusInternalServerError, nil, fmt.Errorf("%d CDN federations found by ID: %d", len(feds), id) + return fmt.Errorf("%d CDN federations found by ID: %d", len(feds), id) } return inf.WriteOKResponse(feds[0]) } @@ -611,10 +611,10 @@ func validate(fed tc.CDNFederationV5) error { } // Create handles POST requests to `cdns/{{name}}/federations`. -func Create(inf *api.Info) (int, error, error) { +func Create(inf *api.Info) error { var fed tc.CDNFederationV5 if err := inf.DecodeBody(&fed); err != nil { - return http.StatusBadRequest, fmt.Errorf("parsing request body: %w", err), nil + return api.NewUserErrorf("parsing request body: %w", err) } // You can't set this at creation time, but if it was in the request it @@ -624,13 +624,13 @@ func Create(inf *api.Info) (int, error, error) { err := validate(fed) if err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } err = inf.Tx.Tx.QueryRow(insertQuery, fed.CName, fed.TTL, fed.Description).Scan(&fed.ID, &fed.LastUpdated) if err != nil { userErr, sysErr, code := api.ParseDBError(err) - return code, userErr, fmt.Errorf("inserting a CDN Federation: %w", sysErr) + return api.NewErrors(code, userErr, fmt.Errorf("inserting a CDN Federation: %w", sysErr)) } changeLogMsg := fmt.Sprintf("CDNFEDERATION: %s, ID:%d, ACTION: Created cdnFederation", fed.CName, fed.ID) api.CreateChangeLogRawTx(api.Created, changeLogMsg, inf.User, inf.Tx.Tx) @@ -638,10 +638,10 @@ func Create(inf *api.Info) (int, error, error) { } // Update handles PUT requests to `cdns/{{name}}/federations/{{id}}`. -func Update(inf *api.Info) (int, error, error) { +func Update(inf *api.Info) error { var fed tc.CDNFederationV5 if err := inf.DecodeBody(&fed); err != nil { - return http.StatusBadRequest, fmt.Errorf("parsing request body: %w", err), nil + return api.NewUserErrorf("parsing request body: %w", err) } id := inf.IntParams["id"] @@ -649,10 +649,10 @@ func Update(inf *api.Info) (int, error, error) { var lastModified time.Time err := inf.Tx.QueryRow("SELECT last_updated FROM federation WHERE id = $1", id).Scan(&lastModified) if err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("getting last modified time for Federation #%d: %w", id, err) + return fmt.Errorf("getting last modified time for Federation #%d: %w", id, err) } if !api.IsUnmodified(inf.RequestHeaders(), lastModified) { - return http.StatusPreconditionFailed, api.ResourceModifiedError, nil + return api.NewResourceModifiedError() } // You can't set this via a PUT request, but if it was in the request it @@ -663,13 +663,12 @@ func Update(inf *api.Info) (int, error, error) { err = validate(fed) if err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } err = inf.Tx.Tx.QueryRow(updateQuery, fed.CName, fed.TTL, fed.Description, id).Scan(&fed.LastUpdated) if err != nil { - userErr, sysErr, code := api.ParseDBError(err) - return code, userErr, sysErr + return inf.HandleDBError(err) } changeLogMsg := fmt.Sprintf("CDNFEDERATION: %s, ID:%d, ACTION: Updated cdnFederation", fed.CName, id) @@ -678,14 +677,13 @@ func Update(inf *api.Info) (int, error, error) { } // Delete handles DELETE requests to `cdns/{{name}}/federations/{{id}}`. -func Delete(inf *api.Info) (int, error, error) { +func Delete(inf *api.Info) error { id := inf.IntParams["id"] var fed tc.CDNFederationV5 err := inf.Tx.QueryRow(deleteQuery, id).Scan(&fed.CName, &fed.Description, &fed.ID, &fed.LastUpdated, &fed.TTL) if err != nil { - userErr, sysErr, code := api.ParseDBError(err) - return code, userErr, sysErr + return inf.HandleDBError(err) } changeLogMsg := fmt.Sprintf("CDNFEDERATION:%s, ID:%d, ACTION: Deleted cdnFederation", fed.CName, fed.ID) api.CreateChangeLogRawTx(api.Deleted, changeLogMsg, inf.User, inf.Tx.Tx) diff --git a/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations_test.go b/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations_test.go index 27d087f259..f34fbd88b7 100644 --- a/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations_test.go +++ b/traffic_ops/traffic_ops_golang/cdnfederation/cdnfederations_test.go @@ -106,25 +106,28 @@ func gettingUserTenantListFails(t *testing.T) { tx.Rollback() }() - err := errors.New("unknown failure") - mock.ExpectQuery("WITH RECURSIVE").WillReturnError(err) - _, _, code, userErr, sysErr := getCDNFederations(&api.Info{Tx: tx, User: &auth.CurrentUser{TenantID: 1}}) + e := errors.New("unknown failure") + mock.ExpectQuery("WITH RECURSIVE").WillReturnError(e) + _, _, err := getCDNFederations(&api.Info{Tx: tx, User: &auth.CurrentUser{TenantID: 1}}) + if err == nil { + t.Fatal("expected an error to occur") + } - if code != http.StatusInternalServerError { + if code := err.Code(); code != http.StatusInternalServerError { t.Errorf("Incorrect response code when getting user tenants fails; want: %d, got: %d", http.StatusInternalServerError, code) } - if userErr != nil { + if userErr := err.UserError(); userErr != nil { t.Errorf("Unexpected user-facing error: %v", userErr) } - if sysErr == nil { + if err.SystemError() == nil { t.Fatal("Expected a system error but didn't get one") } // You can't use `errors.Is` here because sqlmock doesn't wrap the error you // give it, so we have to resort to comparing text and praying there's no // weird coincidence going on behind the scenes. - if !strings.Contains(sysErr.Error(), err.Error()) { - t.Errorf("Incorrect system error returned; want: %v, got: %v", err, sysErr) + if !strings.Contains(err.SystemError().Error(), e.Error()) { + t.Errorf("Incorrect system error returned; want: %v, got: %v", e, err.SystemError()) } } @@ -149,18 +152,22 @@ func buildingQueryPartsFails(t *testing.T) { User: &auth.CurrentUser{TenantID: 1}, Version: &api.Version{Major: 5}, } - _, _, code, userErr, sysErr := getCDNFederations(&inf) - if code != http.StatusBadRequest { + _, _, err := getCDNFederations(&inf) + if err == nil { + t.Fatal("Expected an error to occur") + } + + if code := err.Code(); code != http.StatusBadRequest { t.Errorf("Incorrect response code when getting user tenants fails; want: %d, got: %d", http.StatusBadRequest, code) } - if sysErr != nil { + if sysErr := err.SystemError(); sysErr != nil { t.Errorf("Unexpected system error: %v", sysErr) } - if userErr == nil { + if err.UserError() == nil { t.Fatal("Expected a user-facing error, but didn't get one") } - if !strings.Contains(userErr.Error(), "dsID") { - t.Errorf("Incorrect user error; expected it to mention 'dsID', but it didn't: %v", userErr) + if !strings.Contains(err.UserError().Error(), "dsID") { + t.Errorf("Incorrect user error; expected it to mention 'dsID', but it didn't: %v", err.UserError()) } } @@ -193,12 +200,9 @@ func everythingWorks(t *testing.T) { fedRows.AddRow(1, fed.ID, fed.CName, fed.TTL, fed.Description, fed.LastUpdated, fed.DeliveryService.ID, fed.DeliveryService.XMLID) mock.ExpectQuery("SELECT").WillReturnRows(fedRows) - feds, _, _, userErr, sysErr := getCDNFederations(&api.Info{Tx: tx, User: &auth.CurrentUser{TenantID: 1}, Version: &api.Version{Major: 5}}) - if userErr != nil { - t.Errorf("Unexpected user-facing error: %v", userErr) - } - if sysErr != nil { - t.Errorf("Unexpected system error: %v", sysErr) + feds, _, err := getCDNFederations(&api.Info{Tx: tx, User: &auth.CurrentUser{TenantID: 1}, Version: &api.Version{Major: 5}}) + if err != nil { + t.Errorf("Unexpected error: %+v", err) } if l := len(feds); l != 1 { t.Fatalf("Expected one federation to be returned; got: %d", l) diff --git a/traffic_ops/traffic_ops_golang/server/servers.go b/traffic_ops/traffic_ops_golang/server/servers.go index 120d183d92..91a464a5a4 100644 --- a/traffic_ops/traffic_ops_golang/server/servers.go +++ b/traffic_ops/traffic_ops_golang/server/servers.go @@ -392,7 +392,7 @@ func validateCommon(s *tc.CommonServerProperties, tx *sql.Tx) ([]error, error) { return errs, nil } -func validateCommonV40(s *tc.ServerV40, tx *sql.Tx) ([]error, error) { +func validateCommonV40(s *tc.ServerV40, tx *sql.Tx) api.Errors { noSpaces := validation.NewStringRule(tovalidate.NoSpaces, "cannot contain spaces") @@ -410,7 +410,7 @@ func validateCommonV40(s *tc.ServerV40, tx *sql.Tx) ([]error, error) { }) if len(errs) > 0 { - return errs, nil + return api.NewUserErrorFromErrorList(errs) } if _, err := tc.ValidateTypeID(tx, s.TypeID, "server"); err != nil { @@ -428,9 +428,9 @@ func validateCommonV40(s *tc.ServerV40, tx *sql.Tx) ([]error, error) { if errors.Is(err, sql.ErrNoRows) { errs = append(errs, fmt.Errorf("no such profileName: '%s'", profile)) } else { - return nil, fmt.Errorf("unable to get CDN ID for profile name '%s': %w", profile, err) + return api.NewSystemErrorf("unable to get CDN ID for profile name '%s': %w", profile, err) } - return errs, nil + return api.NewUserErrorFromErrorList(errs) } log.Infof("got cdn id: %d from profile and cdn id: %d from server", cdnID, *s.CDNID) @@ -439,7 +439,7 @@ func validateCommonV40(s *tc.ServerV40, tx *sql.Tx) ([]error, error) { } } - return errs, nil + return api.NewUserErrorFromErrorList(errs) } func validateMTU(mtu interface{}) error { @@ -457,9 +457,9 @@ func validateMTU(mtu interface{}) error { return nil } -func validateV4(s *tc.ServerV40, tx *sql.Tx) (string, error, error) { +func validateV4(s *tc.ServerV40, tx *sql.Tx) (string, api.Errors) { if len(s.Interfaces) == 0 { - return "", errors.New("a server must have at least one interface"), nil + return "", api.NewUserErrorString("a server must have at least one interface") } var errs []error var serviceAddrV4Found bool @@ -520,10 +520,10 @@ func validateV4(s *tc.ServerV40, tx *sql.Tx) (string, error, error) { if !serviceAddrV6Found && !serviceAddrV4Found { errs = append(errs, errors.New("a server must have at least one service address")) } - usrErr, sysErr := validateCommonV40(s, tx) - errs = append(errs, usrErr...) - if sysErr != nil || len(errs) > 0 { - return serviceInterface, util.JoinErrs(errs), sysErr + err := validateCommonV40(s, tx) + if err != nil { + errs = append(errs, err.UserError()) + return "", api.NewErrors(err.Code(), util.JoinErrs(errs), err.SystemError()) } query := ` SELECT tmp.server, ip.address @@ -536,22 +536,22 @@ JOIN ip_address ip on ip.server = tmp.server WHERE (profiles = $1::text[]) ` var rows *sql.Rows - var err error + var e error //ProfileID already validated if s.ID != nil { - rows, err = tx.Query(query+" and tmp.server != $2", pq.Array(s.ProfileNames), *s.ID) + rows, e = tx.Query(query+" and tmp.server != $2", pq.Array(s.ProfileNames), *s.ID) } else { - rows, err = tx.Query(query, pq.Array(s.ProfileNames)) + rows, e = tx.Query(query, pq.Array(s.ProfileNames)) } - if err != nil { + if e != nil { errs = append(errs, errors.New("unable to determine service address uniqueness")) } else if rows != nil { defer rows.Close() for rows.Next() { var id int var ipaddress string - err = rows.Scan(&id, &ipaddress) - if err != nil { + e = rows.Scan(&id, &ipaddress) + if e != nil { errs = append(errs, errors.New("unable to determine service address uniqueness")) } else if (ipaddress == ipv4 || ipaddress == ipv6) && (s.ID == nil || *s.ID != id) { errs = append(errs, fmt.Errorf("there exists a server with id %v on the same profile that has the same service address %s", id, ipaddress)) @@ -559,13 +559,13 @@ WHERE (profiles = $1::text[]) } } - return serviceInterface, util.JoinErrs(errs), nil + return serviceInterface, api.NewUserErrorFromErrorList(errs) } -func validateV3(s *tc.ServerV30, tx *sql.Tx) (string, error, error) { +func validateV3(s *tc.ServerV30, tx *sql.Tx) (string, api.Errors) { if len(s.Interfaces) == 0 { - return "", errors.New("a server must have at least one interface"), nil + return "", api.NewUserErrorString("a server must have at least one interface") } var errs []error var serviceAddrV4Found bool @@ -630,8 +630,11 @@ func validateV3(s *tc.ServerV30, tx *sql.Tx) (string, error, error) { commonErrs, sysErr := validateCommon(&s.CommonServerProperties, tx) errs = append(errs, commonErrs...) - if len(errs) > 0 || sysErr != nil { - return serviceInterface, util.JoinErrs(errs), sysErr + if sysErr != nil { + return "", api.NewSystemError(sysErr) + } + if len(errs) > 0 { + return serviceInterface, api.NewUserErrorFromErrorList(errs) } query := ` SELECT s.ID, ip.address FROM server s @@ -650,7 +653,7 @@ and p.id = $1 rows, err = tx.Query(query, *s.ProfileID) } if err != nil { - return serviceInterface, util.JoinErrs(errs), fmt.Errorf("unable to determine service address uniqueness: querying: %w", err) + return serviceInterface, api.NewErrors(http.StatusInternalServerError, util.JoinErrs(errs), fmt.Errorf("unable to determine service address uniqueness: querying: %w", err)) } else if rows != nil { defer rows.Close() for rows.Next() { @@ -658,27 +661,27 @@ and p.id = $1 var ipaddress string err = rows.Scan(&id, &ipaddress) if err != nil { - return serviceInterface, util.JoinErrs(errs), fmt.Errorf("unable to determine service address uniqueness: scanning: %w", err) + return serviceInterface, api.NewErrors(http.StatusInternalServerError, util.JoinErrs(errs), fmt.Errorf("unable to determine service address uniqueness: scanning: %w", err)) } else if (ipaddress == ipv4 || ipaddress == ipv6) && (s.ID == nil || *s.ID != id) { errs = append(errs, fmt.Errorf("there exists a server with id %v on the same profile that has the same service address %s", id, ipaddress)) } } } - return serviceInterface, util.JoinErrs(errs), nil + return serviceInterface, api.NewUserErrorFromErrorList(errs) } // Read is the handler for GET requests to /servers. -func Read(inf *api.Info) (int, error, error) { +func Read(inf *api.Info) error { useIMS := inf.UseIMS() version := inf.Version - servers, serverCount, userErr, sysErr, errCode, maxTime := getServers(inf.RequestHeaders(), inf.Params, inf.Tx, inf.User, useIMS, *version, inf.Config.RoleBasedPermissions) - if useIMS && maxTime != nil && errCode == http.StatusNotModified { - return inf.WriteNotModifiedResponse(*maxTime) + servers, serverCount, maxTime, err := getServers(inf.RequestHeaders(), inf.Params, inf.Tx, inf.User, useIMS, *version, inf.Config.RoleBasedPermissions) + if err != nil { + return err } - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + if useIMS && maxTime != nil { + return inf.WriteNotModifiedResponse(*maxTime) } if version.GreaterThanOrEqualTo(&api.Version{Major: 5}) { return inf.WriteOKResponse(servers) @@ -697,12 +700,12 @@ func Read(inf *api.Info) (int, error, error) { for i, server := range downgraded { csp, err := dbhelpers.GetCommonServerPropertiesFromV4(server, tx) if err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("failed to get common server properties from V4 server struct: %w", err) + return api.NewSystemErrorf("failed to get common server properties from V4 server struct: %w", err) } v3Server, err := server.ToServerV3FromV4(csp) if err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("failed to convert servers to V3 format: %w", err) + return api.NewSystemErrorf("failed to convert servers to V3 format: %w", err) } v3Servers[i] = v3Server } @@ -735,7 +738,7 @@ func getServerCount(tx *sqlx.Tx, query string, queryValues map[string]interface{ return serverCount, nil } -func getServers(h http.Header, params map[string]string, tx *sqlx.Tx, user *auth.CurrentUser, useIMS bool, version api.Version, roleBasedPerms bool) ([]tc.ServerV5, uint64, error, error, int, *time.Time) { +func getServers(h http.Header, params map[string]string, tx *sqlx.Tx, user *auth.CurrentUser, useIMS bool, version api.Version, roleBasedPerms bool) ([]tc.ServerV5, uint64, *time.Time, api.Errors) { var maxTime time.Time var runSecond bool // Query Parameters to Database Query column mappings @@ -782,22 +785,21 @@ func getServers(h http.Header, params map[string]string, tx *sqlx.Tx, user *auth // don't allow query on ds outside user's tenant dsID, err = strconv.Atoi(dsIDStr) if err != nil { - return nil, 0, errors.New("dsId must be an integer"), nil, http.StatusNotFound, nil + return nil, 0, nil, api.NewNotFoundError("dsId must be an integer") } cdnID, _, err = dbhelpers.GetDSCDNIdFromID(tx.Tx, dsID) if err != nil { - return nil, 0, nil, err, http.StatusInternalServerError, nil + return nil, 0, nil, api.NewSystemError(err) } userErr, sysErr, _ := tenant.CheckID(tx.Tx, user, dsID) if userErr != nil || sysErr != nil { - return nil, 0, errors.New("Forbidden"), sysErr, http.StatusForbidden, nil + return nil, 0, nil, api.NewErrors(http.StatusForbidden, fmt.Errorf("Forbidden: %w", userErr), sysErr) } var joinSubQuery string if err := tx.QueryRow(deliveryservice.GetRequiredCapabilitiesQuery, dsID).Scan(pq.Array(&requiredCapabilities)); err != nil && err != sql.ErrNoRows { - err = fmt.Errorf("unable to get required capabilities for deliveryservice %d: %w", dsID, err) - return nil, 0, nil, err, http.StatusInternalServerError, nil + return nil, 0, nil, api.NewSystemErrorf("unable to get required capabilities for deliveryservice %d: %w", dsID, err) } if requiredCapabilities != nil && len(requiredCapabilities) > 0 { dsHasRequiredCapabilities = true @@ -809,7 +811,7 @@ func getServers(h http.Header, params map[string]string, tx *sqlx.Tx, user *auth // depending on ds type, also need to add mids dsType, _, _, err := dbhelpers.GetDeliveryServiceTypeAndCDNName(dsID, tx.Tx) if err != nil { - return nil, 0, nil, err, http.StatusInternalServerError, nil + return nil, 0, nil, api.NewSystemError(err) } usesMids = dsType.UsesMidCache() log.Debugf("Servers for ds %d; uses mids? %v\n", dsID, usesMids) @@ -827,7 +829,7 @@ func getServers(h http.Header, params map[string]string, tx *sqlx.Tx, user *auth where += requiredCapabilitiesCondition } if len(errs) > 0 { - return nil, 0, util.JoinErrs(errs), nil, http.StatusBadRequest, nil + return nil, 0, nil, api.NewUserError(util.JoinErrs(errs)) } var queryString, countQueryString string @@ -851,7 +853,7 @@ JOIN server_profile sp ON s.id = sp.server` } serverCount, err = getServerCount(tx, countQuery, queryValues) if err != nil { - return nil, 0, nil, fmt.Errorf("failed to get servers count: %v", err), http.StatusInternalServerError, nil + return nil, 0, nil, api.NewSystemErrorf("failed to get servers count: %w", err) } serversList := []tc.ServerV5{} @@ -859,7 +861,7 @@ JOIN server_profile sp ON s.id = sp.server` runSecond, maxTime = ims.TryIfModifiedSinceQuery(tx, h, queryValues, selectMaxLastUpdatedQuery(queryAddition, where)) if !runSecond { log.Debugln("IMS HIT") - return serversList, 0, nil, nil, http.StatusNotModified, &maxTime + return serversList, 0, &maxTime, nil } log.Debugln("IMS MISS") } else { @@ -875,7 +877,7 @@ JOIN server_profile sp ON s.id = sp.server` log.Debugln("Query is ", query) rows, err := tx.NamedQuery(query, queryValues) if err != nil { - return nil, serverCount, nil, errors.New("querying: " + err.Error()), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("querying: %w", err) } defer rows.Close() @@ -925,7 +927,7 @@ JOIN server_profile sp ON s.id = sp.server` &s.StatusLastUpdated, ) if err != nil { - return nil, serverCount, nil, fmt.Errorf("getting servers: %w", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("getting servers: %w", err) } if (version.GreaterThanOrEqualTo(&api.Version{Major: 4}) && roleBasedPerms) || version.GreaterThanOrEqualTo(&api.Version{Major: 5}) { if !user.Can(tc.PermSecureServerRead) { @@ -938,7 +940,7 @@ JOIN server_profile sp ON s.id = sp.server` } if _, ok := servers[s.ID]; ok { - return nil, serverCount, nil, fmt.Errorf("found more than one server with ID #%d", s.ID), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("found more than one server with ID #%d", s.ID) } servers[s.ID] = s ids = append(ids, s.ID) @@ -946,30 +948,28 @@ JOIN server_profile sp ON s.id = sp.server` // if ds requested uses mid-tier caches, add those to the list as well if usesMids { - midIDs, userErr, sysErr, errCode := getMidServers(ids, servers, dsID, cdnID, tx, dsHasRequiredCapabilities) - - log.Debugf("getting mids: %v, %v, %s\n", userErr, sysErr, http.StatusText(errCode)) + midIDs, err := getMidServers(ids, servers, dsID, cdnID, tx, dsHasRequiredCapabilities) serverCount = serverCount + uint64(len(midIDs)) - if userErr != nil || sysErr != nil { - return nil, serverCount, userErr, sysErr, errCode, nil + if err != nil { + return nil, serverCount, nil, err } ids = append(ids, midIDs...) } if len(ids) < 1 { - return []tc.ServerV5{}, serverCount, nil, nil, http.StatusOK, nil + return []tc.ServerV5{}, serverCount, nil, nil } query, args, err := sqlx.In(`SELECT max_bandwidth, monitor, mtu, name, server, router_host_name, router_port_name FROM interface WHERE server IN (?)`, ids) if err != nil { - return nil, serverCount, nil, fmt.Errorf("building interfaces query: %v", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("building interfaces query: %w", err) } query = tx.Rebind(query) interfaces := map[int]map[string]tc.ServerInterfaceInfoV40{} interfaceRows, err := tx.Queryx(query, args...) if err != nil { - return nil, serverCount, nil, fmt.Errorf("querying for interfaces: %v", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("querying for interfaces: %w", err) } defer interfaceRows.Close() @@ -983,7 +983,7 @@ JOIN server_profile sp ON s.id = sp.server` var routerHostName string var routerPort string if err = interfaceRows.Scan(&iface.MaxBandwidth, &iface.Monitor, &iface.MTU, &iface.Name, &server, &routerHostName, &routerPort); err != nil { - return nil, serverCount, nil, fmt.Errorf("getting server interfaces: %v", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("getting server interfaces: %w", err) } if _, ok := servers[server]; !ok { @@ -1000,12 +1000,12 @@ JOIN server_profile sp ON s.id = sp.server` query, args, err = sqlx.In(`SELECT address, gateway, service_address, server, interface FROM ip_address WHERE server IN (?)`, ids) if err != nil { - return nil, serverCount, nil, fmt.Errorf("building IP addresses query: %v", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("building IP addresses query: %w", err) } query = tx.Rebind(query) ipRows, err := tx.Tx.Query(query, args...) if err != nil { - return nil, serverCount, nil, fmt.Errorf("querying for IP addresses: %v", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("querying for IP addresses: %w", err) } defer ipRows.Close() @@ -1015,7 +1015,7 @@ JOIN server_profile sp ON s.id = sp.server` var iface string if err = ipRows.Scan(&ip.Address, &ip.Gateway, &ip.ServiceAddress, &server, &iface); err != nil { - return nil, serverCount, nil, fmt.Errorf("getting server IP addresses: %v", err), http.StatusInternalServerError, nil + return nil, serverCount, nil, api.NewSystemErrorf("getting server IP addresses: %w", err) } if _, ok := interfaces[server]; !ok { @@ -1039,13 +1039,13 @@ JOIN server_profile sp ON s.id = sp.server` returnable = append(returnable, server) } - return returnable, serverCount, nil, nil, http.StatusOK, &maxTime + return returnable, serverCount, &maxTime, nil } // getMidServers gets the mids used by the edges provided with an option to filter for a given cdn -func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID int, tx *sqlx.Tx, includeCapabilities bool) ([]int, error, error, int) { +func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID int, tx *sqlx.Tx, includeCapabilities bool) ([]int, api.Errors) { if len(edgeIDs) == 0 { - return nil, nil, nil, http.StatusOK + return nil, nil } filters := map[string]interface{}{ @@ -1061,7 +1061,7 @@ func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID i q := selectIDQuery + midWhereClause rows, err := tx.NamedQuery(q, filters) if err != nil { - return nil, err, nil, http.StatusBadRequest + return nil, api.NewUserError(err) } defer rows.Close() @@ -1069,7 +1069,7 @@ func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID i var midID int if err := rows.Scan(&midID); err != nil { log.Errorf("could not scan mid server id: %s\n", err) - return nil, nil, err, http.StatusInternalServerError + return nil, api.NewSystemError(err) } midIDs = append(midIDs, midID) } @@ -1104,7 +1104,7 @@ func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID i rows, err := tx.NamedQuery(query, filters) if err != nil { - return nil, err, nil, http.StatusBadRequest + return nil, api.NewUserError(err) } defer rows.Close() @@ -1151,7 +1151,7 @@ func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID i &s.XMPPPasswd, &s.StatusLastUpdated); err != nil { log.Errorf("could not scan mid servers: %s\n", err) - return nil, nil, err, http.StatusInternalServerError + return nil, api.NewSystemError(err) } // This may mean that the server was caught by other query parameters, @@ -1163,49 +1163,49 @@ func getMidServers(edgeIDs []int, servers map[int]tc.ServerV5, dsID int, cdnID i } - return ids, nil, nil, http.StatusOK + return ids, nil } -func checkTypeChangeSafety(server tc.ServerV5, tx *sqlx.Tx) (error, error, int) { +func checkTypeChangeSafety(server tc.ServerV5, tx *sqlx.Tx) api.Errors { // see if cdn or type changed var cdnID int var typeID int if err := tx.QueryRow("SELECT type, cdn_id FROM server WHERE id = $1", server.ID).Scan(&typeID, &cdnID); err != nil { if err == sql.ErrNoRows { - return errors.New("no server found with this ID"), nil, http.StatusNotFound + return api.NewNotFoundError("no server found with this ID") } - return nil, fmt.Errorf("getting current server type: %w", err), http.StatusInternalServerError + return api.NewSystemErrorf("getting current server type: %w", err) } var dsIDs []int64 if err := tx.QueryRowx("SELECT ARRAY(SELECT deliveryservice FROM deliveryservice_server WHERE server = $1)", server.ID).Scan(pq.Array(&dsIDs)); err != nil && !errors.Is(err, sql.ErrNoRows) { - return nil, fmt.Errorf("getting server assigned delivery services: %w", err), http.StatusInternalServerError + return api.NewSystemErrorf("getting server assigned delivery services: %w", err) } // If type is changing ensure it isn't assigned to any DSes. if typeID != server.TypeID { if len(dsIDs) != 0 { - return errors.New("server type can not be updated when it is currently assigned to Delivery Services"), nil, http.StatusConflict + return api.NewErrors(http.StatusConflict, errors.New("server type can not be updated when it is currently assigned to Delivery Services"), nil) } } // Check to see if the user is trying to change the CDN of a server, which is already linked with a DS if cdnID != server.CDNID && len(dsIDs) != 0 { - return errors.New("server cdn can not be updated when it is currently assigned to delivery services"), nil, http.StatusConflict + return api.NewErrors(http.StatusConflict, errors.New("server cdn can not be updated when it is currently assigned to delivery services"), nil) } - return nil, nil, http.StatusOK + return nil } -func updateStatusLastUpdatedTime(id int, statusLastUpdatedTime *time.Time, tx *sql.Tx) (error, error, int) { +func updateStatusLastUpdatedTime(id int, statusLastUpdatedTime *time.Time, tx *sql.Tx) api.Errors { query := `UPDATE server SET status_last_updated=$1 WHERE id=$2 ` if _, err := tx.Exec(query, statusLastUpdatedTime, id); err != nil { - return errors.New("updating status last updated: " + err.Error()), nil, http.StatusInternalServerError + return api.NewSystemErrorf("updating status last updated: %w", err) } - return nil, nil, http.StatusOK + return nil } -func createInterfaces(id int, interfaces []tc.ServerInterfaceInfoV40, tx *sql.Tx) (error, error, int) { +func createInterfaces(id int, interfaces []tc.ServerInterfaceInfoV40, tx *sql.Tx) api.Errors { ifaceQry := ` INSERT INTO interface ( max_bandwidth, @@ -1247,7 +1247,8 @@ func createInterfaces(id int, interfaces []tc.ServerInterfaceInfoV40, tx *sql.Tx _, err := tx.Exec(ifaceQry, ifaceArgs...) if err != nil { - return api.ParseDBError(err) + userErr, sysErr, code := api.ParseDBError(err) + return api.NewErrors(code, userErr, sysErr) } ipQry += strings.Join(ipQueryParts, ",") @@ -1255,37 +1256,41 @@ func createInterfaces(id int, interfaces []tc.ServerInterfaceInfoV40, tx *sql.Tx _, err = tx.Exec(ipQry, ipArgs...) if err != nil { - return api.ParseDBError(err) + userErr, sysErr, code := api.ParseDBError(err) + return api.NewErrors(code, userErr, sysErr) } - return nil, nil, http.StatusOK + return nil } -func deleteInterfaces(id int, tx *sql.Tx) (error, error, int) { - if _, err := tx.Exec(deleteIPsQuery, id); err != nil && err != sql.ErrNoRows { - return api.ParseDBError(err) +func deleteInterfaces(id int, tx *sql.Tx) api.Errors { + if _, err := tx.Exec(deleteIPsQuery, id); err != nil && !errors.Is(err, sql.ErrNoRows) { + userErr, sysErr, code := api.ParseDBError(err) + return api.NewErrors(code, userErr, sysErr) } - if _, err := tx.Exec(deleteInterfacesQuery, id); err != nil && err != sql.ErrNoRows { - return api.ParseDBError(err) + if _, err := tx.Exec(deleteInterfacesQuery, id); err != nil && !errors.Is(err, sql.ErrNoRows) { + userErr, sysErr, code := api.ParseDBError(err) + return api.NewErrors(code, userErr, sysErr) } - return nil, nil, http.StatusOK + return nil } // Update is the handler for PUT requests to /servers. -func Update(inf *api.Info) (int, error, error) { +func Update(inf *api.Info) error { id := inf.IntParams["id"] // Get original server - originals, _, userErr, sysErr, errCode, _ := getServers(inf.RequestHeaders(), inf.Params, inf.Tx, inf.User, false, *inf.Version, inf.Config.RoleBasedPermissions) - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + var err error + originals, _, _, err := getServers(inf.RequestHeaders(), inf.Params, inf.Tx, inf.User, false, *inf.Version, inf.Config.RoleBasedPermissions) + if err != nil { + return err } if len(originals) < 1 { - return http.StatusNotFound, errors.New("the server doesn't exist, cannot update"), nil + return api.NewNotFoundError("the server doesn't exist, cannot update") } if len(originals) > 1 { - return http.StatusInternalServerError, nil, fmt.Errorf("too many servers by ID %d: %d", id, len(originals)) + return fmt.Errorf("too many servers by ID %d: %d", id, len(originals)) } original := originals[0] @@ -1311,7 +1316,7 @@ func Update(inf *api.Info) (int, error, error) { if inf.Version.GreaterThanOrEqualTo(&api.Version{Major: 5}) { server.ID = inf.IntParams["id"] if err := inf.DecodeBody(&server); err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } if server.StatusID != originalStatusID { currentTime := time.Now() @@ -1322,19 +1327,16 @@ func Update(inf *api.Info) (int, error, error) { statusLastUpdatedTime = *original.StatusLastUpdated } tmp := server.Downgrade() - _, userErr, sysErr := validateV4(&tmp, tx) - if userErr != nil || sysErr != nil { - if sysErr != nil { - return http.StatusInternalServerError, userErr, sysErr - } - return http.StatusBadRequest, userErr, sysErr + _, err := validateV4(&tmp, tx) + if err != nil { + return err } server = tmp.Upgrade() } else if inf.Version.GreaterThanOrEqualTo(&api.Version{Major: 4}) { var serverV4 tc.ServerV4 serverV4.ID = util.Ptr(inf.IntParams["id"]) if err := inf.DecodeBody(&serverV4); err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } if serverV4.StatusID == nil || *serverV4.StatusID != originalStatusID { currentTime := time.Now() @@ -1344,19 +1346,17 @@ func Update(inf *api.Info) (int, error, error) { server.StatusLastUpdated = original.StatusLastUpdated statusLastUpdatedTime = *original.StatusLastUpdated } - _, userErr, sysErr := validateV4(&serverV4, tx) - if userErr != nil || sysErr != nil { - if sysErr != nil { - return http.StatusInternalServerError, userErr, sysErr - } - return http.StatusBadRequest, userErr, sysErr + _, err := validateV4(&serverV4, tx) + if err != nil { + return err } server = serverV4.Upgrade() } else { serverV3.ID = new(int) *serverV3.ID = inf.IntParams["id"] - if err := inf.DecodeBody(&serverV3); err != nil { - return http.StatusBadRequest, err, nil + err := inf.DecodeBody(&serverV3) + if err != nil { + return api.NewUserError(err) } if serverV3.StatusID != nil && *serverV3.StatusID != originalStatusID { currentTime := time.Now() @@ -1366,26 +1366,23 @@ func Update(inf *api.Info) (int, error, error) { serverV3.StatusLastUpdated = original.StatusLastUpdated statusLastUpdatedTime = *original.StatusLastUpdated } - _, userErr, sysErr := validateV3(&serverV3, tx) - if userErr != nil || sysErr != nil { - if sysErr != nil { - return http.StatusInternalServerError, userErr, sysErr - } - return http.StatusBadRequest, userErr, sysErr + _, err = validateV3(&serverV3, tx) + if err != nil { + return err } profileName, exists, err := dbhelpers.GetProfileNameFromID(*serverV3.ProfileID, tx) if err != nil { - return http.StatusInternalServerError, nil, err + return err } if !exists { - return http.StatusNotFound, errors.New("profile does not exist"), nil + return api.NewNotFoundError("profile does not exist") } profileNames := []string{profileName} upgraded, err := serverV3.UpgradeToV40(profileNames) if err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("error upgrading valid V3 server to V4 structure: %w", err) + return fmt.Errorf("error upgrading valid V3 server to V4 structure: %w", err) } server = upgraded.Upgrade() } @@ -1393,89 +1390,86 @@ func Update(inf *api.Info) (int, error, error) { if original.CacheGroupID != server.CacheGroupID || original.CDNID != server.CDNID { hasDSOnCDN, err := dbhelpers.CachegroupHasTopologyBasedDeliveryServicesOnCDN(tx, original.CacheGroupID, original.CDNID) if err != nil { - return http.StatusInternalServerError, nil, err + return err } CDNIDs := []int{} if hasDSOnCDN { CDNIDs = append(CDNIDs, original.CDNID) } if err = topology_validation.CheckForEmptyCacheGroups(inf.Tx, []int{original.CacheGroupID}, CDNIDs, true, []int{original.ID}); err != nil { - return http.StatusBadRequest, fmt.Errorf("server is the last one in its Cache Group, which is used by a Topology, so it cannot be moved to another Cache Group: %w", err), nil + return api.NewUserErrorf("server is the last one in its Cache Group, which is used by a Topology, so it cannot be moved to another Cache Group: %w", err) } } status, ok, err := dbhelpers.GetStatusByID(server.StatusID, tx) if err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("getting server #%d status (#%d): %v", id, server.StatusID, err) + return fmt.Errorf("getting server #%d status (#%d): %v", id, server.StatusID, err) } if !ok { log.Warnf("previously existent status #%d not found when fetching later", server.StatusID) - return http.StatusBadRequest, fmt.Errorf("no such Status: #%d", server.StatusID), nil + return api.NewUserErrorf("no such Status: #%d", server.StatusID) } if status.Name == nil { - return http.StatusInternalServerError, nil, fmt.Errorf("status #%d had no name", server.StatusID) + return fmt.Errorf("status #%d had no name", server.StatusID) } if *status.Name != string(tc.CacheStatusOnline) && *status.Name != string(tc.CacheStatusReported) { dsIDs, err := getActiveDeliveryServicesThatOnlyHaveThisServerAssigned(id, original.Type, tx) if err != nil { - return http.StatusInternalServerError, - nil, - fmt.Errorf("getting Delivery Services to which server #%d is assigned that have no other servers: %w", id, err) + return fmt.Errorf("getting Delivery Services to which server #%d is assigned that have no other servers: %w", id, err) } if len(dsIDs) > 0 { prefix := fmt.Sprintf("setting server status to '%s' would leave Active Delivery Service", *status.Name) alertText := InvalidStatusForDeliveryServicesAlertText(prefix, original.Type, dsIDs) - return http.StatusConflict, errors.New(alertText), nil + return api.NewErrors(http.StatusConflict, errors.New(alertText), nil) } } - if userErr, sysErr, errCode = checkTypeChangeSafety(server, inf.Tx); userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + if err = checkTypeChangeSafety(server, inf.Tx); err != nil { + return err } if server.XMPPID != nil && *server.XMPPID != "" && originalXMPPID != "" && *server.XMPPID != originalXMPPID { - return http.StatusBadRequest, errors.New("server cannot be updated due to requested XMPPID change. XMPIDD is immutable"), nil + return api.NewUserErrorString("server cannot be updated due to requested XMPPID change. XMPIDD is immutable") } userErr, sysErr, statusCode := api.CheckIfUnModified(inf.RequestHeaders(), inf.Tx, server.ID, "server") if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } if server.CDN != "" { userErr, sysErr, statusCode = dbhelpers.CheckIfCurrentUserCanModifyCDN(inf.Tx.Tx, server.CDN, inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } else { userErr, sysErr, statusCode = dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(inf.Tx.Tx, int64(server.CDNID), inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } if inf.Version.GreaterThanOrEqualTo(&api.Version{Major: 4}) { if err = dbhelpers.UpdateServerProfilesForV4(server.ID, server.Profiles, tx); err != nil { - userErr, sysErr, errCode := api.ParseDBError(err) - return errCode, userErr, sysErr + return inf.HandleDBError(err) } } else { if err = dbhelpers.UpdateServerProfileTableForV3(serverV3.ID, serverV3.ProfileID, (original.Profiles)[0], tx); err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("failed to update server_profile: %w", err) + return fmt.Errorf("failed to update server_profile: %w", err) } } - serverID, errCode, userErr, sysErr := updateServer(inf.Tx, server) - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + serverID, err := updateServer(inf.Tx, server) + if err != nil { + return err } - if userErr, sysErr, errCode = deleteInterfaces(id, tx); userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + if err = deleteInterfaces(id, tx); err != nil { + return err } - if userErr, sysErr, errCode = createInterfaces(id, server.Interfaces, tx); userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + if err = createInterfaces(id, server.Interfaces, tx); err != nil { + return err } where := `WHERE s.id = $1` @@ -1526,12 +1520,12 @@ func Update(inf *api.Info) (int, error, error) { &server.StatusLastUpdated, ) if err != nil { - return http.StatusInternalServerError, nil, err + return err } serversInterfaces, err := dbhelpers.GetServersInterfaces([]int{server.ID}, inf.Tx.Tx) if err != nil { - return http.StatusInternalServerError, nil, err + return err } if interfacesMap, ok := serversInterfaces[server.ID]; ok { for _, intfc := range interfacesMap { @@ -1539,8 +1533,8 @@ func Update(inf *api.Info) (int, error, error) { } } - if userErr, sysErr, errCode = updateStatusLastUpdatedTime(id, &statusLastUpdatedTime, tx); userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + if err = updateStatusLastUpdatedTime(id, &statusLastUpdatedTime, tx); err != nil { + return err } if inf.Version.GreaterThanOrEqualTo(&api.Version{Major: 5}) { inf.WriteSuccessResponse(server, "Server updated") @@ -1550,25 +1544,25 @@ func Update(inf *api.Info) (int, error, error) { downgraded := server.Downgrade() csp, err := dbhelpers.GetCommonServerPropertiesFromV4(downgraded, inf.Tx.Tx) if err != nil { - return http.StatusInternalServerError, nil, err + return err } serverV30, err := downgraded.ToServerV3FromV4(csp) if err != nil { - return http.StatusInternalServerError, nil, err + return err } inf.WriteSuccessResponse(serverV30, "Server updated") } inf.CreateChangeLog(fmt.Sprintf("SERVER: %s.%s, ID: %d, ACTION: updated", server.HostName, server.DomainName, server.ID)) - return http.StatusOK, nil, nil + return nil } -func updateServer(tx *sqlx.Tx, server tc.ServerV5) (int64, int, error, error) { +func updateServer(tx *sqlx.Tx, server tc.ServerV5) (int64, api.Errors) { rows, err := tx.NamedQuery(updateQuery, server) if err != nil { userErr, sysErr, errCode := api.ParseDBError(err) - return 0, errCode, userErr, sysErr + return 0, api.NewErrors(errCode, userErr, sysErr) } defer rows.Close() @@ -1606,22 +1600,22 @@ func updateServer(tx *sqlx.Tx, server tc.ServerV5) (int64, int, error, error) { &server.TypeID, &server.StatusLastUpdated, ); err != nil { - return 0, http.StatusNotFound, nil, fmt.Errorf("scanning lastUpdated from server insert: %w", err) + return 0, api.NewNotFoundError("scanning lastUpdated from server insert: %w", err) } rowsAffected++ } if rowsAffected < 1 { - return 0, http.StatusNotFound, fmt.Errorf("no server found with id %d", server.ID), nil + return 0, api.NewNotFoundError("no server found with id %d", server.ID) } if rowsAffected > 1 { - return 0, http.StatusInternalServerError, nil, fmt.Errorf("update for server #%d affected too many rows (%d)", server.ID, rowsAffected) + return 0, api.NewSystemErrorf("update for server #%d affected too many rows (%d)", server.ID, rowsAffected) } - return serverId, http.StatusOK, nil, nil + return serverId, nil } -func insertServerProfile(id int, pName []string, tx *sql.Tx) (error, error, int) { +func insertServerProfile(id int, pName []string, tx *sql.Tx) api.Errors { priority := make([]int, 0, len(pName)) for i, _ := range pName { priority = append(priority, i) @@ -1636,37 +1630,35 @@ func insertServerProfile(id int, pName []string, tx *sql.Tx) (error, error, int) ` if _, err := tx.Exec(insertQuery, id, pq.Array(pName), pq.Array(priority)); err != nil { - return api.ParseDBError(err) + userErr, sysErr, code := api.ParseDBError(err) + return api.NewErrors(code, userErr, sysErr) } - return nil, nil, http.StatusOK + return nil } -func createV3(inf *api.Info) (int, error, error) { +func createV3(inf *api.Info) error { var server tc.ServerV30 if err := inf.DecodeBody(&server); err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } if server.ID != nil { var prevID int err := inf.Tx.Tx.QueryRow("SELECT id from server where id = $1", server.ID).Scan(&prevID) if err != nil && !errors.Is(err, sql.ErrNoRows) { - return http.StatusInternalServerError, nil, fmt.Errorf("checking if server with id %d exists", *server.ID) + return fmt.Errorf("checking if server with id %d exists", *server.ID) } if prevID != 0 { - return http.StatusBadRequest, fmt.Errorf("server with id %d already exists. Please do not provide an id", *server.ID), nil + return api.NewUserErrorf("server with id %d already exists. Please do not provide an id", *server.ID) } } server.XMPPID = newUUID() - _, userErr, sysErr := validateV3(&server, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - if sysErr != nil { - return http.StatusInternalServerError, userErr, sysErr - } - return http.StatusBadRequest, userErr, sysErr + _, err := validateV3(&server, inf.Tx.Tx) + if err != nil { + return err } currentTime := time.Now() @@ -1675,46 +1667,46 @@ func createV3(inf *api.Info) (int, error, error) { if server.CDNName != nil { userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDN(inf.Tx.Tx, *server.CDNName, inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } else if server.CDNID != nil { userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(inf.Tx.Tx, int64(*server.CDNID), inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } - serverID, err := createServerV3(inf.Tx, server) - if err != nil { - userErr, sysErr, errCode := api.ParseDBError(err) - return errCode, userErr, sysErr + serverID, e := createServerV3(inf.Tx, server) + if e != nil { + userErr, sysErr, errCode := api.ParseDBError(e) + return api.NewErrors(errCode, userErr, sysErr) } - interfaces, err := tc.ToInterfacesV4(server.Interfaces, server.RouterHostName, server.RouterPortName) - if err != nil { - return http.StatusInternalServerError, nil, err + interfaces, e := tc.ToInterfacesV4(server.Interfaces, server.RouterHostName, server.RouterPortName) + if e != nil { + return api.NewSystemError(e) } - userErr, sysErr, errCode := createInterfaces(int(serverID), interfaces, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + err = createInterfaces(int(serverID), interfaces, inf.Tx.Tx) + if err != nil { + return err } var origProfile string - err = inf.Tx.Tx.QueryRow("SELECT name from profile where id = $1", server.ProfileID).Scan(&origProfile) - if err != nil && err != sql.ErrNoRows { - return http.StatusInternalServerError, nil, fmt.Errorf("retreiving profile with id %d", *server.ProfileID) + e = inf.Tx.Tx.QueryRow("SELECT name from profile where id = $1", server.ProfileID).Scan(&origProfile) + if e != nil && !errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("retreiving profile with id %d", *server.ProfileID) } var origProfiles = []string{origProfile} - userErr, sysErr, statusCode := insertServerProfile(int(serverID), origProfiles, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + err = insertServerProfile(int(serverID), origProfiles, inf.Tx.Tx) + if err != nil { + return err } where := `WHERE s.id = $1` selquery := selectQuery + where var s4 tc.ServerV5 - err = inf.Tx.QueryRow(selquery, serverID).Scan( + e = inf.Tx.QueryRow(selquery, serverID).Scan( &s4.CacheGroup, &s4.CacheGroupID, &s4.CDNID, @@ -1753,43 +1745,40 @@ func createV3(inf *api.Info) (int, error, error) { &s4.XMPPPasswd, &s4.StatusLastUpdated, ) - if err != nil { - return http.StatusInternalServerError, nil, err + if e != nil { + return e } s4.Interfaces = interfaces downgraded := s4.Downgrade() - csp, err := dbhelpers.GetCommonServerPropertiesFromV4(downgraded, inf.Tx.Tx) - if err != nil { - return http.StatusInternalServerError, nil, err + csp, e := dbhelpers.GetCommonServerPropertiesFromV4(downgraded, inf.Tx.Tx) + if e != nil { + return e } - server, err = downgraded.ToServerV3FromV4(csp) - if err != nil { - return http.StatusInternalServerError, nil, err + server, e = downgraded.ToServerV3FromV4(csp) + if e != nil { + return e } inf.WriteCreatedResponse(server, "Server created", fmt.Sprintf("servers?id=%d", server.ID)) inf.CreateChangeLog(fmt.Sprintf("SERVER: %s.%s, ID: %d, ACTION: created", *server.HostName, *server.DomainName, *server.ID)) - return http.StatusCreated, nil, nil + return nil } -func createV5(inf *api.Info) (int, error, error) { +func createV5(inf *api.Info) error { var server tc.ServerV5 if err := inf.DecodeBody(&server); err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } server.XMPPID = newUUID() tmp := server.Downgrade() - _, userErr, sysErr := validateV4(&tmp, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - if sysErr != nil { - return http.StatusInternalServerError, userErr, sysErr - } - return http.StatusBadRequest, userErr, sysErr + _, err := validateV4(&tmp, inf.Tx.Tx) + if err != nil { + return err } currentTime := time.Now() @@ -1798,35 +1787,35 @@ func createV5(inf *api.Info) (int, error, error) { if server.CDN != "" { userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDN(inf.Tx.Tx, server.CDN, inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } else { userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(inf.Tx.Tx, int64(server.CDNID), inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } origProfiles := server.Profiles - serverID, err := createServerV5(inf.Tx, server) - if err != nil { - userErr, sysErr, errCode := api.ParseDBError(err) - return errCode, userErr, sysErr + serverID, e := createServerV5(inf.Tx, server) + if e != nil { + userErr, sysErr, errCode := api.ParseDBError(e) + return api.NewErrors(errCode, userErr, sysErr) } - userErr, sysErr, errCode := createInterfaces(int(serverID), server.Interfaces, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + err = createInterfaces(int(serverID), server.Interfaces, inf.Tx.Tx) + if err != nil { + return err } - userErr, sysErr, statusCode := insertServerProfile(int(serverID), origProfiles, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + err = insertServerProfile(int(serverID), origProfiles, inf.Tx.Tx) + if err != nil { + return err } where := `WHERE s.id = $1` selquery := selectQuery + joinProfileV4 + where - err = inf.Tx.QueryRow(selquery, serverID).Scan( + e = inf.Tx.QueryRow(selquery, serverID).Scan( &server.CacheGroup, &server.CacheGroupID, &server.CDNID, @@ -1865,41 +1854,38 @@ func createV5(inf *api.Info) (int, error, error) { &server.XMPPPasswd, &server.StatusLastUpdated, ) - if err != nil { - return http.StatusInternalServerError, nil, err + if e != nil { + return e } - code, userErr, sysErr := inf.WriteCreatedResponse(server, "Server created", fmt.Sprintf("servers?id=%d", server.ID)) + inf.WriteCreatedResponse(server, "Server created", fmt.Sprintf("servers?id=%d", server.ID)) inf.CreateChangeLog(fmt.Sprintf("SERVER: %s.%s, ID: %d, ACTION: created", server.HostName, server.DomainName, server.ID)) - return code, userErr, sysErr + return nil } -func createV4(inf *api.Info) (int, error, error) { +func createV4(inf *api.Info) error { var server tc.ServerV40 if err := inf.DecodeBody(&server); err != nil { - return http.StatusBadRequest, err, nil + return api.NewUserError(err) } if server.ID != nil { var prevID int err := inf.Tx.Tx.QueryRow("SELECT id from server where id = $1", server.ID).Scan(&prevID) if err != nil && !errors.Is(err, sql.ErrNoRows) { - return http.StatusInternalServerError, nil, fmt.Errorf("checking if server with id %d exists", *server.ID) + return fmt.Errorf("checking if server with id %d exists", *server.ID) } if prevID != 0 { - return http.StatusBadRequest, fmt.Errorf("server with id %d already exists. Please do not provide an id", *server.ID), nil + return api.NewUserErrorf("server with id %d already exists. Please do not provide an id", *server.ID) } } server.XMPPID = newUUID() - _, userErr, sysErr := validateV4(&server, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - if sysErr != nil { - return http.StatusInternalServerError, userErr, sysErr - } - return http.StatusBadRequest, userErr, sysErr + _, err := validateV4(&server, inf.Tx.Tx) + if err != nil { + return err } currentTime := time.Now() @@ -1908,36 +1894,36 @@ func createV4(inf *api.Info) (int, error, error) { if server.CDNName != nil { userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDN(inf.Tx.Tx, *server.CDNName, inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } else if server.CDNID != nil { userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(inf.Tx.Tx, int64(*server.CDNID), inf.User.UserName) if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + return api.NewErrors(statusCode, userErr, sysErr) } } origProfiles := server.ProfileNames - serverID, err := createServerV4(inf.Tx, server) - if err != nil { - userErr, sysErr, errCode := api.ParseDBError(err) - return errCode, userErr, sysErr + serverID, e := createServerV4(inf.Tx, server) + if e != nil { + userErr, sysErr, errCode := api.ParseDBError(e) + return api.NewErrors(errCode, userErr, sysErr) } - userErr, sysErr, errCode := createInterfaces(int(serverID), server.Interfaces, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + err = createInterfaces(int(serverID), server.Interfaces, inf.Tx.Tx) + if err != nil { + return err } - userErr, sysErr, statusCode := insertServerProfile(int(serverID), origProfiles, inf.Tx.Tx) - if userErr != nil || sysErr != nil { - return statusCode, userErr, sysErr + err = insertServerProfile(int(serverID), origProfiles, inf.Tx.Tx) + if err != nil { + return err } where := `WHERE s.id = $1` selquery := selectQuery + joinProfileV4 + where var srvr tc.ServerV5 - err = inf.Tx.QueryRow(selquery, serverID).Scan( + e = inf.Tx.QueryRow(selquery, serverID).Scan( &srvr.CacheGroup, &srvr.CacheGroupID, &srvr.CDNID, @@ -1976,16 +1962,16 @@ func createV4(inf *api.Info) (int, error, error) { &srvr.XMPPPasswd, &srvr.StatusLastUpdated, ) - if err != nil { - return http.StatusInternalServerError, nil, err + if e != nil { + return e } // TODO: Use returned values from SQL insert to ensure inserted values match srvr.Interfaces = server.Interfaces - code, userErr, sysErr := inf.WriteCreatedResponse(srvr.Downgrade(), "Server created", fmt.Sprintf("servers?id=%d", srvr.ID)) + inf.WriteCreatedResponse(srvr.Downgrade(), "Server created", fmt.Sprintf("servers?id=%d", srvr.ID)) inf.CreateChangeLog(fmt.Sprintf("SERVER: %s.%s, ID: %d, ACTION: created", srvr.HostName, srvr.DomainName, srvr.ID)) - return code, userErr, sysErr + return nil } func createServerV5(tx *sqlx.Tx, server tc.ServerV5) (int64, error) { @@ -2162,7 +2148,7 @@ func createServerV3(tx *sqlx.Tx, server tc.ServerV30) (int64, error) { } // Create is the handler for POST requests to /servers. -func Create(inf *api.Info) (int, error, error) { +func Create(inf *api.Info) error { switch inf.Version.Major { case 3: return createV3(inf) @@ -2226,69 +2212,68 @@ func getActiveDeliveryServicesThatOnlyHaveThisServerAssigned(id int, serverType } // Delete is the handler for DELETE requests to the /servers API endpoint. -func Delete(inf *api.Info) (int, error, error) { +func Delete(inf *api.Info) error { id := inf.IntParams["id"] tx := inf.Tx.Tx - serverInfo, exists, err := dbhelpers.GetServerInfo(id, tx) - if err != nil { - return http.StatusInternalServerError, nil, err + serverInfo, exists, e := dbhelpers.GetServerInfo(id, tx) + if e != nil { + return e } if !exists { - return http.StatusNotFound, fmt.Errorf("no server exists by id #%d", id), nil + return api.NewNotFoundError("no server exists by id #%d", id) } if dsIDs, err := getActiveDeliveryServicesThatOnlyHaveThisServerAssigned(id, serverInfo.Type, tx); err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("checking if server #%d is the last server assigned to any Delivery Services: %w", id, err) + return api.NewSystemErrorf("checking if server #%d is the last server assigned to any Delivery Services: %w", id, err) } else if len(dsIDs) > 0 { - return http.StatusConflict, fmt.Errorf("deleting server #%d would leave Active Delivery Service", id), nil + return api.NewErrors(http.StatusConflict, fmt.Errorf("deleting server #%d would leave Active Delivery Service", id), nil) } - servers, _, userErr, sysErr, errCode, _ := getServers(inf.RequestHeaders(), map[string]string{"id": inf.Params["id"]}, inf.Tx, inf.User, false, *inf.Version, inf.Config.RoleBasedPermissions) - if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + servers, _, _, err := getServers(inf.RequestHeaders(), map[string]string{"id": inf.Params["id"]}, inf.Tx, inf.User, false, *inf.Version, inf.Config.RoleBasedPermissions) + if err != nil { + return err } if len(servers) < 1 { - return http.StatusNotFound, fmt.Errorf("no server exists by id #%d", id), nil + return api.NewNotFoundError("no server exists by id #%d", id) } if len(servers) > 1 { - return http.StatusInternalServerError, nil, fmt.Errorf("there are somehow two servers with id %d - cannot delete", id) + return api.NewSystemErrorf("there are somehow two servers with id %d - cannot delete", id) } server := servers[0] if server.CDN != "" { - userErr, sysErr, errCode = dbhelpers.CheckIfCurrentUserCanModifyCDN(tx, server.CDN, inf.User.UserName) + userErr, sysErr, errCode := dbhelpers.CheckIfCurrentUserCanModifyCDN(tx, server.CDN, inf.User.UserName) if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + return api.NewErrors(errCode, userErr, sysErr) } } else { // when would this happen? - userErr, sysErr, errCode = dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(tx, int64(server.CDNID), inf.User.UserName) + userErr, sysErr, errCode := dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(tx, int64(server.CDNID), inf.User.UserName) if userErr != nil || sysErr != nil { - return errCode, userErr, sysErr + return api.NewErrors(errCode, userErr, sysErr) } } cacheGroupIds := []int{server.CacheGroupID} serverIds := []int{server.ID} - hasDSOnCDN, err := dbhelpers.CachegroupHasTopologyBasedDeliveryServicesOnCDN(tx, server.CacheGroupID, server.CDNID) - if err != nil { - return http.StatusInternalServerError, nil, err + hasDSOnCDN, e := dbhelpers.CachegroupHasTopologyBasedDeliveryServicesOnCDN(tx, server.CacheGroupID, server.CDNID) + if e != nil { + return e } CDNIDs := []int{} if hasDSOnCDN { CDNIDs = append(CDNIDs, server.CDNID) } if err := topology_validation.CheckForEmptyCacheGroups(inf.Tx, cacheGroupIds, CDNIDs, true, serverIds); err != nil { - return http.StatusBadRequest, fmt.Errorf("server is the last one in its cachegroup, which is used by a topology: %w", err), nil + return api.NewUserErrorf("server is the last one in its cachegroup, which is used by a topology: %w", err) } if result, err := tx.Exec(deleteServerQuery, id); err != nil { log.Errorf("Raw error: %v", err) - userErr, sysErr, errCode = api.ParseDBError(err) - return errCode, userErr, sysErr + return inf.HandleDBError(err) } else if rowsAffected, err := result.RowsAffected(); err != nil { - return http.StatusInternalServerError, nil, fmt.Errorf("getting rows affected by server delete: %w", err) + return fmt.Errorf("getting rows affected by server delete: %w", err) } else if rowsAffected != 1 { - return http.StatusInternalServerError, nil, fmt.Errorf("incorrect number of rows affected: %d", rowsAffected) + return fmt.Errorf("incorrect number of rows affected: %d", rowsAffected) } inf.CreateChangeLog(fmt.Sprintf("SERVER: %s.%s, ID: %d, ACTION: deleted", server.HostName, server.DomainName, server.ID)) @@ -2301,15 +2286,14 @@ func Delete(inf *api.Info) (int, error, error) { return inf.WriteSuccessResponse(downgraded, "Server deleted") } - csp, err := dbhelpers.GetCommonServerPropertiesFromV4(downgraded, tx) - if err != nil { - userErr, sysErr, errCode := api.ParseDBError(err) - return errCode, userErr, sysErr + csp, e := dbhelpers.GetCommonServerPropertiesFromV4(downgraded, tx) + if e != nil { + return inf.HandleDBError(e) } - serverv3, err := downgraded.ToServerV3FromV4(csp) - if err != nil { - return http.StatusInternalServerError, nil, err + serverv3, e := downgraded.ToServerV3FromV4(csp) + if e != nil { + return e } return inf.WriteSuccessResponse(serverv3, "Server deleted") } diff --git a/traffic_ops/traffic_ops_golang/server/servers_test.go b/traffic_ops/traffic_ops_golang/server/servers_test.go index 8b01eb71cd..8b6baa6ebe 100644 --- a/traffic_ops/traffic_ops_golang/server/servers_test.go +++ b/traffic_ops/traffic_ops_golang/server/servers_test.go @@ -148,13 +148,16 @@ func TestCheckTypeChangeSafety(t *testing.T) { ID: testServers[0].Server.ID, } - userErr, _, errCode := checkTypeChangeSafety(s, db.MustBegin()) - if errCode != 409 { - t.Errorf("Update servers: Expected error code of %v, but got %v", 409, errCode) + e := checkTypeChangeSafety(s, db.MustBegin()) + if e == nil { + t.Fatalf("Expected an error to occur") + } + if e.Code() != http.StatusConflict { + t.Errorf("Update servers: Expected error code of %d, but got %d", http.StatusConflict, e.Code()) } expectedErr := "server cdn can not be updated when it is currently assigned to delivery services" - if userErr == nil { - t.Errorf("Update expected error: %v, but got no error with status: %s", expectedErr, http.StatusText(errCode)) + if e.UserError() == nil { + t.Errorf("Update expected error: %s, but got no error with status: %s", expectedErr, http.StatusText(e.Code())) } } @@ -265,13 +268,13 @@ func TestGetServersByCachegroup(t *testing.T) { version := api.Version{Major: 4, Minor: 0} - servers, _, userErr, sysErr, errCode, _ := getServers(nil, v, db.MustBegin(), &user, false, version, false) - if userErr != nil || sysErr != nil { - t.Errorf("getServers expected: no errors, actual: %v %v with status: %s", userErr, sysErr, http.StatusText(errCode)) + servers, _, _, err := getServers(nil, v, db.MustBegin(), &user, false, version, false) + if err != nil { + t.Errorf("getServers expected: no errors, actual: %+v", err) } if len(servers) != 3 { - t.Errorf("getServers expected: len(servers) == 3, actual: %v", len(servers)) + t.Errorf("getServers expected: len(servers) == 3, actual: %d", len(servers)) } } @@ -381,10 +384,10 @@ func TestGetMidServers(t *testing.T) { user := auth.CurrentUser{} version := api.Version{Major: 4, Minor: 0} - servers, _, userErr, sysErr, errCode, _ := getServers(nil, v, db.MustBegin(), &user, false, version, false) + servers, _, _, err := getServers(nil, v, db.MustBegin(), &user, false, version, false) - if userErr != nil || sysErr != nil { - t.Errorf("getServers expected: no errors, actual: %v %v with status: %s", userErr, sysErr, http.StatusText(errCode)) + if err != nil { + t.Errorf("getServers expected: no errors, actual: %+v", err) } cols2 := []string{"cachegroup", "cachegroup_id", "cdn_id", "cdn_name", "domain_name", "guid", "host_name", @@ -496,10 +499,10 @@ func TestGetMidServers(t *testing.T) { mock.ExpectBegin() mock.ExpectQuery("SELECT").WillReturnRows(rows2) - mid, userErr, sysErr, errCode := getMidServers(serverIDs, serverMap, 0, 0, db.MustBegin(), false) + mid, err := getMidServers(serverIDs, serverMap, 0, 0, db.MustBegin(), false) - if userErr != nil || sysErr != nil { - t.Fatalf("getMidServers expected: no errors, actual: %v %v with status: %s", userErr, sysErr, http.StatusText(errCode)) + if err != nil { + t.Fatalf("getMidServers expected: no errors, actual: %+v", err) } if len(mid) != 1 { t.Fatalf("getMidServers expected: len(mid) == 1, actual: %v", len(mid)) @@ -570,9 +573,9 @@ func TestV3Validations(t *testing.T) { tx := db.MustBegin().Tx - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err != nil { - t.Errorf("Unexpected error validating test server: %v", err) + t.Errorf("Unexpected error validating test server: %#v", err) } testServer.Interfaces = []tc.ServerInterfaceInfo{} @@ -582,11 +585,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with no interfaces to be invalid") } else { - t.Logf("Got expected error validating server with no interfaces: %v", err) + t.Logf("Got expected error validating server with no interfaces: %+v", err) } testServer.Interfaces = nil @@ -596,11 +599,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with nil interfaces to be invalid") } else { - t.Logf("Got expected error validating server with nil interfaces: %v", err) + t.Logf("Got expected error validating server with nil interfaces: %+v", err) } badIface := goodInterface @@ -613,11 +616,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server an MTU < 1280 to be invalid") } else { - t.Logf("Got expected error validating server with an MTU < 1280: %v", err) + t.Logf("Got expected error validating server with an MTU < 1280: %+v", err) } badIface.MTU = nil @@ -629,11 +632,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with no IP addresses to be invalid") } else { - t.Logf("Got expected error validating server with no IP addresses: %v", err) + t.Logf("Got expected error validating server with no IP addresses: %+v", err) } badIface.IPAddresses = nil @@ -644,11 +647,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with nil IP addresses to be invalid") } else { - t.Logf("Got expected error validating server with nil IP addresses: %v", err) + t.Logf("Got expected error validating server with nil IP addresses: %+v", err) } badIface = goodInterface @@ -665,11 +668,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with no service addresses to be invalid") } else { - t.Logf("Got expected error validating server with no service addresses: %v", err) + t.Logf("Got expected error validating server with no service addresses: %+v", err) } testServer.Interfaces = []tc.ServerInterfaceInfo{goodInterface, goodInterface} @@ -679,11 +682,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with too many interfaces with service addresses to be invalid") } else { - t.Logf("Got expected error validating server with too many interfaces with service addresses: %v", err) + t.Logf("Got expected error validating server with too many interfaces with service addresses: %+v", err) } badIface = goodInterface @@ -699,11 +702,11 @@ func TestV3Validations(t *testing.T) { mock.ExpectQuery("SELECT name, use_in_table").WillReturnRows(typeRows) mock.ExpectQuery("SELECT").WillReturnRows(cdnRows) - _, err, _ = validateV3(&testServer, tx) + _, err = validateV3(&testServer, tx) if err == nil { t.Errorf("Expected a server with no service addresses to be invalid") } else { - t.Logf("Got expected error validating server with no service addresses: %v", err) + t.Logf("Got expected error validating server with no service addresses: %+v", err) } } @@ -734,12 +737,9 @@ func TestUpdateStatusLastUpdatedTime(t *testing.T) { mock.ExpectExec("UPDATE").WithArgs(lastUpdated, 1).WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() - sysErr, _, code := updateStatusLastUpdatedTime(1, &lastUpdated, db.MustBegin().Tx) - if sysErr != nil { - t.Errorf("unable to update time, system error: %v", sysErr) - } - if code != http.StatusOK { - t.Errorf("updated time failed with status code:%d", code) + err = updateStatusLastUpdatedTime(1, &lastUpdated, db.MustBegin().Tx) + if err != nil { + t.Errorf("unable to update time, system error: %+v", err) } } @@ -767,15 +767,9 @@ func TestCreateInterfaces(t *testing.T) { WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() - usrErr, sysErr, code := createInterfaces(1, iface, db.MustBegin().Tx) - if usrErr != nil { - t.Errorf("unable to create interface, user error: %v", usrErr) - } - if sysErr != nil { - t.Errorf("unable to create interface, system error: %v", sysErr) - } - if code != http.StatusOK { - t.Errorf("unable to create interface, failed with status code:%d", code) + err = createInterfaces(1, iface, db.MustBegin().Tx) + if err != nil { + t.Errorf("unable to create interface, error: %+v", err) } } @@ -795,15 +789,9 @@ func TestDeleteInterfaces(t *testing.T) { mock.ExpectExec("DELETE FROM interface").WithArgs(1).WillReturnResult(sqlmock.NewResult(1, 1)) mock.ExpectCommit() - usrErr, sysErr, code := deleteInterfaces(1, db.MustBegin().Tx) - if usrErr != nil { - t.Errorf("unable to delete interface, user error: %v", usrErr) - } - if sysErr != nil { - t.Errorf("unable to delete interface, system error: %v", sysErr) - } - if code != http.StatusOK { - t.Errorf("unable to delete interface, failed with status code:%d", code) + err = deleteInterfaces(1, db.MustBegin().Tx) + if err != nil { + t.Errorf("unable to delete interface, error: %+v", err) } } @@ -823,15 +811,9 @@ func TestInsertServerProfile(t *testing.T) { mock.ExpectExec("INSERT INTO").WithArgs(1, pq.Array(profileName), pq.Array(priority)).WillReturnResult(sqlmock.NewResult(2, 2)) mock.ExpectCommit() - usrErr, sysErr, code := insertServerProfile(1, profileName, db.MustBegin().Tx) - if usrErr != nil { - t.Errorf("unable to insert profile, user error: %v", usrErr) - } - if sysErr != nil { - t.Errorf("unable to insert profile, system error: %v", sysErr) - } - if code != http.StatusOK { - t.Errorf("unable to insert profile, failed with status code:%d", code) + err = insertServerProfile(1, profileName, db.MustBegin().Tx) + if err != nil { + t.Errorf("unable to insert profile, error: %+v", err) } } @@ -1088,17 +1070,11 @@ func TestUpdateServer(t *testing.T) { WillReturnRows(rows) mock.ExpectCommit() - sid, code, usrErr, sysErr := updateServer(db.MustBegin(), server) - if usrErr != nil { - t.Errorf("unable to update v4 server, user error: %v", usrErr) - } - if sysErr != nil { - t.Errorf("unable to update v4 server, system error: %v", sysErr) + sid, err := updateServer(db.MustBegin(), server) + if err != nil { + t.Errorf("unable to update v4 server, error: %+v", err) } if sid != int64(server.ID) { t.Errorf("updated incorrect server, expected: %d, got: %d", server.ID, sid) } - if code != http.StatusOK { - t.Errorf("failed to update server with id: %d, expected: %d, got: %d", server.ID, http.StatusOK, code) - } }