diff --git a/internal/validation/suite.go b/internal/validation/suite.go index 3122695e..e158ad52 100644 --- a/internal/validation/suite.go +++ b/internal/validation/suite.go @@ -246,7 +246,7 @@ type KubernetesValidationNodeGroupOpts struct { MinNodeCount int MaxNodeCount int InstanceType string - DiskSizeGiB int + DiskSize v1.Bytes } type KubernetesValidationNetworkOpts struct { @@ -420,7 +420,7 @@ func RunKubernetesValidation(t *testing.T, config ProviderConfig, opts Kubernete MinNodeCount: opts.NodeGroupOpts.MinNodeCount, MaxNodeCount: opts.NodeGroupOpts.MaxNodeCount, InstanceType: opts.NodeGroupOpts.InstanceType, - DiskSizeGiB: opts.NodeGroupOpts.DiskSizeGiB, + DiskSize: opts.NodeGroupOpts.DiskSize, Tags: opts.Tags, }) require.NoError(t, err, "ValidateCreateKubernetesNodeGroup should pass") diff --git a/v1/bytes.go b/v1/bytes.go index f1a1895e..0c06595a 100644 --- a/v1/bytes.go +++ b/v1/bytes.go @@ -1,21 +1,29 @@ package v1 -var zeroBytes = Bytes{value: 0, unit: Byte} +import ( + "encoding/json" + "fmt" + "math" + "math/big" -// NewBytes creates a new Bytes with the given value and unit + "github.com/brevdev/cloud/internal/errors" +) + +var ( + zeroBytes = Bytes{value: 0, unit: Byte} + + ErrBytesInvalidUnit = errors.New("invalid unit") + ErrBytesNotAnInt64 = errors.New("byte count is not an int64") + ErrBytesNotAnInt32 = errors.New("byte count is not an int32") +) + +// NewBytes creates a new Bytes with the given value and unit. func NewBytes(value BytesValue, unit BytesUnit) Bytes { - if value < 0 { - return zeroBytes - } - return Bytes{ - value: value, - unit: unit, - } + return Bytes{value: value, unit: unit} } type ( BytesValue int64 - BytesUnit string ) // Bytes represents a number of some unit of bytes @@ -24,6 +32,13 @@ type Bytes struct { unit BytesUnit } +// bytesJSON is the JSON representation of a Bytes. This struct is maintained separately from the core Bytes +// struct to allow for unexported fields to be used in the MarshalJSON and UnmarshalJSON methods. +type bytesJSON struct { + Value int64 `json:"value"` + Unit string `json:"unit"` +} + // Value is the whole non-negative number of bytes of the specified unit func (b Bytes) Value() BytesValue { return b.value @@ -34,23 +49,165 @@ func (b Bytes) Unit() BytesUnit { return b.unit } -// ByteUnit is a unit of measurement for bytes -const ( - Byte BytesUnit = "B" +// ByteCount is the total number of bytes in the Bytes +func (b Bytes) ByteCount() *big.Int { + bytesByteCount := big.NewInt(0).SetInt64(int64(b.value)) + unitByteCount := big.NewInt(0).SetUint64(b.unit.byteCount) + + return big.NewInt(0).Mul(bytesByteCount, unitByteCount) +} + +// ByteCountInUnit is the number of bytes in the Bytes of the given unit. For example, if +// the Bytes is 1000 MB, then: +// +// 1000 MB -> B = 1000000 +// 1000 MB -> KB = 1000 +// 1000 MB -> MB = 1 +// 1000 MB -> GB = .001 +// +// etc. +func (b Bytes) ByteCountInUnit(unit BytesUnit) *big.Float { + if b.unit == unit { + // If the units are the same, return the value as a float + return big.NewFloat(0).SetInt64(int64(b.value)) + } + + bytesByteCount := big.NewFloat(0).SetInt(b.ByteCount()) + unitByteCount := big.NewFloat(0).SetUint64(unit.byteCount) + + return big.NewFloat(0).Quo(bytesByteCount, unitByteCount) +} + +// ByteCountInUnitInt64 attempts to convert the result of ByteCountInUnit to an int64. If this conversion would +// result in an overflow, it returns an ErrBytesNotAnInt64 error. If the byte count is not an integer, the value +// is truncated towards zero. +func (b Bytes) ByteCountInUnitInt64(unit BytesUnit) (int64, error) { + byteCount := b.ByteCountInUnit(unit) + + byteCountInt64, accuracy := byteCount.Int64() + if byteCountInt64 == math.MaxInt64 && accuracy == big.Below { + return 0, errors.WrapAndTrace(errors.Join(ErrBytesNotAnInt64, fmt.Errorf("byte count %v is greater than %d", byteCount, math.MaxInt64))) + } + if byteCountInt64 == math.MinInt64 && accuracy == big.Above { + return 0, errors.WrapAndTrace(errors.Join(ErrBytesNotAnInt64, fmt.Errorf("byte count %v is less than %d", byteCount, math.MinInt64))) + } + return byteCountInt64, nil +} + +// ByteCountInUnitInt32 attempts to convert the result of ByteCountInUnit to an int32. If this conversion would +// result in an overflow, it returns an ErrBytesNotAnInt32 error. +func (b Bytes) ByteCountInUnitInt32(unit BytesUnit) (int32, error) { + byteCountInt64, err := b.ByteCountInUnitInt64(unit) + if err != nil { + return 0, errors.WrapAndTrace(err) + } + if byteCountInt64 > math.MaxInt32 { + return 0, errors.WrapAndTrace(errors.Join(ErrBytesNotAnInt32, fmt.Errorf("byte count %v is greater than %d", byteCountInt64, math.MaxInt32))) + } + return int32(byteCountInt64), nil //nolint:gosec // checked above +} + +// String returns the string representation of the Bytes +func (b Bytes) String() string { + return fmt.Sprintf("%d %s", b.value, b.unit) +} + +// MarshalJSON implements the json.Marshaler interface +func (b Bytes) MarshalJSON() ([]byte, error) { + return json.Marshal(bytesJSON{ + Value: int64(b.value), + Unit: b.unit.name, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (b *Bytes) UnmarshalJSON(data []byte) error { + var bytesJSON bytesJSON + if err := json.Unmarshal(data, &bytesJSON); err != nil { + return errors.WrapAndTrace(err) + } + + unit, err := stringToBytesUnit(bytesJSON.Unit) + if err != nil { + return errors.WrapAndTrace(err) + } + + newBytes := NewBytes(BytesValue(bytesJSON.Value), unit) + *b = newBytes + return nil +} + +// LessThan returns true if the Bytes is less than the other Bytes +func (b Bytes) LessThan(other Bytes) bool { + return b.ByteCount().Cmp(other.ByteCount()) < 0 +} + +// GreaterThan returns true if the Bytes is greater than the other Bytes +func (b Bytes) GreaterThan(other Bytes) bool { + return b.ByteCount().Cmp(other.ByteCount()) > 0 +} + +// Equal returns true if the Bytes is equal to the other Bytes +func (b Bytes) Equal(other Bytes) bool { + return b.ByteCount().Cmp(other.ByteCount()) == 0 +} + +// BytesUnit is a unit of measurement for bytes. Note for maintainers: this is defined as a struct rather than a +// type alias to ensure stronger compile-time type checking and to avoid the need for a validation function. +type BytesUnit struct { + name string + byteCount uint64 +} + +// String returns the string representation of the BytesUnit +func (u BytesUnit) String() string { + return u.name +} + +var ( + Byte = BytesUnit{name: "B", byteCount: 1} // Base 10 - Kilobyte BytesUnit = "KB" - Megabyte BytesUnit = "MB" - Gigabyte BytesUnit = "GB" - Terabyte BytesUnit = "TB" - Petabyte BytesUnit = "PB" - Exabyte BytesUnit = "EB" + Kilobyte = BytesUnit{name: "KB", byteCount: 1000} + Megabyte = BytesUnit{name: "MB", byteCount: 1000 * 1000} + Gigabyte = BytesUnit{name: "GB", byteCount: 1000 * 1000 * 1000} + Terabyte = BytesUnit{name: "TB", byteCount: 1000 * 1000 * 1000 * 1000} + Petabyte = BytesUnit{name: "PB", byteCount: 1000 * 1000 * 1000 * 1000 * 1000} + Exabyte = BytesUnit{name: "EB", byteCount: 1000 * 1000 * 1000 * 1000 * 1000 * 1000} // Base 2 - Kibibyte BytesUnit = "KiB" - Mebibyte BytesUnit = "MiB" - Gibibyte BytesUnit = "GiB" - Tebibyte BytesUnit = "TiB" - Pebibyte BytesUnit = "PiB" - Exbibyte BytesUnit = "EiB" + Kibibyte = BytesUnit{name: "KiB", byteCount: 1024} + Mebibyte = BytesUnit{name: "MiB", byteCount: 1024 * 1024} + Gibibyte = BytesUnit{name: "GiB", byteCount: 1024 * 1024 * 1024} + Tebibyte = BytesUnit{name: "TiB", byteCount: 1024 * 1024 * 1024 * 1024} + Pebibyte = BytesUnit{name: "PiB", byteCount: 1024 * 1024 * 1024 * 1024 * 1024} + Exbibyte = BytesUnit{name: "EiB", byteCount: 1024 * 1024 * 1024 * 1024 * 1024 * 1024} ) + +func stringToBytesUnit(unit string) (BytesUnit, error) { + switch unit { + case Byte.name: + return Byte, nil + case Kilobyte.name: + return Kilobyte, nil + case Megabyte.name: + return Megabyte, nil + case Gigabyte.name: + return Gigabyte, nil + case Terabyte.name: + return Terabyte, nil + case Petabyte.name: + return Petabyte, nil + case Kibibyte.name: + return Kibibyte, nil + case Mebibyte.name: + return Mebibyte, nil + case Gibibyte.name: + return Gibibyte, nil + case Tebibyte.name: + return Tebibyte, nil + case Pebibyte.name: + return Pebibyte, nil + } + return BytesUnit{}, errors.WrapAndTrace(errors.Join(ErrBytesInvalidUnit, fmt.Errorf("invalid unit: %s", unit))) +} diff --git a/v1/bytes_test.go b/v1/bytes_test.go new file mode 100644 index 00000000..eec238f6 --- /dev/null +++ b/v1/bytes_test.go @@ -0,0 +1,449 @@ +package v1 + +import ( + "encoding/json" + "errors" + "math/big" + "testing" +) + +func TestNewBytes(t *testing.T) { + tests := []struct { + name string + value BytesValue + unit BytesUnit + want Bytes + wantErr error + }{ + {name: "1000 B", value: 1000, unit: Byte, want: NewBytes(1000, Byte), wantErr: nil}, + {name: "1000 KB", value: 1000, unit: Kilobyte, want: NewBytes(1000, Kilobyte), wantErr: nil}, + {name: "1000 MB", value: 1000, unit: Megabyte, want: NewBytes(1000, Megabyte), wantErr: nil}, + {name: "1000 GB", value: 1000, unit: Gigabyte, want: NewBytes(1000, Gigabyte), wantErr: nil}, + {name: "1000 TB", value: 1000, unit: Terabyte, want: NewBytes(1000, Terabyte), wantErr: nil}, + {name: "1000 PB", value: 1000, unit: Petabyte, want: NewBytes(1000, Petabyte), wantErr: nil}, + {name: "1000 KiB", value: 1000, unit: Kibibyte, want: NewBytes(1000, Kibibyte), wantErr: nil}, + {name: "1000 MiB", value: 1000, unit: Mebibyte, want: NewBytes(1000, Mebibyte), wantErr: nil}, + {name: "1000 GiB", value: 1000, unit: Gibibyte, want: NewBytes(1000, Gibibyte), wantErr: nil}, + {name: "1000 TiB", value: 1000, unit: Tebibyte, want: NewBytes(1000, Tebibyte), wantErr: nil}, + {name: "1000 PiB", value: 1000, unit: Pebibyte, want: NewBytes(1000, Pebibyte), wantErr: nil}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + bytes := NewBytes(test.value, test.unit) + if !bytes.Equal(test.want) { + t.Errorf("NewBytes() = %v, want %v", bytes, test.want) + } + }) + } +} + +func TestBytesMarshalJSON(t *testing.T) { + tests := []struct { + name string + bytes Bytes + want string + wantErr error + }{ + {name: "1000 B", bytes: NewBytes(1000, Byte), want: `{"value":1000,"unit":"B"}`, wantErr: nil}, + {name: "1000 KB", bytes: NewBytes(1000, Kilobyte), want: `{"value":1000,"unit":"KB"}`, wantErr: nil}, + {name: "1000 MB", bytes: NewBytes(1000, Megabyte), want: `{"value":1000,"unit":"MB"}`, wantErr: nil}, + {name: "1000 GB", bytes: NewBytes(1000, Gigabyte), want: `{"value":1000,"unit":"GB"}`, wantErr: nil}, + {name: "1000 TB", bytes: NewBytes(1000, Terabyte), want: `{"value":1000,"unit":"TB"}`, wantErr: nil}, + {name: "1000 PB", bytes: NewBytes(1000, Petabyte), want: `{"value":1000,"unit":"PB"}`, wantErr: nil}, + {name: "1000 KiB", bytes: NewBytes(1000, Kibibyte), want: `{"value":1000,"unit":"KiB"}`, wantErr: nil}, + {name: "1000 MiB", bytes: NewBytes(1000, Mebibyte), want: `{"value":1000,"unit":"MiB"}`, wantErr: nil}, + {name: "1000 GiB", bytes: NewBytes(1000, Gibibyte), want: `{"value":1000,"unit":"GiB"}`, wantErr: nil}, + {name: "1000 TiB", bytes: NewBytes(1000, Tebibyte), want: `{"value":1000,"unit":"TiB"}`, wantErr: nil}, + {name: "1000 PiB", bytes: NewBytes(1000, Pebibyte), want: `{"value":1000,"unit":"PiB"}`, wantErr: nil}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + json, err := json.Marshal(test.bytes) + if test.wantErr != nil { + if err == nil { + t.Fatalf("json.Marshal() error = nil, want %v", test.wantErr) + } + if !errors.Is(err, test.wantErr) { + t.Fatalf("json.Marshal() error = %v, want %v", err, test.wantErr) + } + } else if string(json) != test.want { + t.Errorf("json.Marshal() = %v, want %v", string(json), test.want) + } + }) + } +} + +func TestBytesUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + json string + want Bytes + wantErr error + }{ + {name: "1000 B", json: `{"value":1000,"unit":"B"}`, want: NewBytes(1000, Byte), wantErr: nil}, + {name: "1000 KB", json: `{"value":1000,"unit":"KB"}`, want: NewBytes(1000, Kilobyte), wantErr: nil}, + {name: "1000 MB", json: `{"value":1000,"unit":"MB"}`, want: NewBytes(1000, Megabyte), wantErr: nil}, + {name: "1000 GB", json: `{"value":1000,"unit":"GB"}`, want: NewBytes(1000, Gigabyte), wantErr: nil}, + {name: "1000 TB", json: `{"value":1000,"unit":"TB"}`, want: NewBytes(1000, Terabyte), wantErr: nil}, + {name: "1000 PB", json: `{"value":1000,"unit":"PB"}`, want: NewBytes(1000, Petabyte), wantErr: nil}, + {name: "1000 KiB", json: `{"value":1000,"unit":"KiB"}`, want: NewBytes(1000, Kibibyte), wantErr: nil}, + {name: "1000 MiB", json: `{"value":1000,"unit":"MiB"}`, want: NewBytes(1000, Mebibyte), wantErr: nil}, + {name: "1000 GiB", json: `{"value":1000,"unit":"GiB"}`, want: NewBytes(1000, Gibibyte), wantErr: nil}, + {name: "1000 TiB", json: `{"value":1000,"unit":"TiB"}`, want: NewBytes(1000, Tebibyte), wantErr: nil}, + {name: "1000 PiB", json: `{"value":1000,"unit":"PiB"}`, want: NewBytes(1000, Pebibyte), wantErr: nil}, + + {name: "Empty unit", json: `{"value":1000,"unit":""}`, want: zeroBytes, wantErr: ErrBytesInvalidUnit}, + {name: "Invalid unit", json: `{"value":1000,"unit":"invalid"}`, want: zeroBytes, wantErr: ErrBytesInvalidUnit}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var bytes Bytes + err := json.Unmarshal([]byte(test.json), &bytes) + if test.wantErr != nil { + if err == nil { + t.Fatalf("json.Unmarshal() error = nil, want %v", test.wantErr) + } + if !errors.Is(err, test.wantErr) { + t.Fatalf("json.Unmarshal() error = %v, want %v", err, test.wantErr) + } + } else if !bytes.Equal(test.want) { + t.Errorf("json.Unmarshal() = %v, want %v", bytes, test.want) + } + }) + } +} + +func TestBytesString(t *testing.T) { + tests := []struct { + name string + bytes Bytes + want string + }{ + {name: "Byte", bytes: NewBytes(1000, Byte), want: "1000 B"}, + {name: "Kilobyte", bytes: NewBytes(1000, Kilobyte), want: "1000 KB"}, + {name: "Megabyte", bytes: NewBytes(1000, Megabyte), want: "1000 MB"}, + {name: "Gigabyte", bytes: NewBytes(1000, Gigabyte), want: "1000 GB"}, + {name: "Terabyte", bytes: NewBytes(1000, Terabyte), want: "1000 TB"}, + {name: "Petabyte", bytes: NewBytes(1000, Petabyte), want: "1000 PB"}, + {name: "Kibibyte", bytes: NewBytes(1000, Kibibyte), want: "1000 KiB"}, + {name: "Mebibyte", bytes: NewBytes(1000, Mebibyte), want: "1000 MiB"}, + {name: "Gibibyte", bytes: NewBytes(1000, Gibibyte), want: "1000 GiB"}, + {name: "Tebibyte", bytes: NewBytes(1000, Tebibyte), want: "1000 TiB"}, + {name: "Pebibyte", bytes: NewBytes(1000, Pebibyte), want: "1000 PiB"}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := test.bytes.String() + if got != test.want { + t.Errorf("Bytes.String() = %v, want %v", got, test.want) + } + }) + } +} + +func TestBytesGetters(t *testing.T) { + // Test that Value() and Unit() work correctly and that Unit() can be used + b := NewBytes(1024, Gigabyte) + + if b.Value() != 1024 { + t.Errorf("Value() = %v, want 1024", b.Value()) + } + + if b.Unit() != Gigabyte { + t.Errorf("Unit() = %v, want Gigabyte", b.Unit()) + } + + // Test that we can get the string value from Unit() + if b.Unit().String() != "GB" { + t.Errorf("Unit().String() = %v, want GB", b.Unit().String()) + } +} + +func TestBytesEqual(t *testing.T) { + tests := []struct { + name string + bytes Bytes + other Bytes + want bool + }{ + {name: "1000 B == 1000 B", bytes: NewBytes(1000, Byte), other: NewBytes(1000, Byte), want: true}, + {name: "1000 B != 1000 KB", bytes: NewBytes(1000, Byte), other: NewBytes(1000, Kilobyte), want: false}, + {name: "1000 KB != 1000 B", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Byte), want: false}, + {name: "1000 KB == 1000 KB", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Kilobyte), want: true}, + {name: "1000 KB != 1000 MB", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Megabyte), want: false}, + {name: "1000 MB != 1000 KB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Kilobyte), want: false}, + {name: "1000 MB == 1000 MB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Megabyte), want: true}, + {name: "1000 MB != 1000 GB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Gigabyte), want: false}, + {name: "1000 GB != 1000 MB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Megabyte), want: false}, + {name: "1000 GB == 1000 GB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Gigabyte), want: true}, + {name: "1000 GB != 1000 TB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Terabyte), want: false}, + {name: "1000 TB != 1000 GB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Gigabyte), want: false}, + {name: "1000 TB == 1000 TB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Terabyte), want: true}, + {name: "1000 TB != 1000 PB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Petabyte), want: false}, + {name: "1000 PB != 1000 TB", bytes: NewBytes(1000, Petabyte), other: NewBytes(1000, Terabyte), want: false}, + {name: "1000 PB == 1000 PB", bytes: NewBytes(1000, Petabyte), other: NewBytes(1000, Petabyte), want: true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + // Test the Equal() method + got := test.bytes.Equal(test.other) + if got != test.want { + t.Errorf("Bytes.Equal() = %v, want %v", got, test.want) + } + // Test the == operator + got = (test.bytes == test.other) + if got != test.want { + t.Errorf("Bytes == Bytes = %v, want %v", got, test.want) + } + }) + } +} + +func TestBytesLessThan(t *testing.T) { //nolint:dupl // test ok + tests := []struct { + name string + bytes Bytes + other Bytes + want bool + }{ + {name: "1000 B < 1000 KB", bytes: NewBytes(1000, Byte), other: NewBytes(1000, Kilobyte), want: true}, + {name: "1000 KB < 1000 B", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Byte), want: false}, + + {name: "1000 KB < 1000 MB", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Megabyte), want: true}, + {name: "1000 MB < 1000 KB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Kilobyte), want: false}, + + {name: "1000 MB < 1000 GB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Gigabyte), want: true}, + {name: "1000 GB < 1000 MB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Megabyte), want: false}, + + {name: "1000 GB < 1000 TB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Terabyte), want: true}, + {name: "1000 TB < 1000 GB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Gigabyte), want: false}, + + {name: "1000 TB < 1000 PB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Petabyte), want: true}, + {name: "1000 PB < 1000 TB", bytes: NewBytes(1000, Petabyte), other: NewBytes(1000, Terabyte), want: false}, + + {name: "1000 B < 1000 KiB", bytes: NewBytes(1000, Byte), other: NewBytes(1000, Kibibyte), want: true}, + {name: "1000 KiB < 1000 B", bytes: NewBytes(1000, Kibibyte), other: NewBytes(1000, Byte), want: false}, + + {name: "1000 KiB < 1000 MiB", bytes: NewBytes(1000, Kibibyte), other: NewBytes(1000, Mebibyte), want: true}, + {name: "1000 MiB < 1000 KiB", bytes: NewBytes(1000, Mebibyte), other: NewBytes(1000, Kibibyte), want: false}, + + {name: "1000 MiB < 1000 GiB", bytes: NewBytes(1000, Mebibyte), other: NewBytes(1000, Gibibyte), want: true}, + {name: "1000 GiB < 1000 MiB", bytes: NewBytes(1000, Gibibyte), other: NewBytes(1000, Mebibyte), want: false}, + + {name: "1000 GiB < 1000 TiB", bytes: NewBytes(1000, Gibibyte), other: NewBytes(1000, Tebibyte), want: true}, + {name: "1000 TiB < 1000 GiB", bytes: NewBytes(1000, Tebibyte), other: NewBytes(1000, Gibibyte), want: false}, + + {name: "1000 TiB < 1000 PiB", bytes: NewBytes(1000, Tebibyte), other: NewBytes(1000, Pebibyte), want: true}, + {name: "1000 PiB < 1000 TiB", bytes: NewBytes(1000, Pebibyte), other: NewBytes(1000, Tebibyte), want: false}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := test.bytes.LessThan(test.other) + if got != test.want { + t.Errorf("Bytes.LessThan() = %v, want %v", got, test.want) + } + }) + } +} + +func TestBytesGreaterThan(t *testing.T) { //nolint:dupl // test ok + tests := []struct { + name string + bytes Bytes + other Bytes + want bool + }{ + {name: "1000 B > 1000 KB", bytes: NewBytes(1000, Byte), other: NewBytes(1000, Kilobyte), want: false}, + {name: "1000 KB > 1000 B", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Byte), want: true}, + + {name: "1000 KB > 1000 MB", bytes: NewBytes(1000, Kilobyte), other: NewBytes(1000, Megabyte), want: false}, + {name: "1000 MB > 1000 KB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Kilobyte), want: true}, + + {name: "1000 MB > 1000 GB", bytes: NewBytes(1000, Megabyte), other: NewBytes(1000, Gigabyte), want: false}, + {name: "1000 GB > 1000 MB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Megabyte), want: true}, + + {name: "1000 GB > 1000 TB", bytes: NewBytes(1000, Gigabyte), other: NewBytes(1000, Terabyte), want: false}, + {name: "1000 TB > 1000 GB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Gigabyte), want: true}, + + {name: "1000 TB > 1000 PB", bytes: NewBytes(1000, Terabyte), other: NewBytes(1000, Petabyte), want: false}, + {name: "1000 PB > 1000 TB", bytes: NewBytes(1000, Petabyte), other: NewBytes(1000, Terabyte), want: true}, + + {name: "1000 B > 1000 KiB", bytes: NewBytes(1000, Byte), other: NewBytes(1000, Kibibyte), want: false}, + {name: "1000 KiB > 1000 B", bytes: NewBytes(1000, Kibibyte), other: NewBytes(1000, Byte), want: true}, + + {name: "1000 KiB > 1000 MiB", bytes: NewBytes(1000, Kibibyte), other: NewBytes(1000, Mebibyte), want: false}, + {name: "1000 MiB > 1000 KiB", bytes: NewBytes(1000, Mebibyte), other: NewBytes(1000, Kibibyte), want: true}, + + {name: "1000 MiB > 1000 GiB", bytes: NewBytes(1000, Mebibyte), other: NewBytes(1000, Gibibyte), want: false}, + {name: "1000 GiB > 1000 MiB", bytes: NewBytes(1000, Gibibyte), other: NewBytes(1000, Mebibyte), want: true}, + + {name: "1000 GiB > 1000 TiB", bytes: NewBytes(1000, Gibibyte), other: NewBytes(1000, Tebibyte), want: false}, + {name: "1000 TiB > 1000 GiB", bytes: NewBytes(1000, Tebibyte), other: NewBytes(1000, Gibibyte), want: true}, + + {name: "1000 TiB > 1000 PiB", bytes: NewBytes(1000, Tebibyte), other: NewBytes(1000, Pebibyte), want: false}, + {name: "1000 PiB > 1000 TiB", bytes: NewBytes(1000, Pebibyte), other: NewBytes(1000, Tebibyte), want: true}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := test.bytes.GreaterThan(test.other) + if got != test.want { + t.Errorf("Bytes.GreaterThan() = %v, want %v", got, test.want) + } + }) + } +} + +func TestBytesByteCountInUnit(t *testing.T) { + tests := []struct { + name string + bytes Bytes + unit BytesUnit + want *big.Float + }{ + {name: "1000 B -> B", bytes: NewBytes(1000, Byte), unit: Byte, want: big.NewFloat(1000)}, + {name: "1000 B -> KB", bytes: NewBytes(1000, Byte), unit: Kilobyte, want: big.NewFloat(1)}, + {name: "1000 B -> MB", bytes: NewBytes(1000, Byte), unit: Megabyte, want: big.NewFloat(0.001)}, + {name: "1000 B -> GB", bytes: NewBytes(1000, Byte), unit: Gigabyte, want: big.NewFloat(0.000001)}, + {name: "1000 B -> TB", bytes: NewBytes(1000, Byte), unit: Terabyte, want: big.NewFloat(0.000000001)}, + {name: "1000 B -> PB", bytes: NewBytes(1000, Byte), unit: Petabyte, want: big.NewFloat(0.000000000001)}, + {name: "1000 B -> KiB", bytes: NewBytes(1000, Byte), unit: Kibibyte, want: big.NewFloat(0.9765625)}, + {name: "1000 B -> MiB", bytes: NewBytes(1000, Byte), unit: Mebibyte, want: big.NewFloat(0.00095367431640625)}, + {name: "1000 B -> GiB", bytes: NewBytes(1000, Byte), unit: Gibibyte, want: big.NewFloat(0.0000009313225746154785)}, + {name: "1000 B -> TiB", bytes: NewBytes(1000, Byte), unit: Tebibyte, want: big.NewFloat(0.0000000009094947017729282)}, + {name: "1000 B -> PiB", bytes: NewBytes(1000, Byte), unit: Pebibyte, want: big.NewFloat(0.0000000000008881784197001252)}, + + {name: "1000 KB -> B", bytes: NewBytes(1000, Kilobyte), unit: Byte, want: big.NewFloat(1000000)}, + {name: "1000 KB -> KB", bytes: NewBytes(1000, Kilobyte), unit: Kilobyte, want: big.NewFloat(1000)}, + {name: "1000 KB -> MB", bytes: NewBytes(1000, Kilobyte), unit: Megabyte, want: big.NewFloat(1)}, + {name: "1000 KB -> GB", bytes: NewBytes(1000, Kilobyte), unit: Gigabyte, want: big.NewFloat(0.001)}, + {name: "1000 KB -> TB", bytes: NewBytes(1000, Kilobyte), unit: Terabyte, want: big.NewFloat(0.000001)}, + {name: "1000 KB -> PB", bytes: NewBytes(1000, Kilobyte), unit: Petabyte, want: big.NewFloat(0.000000001)}, + {name: "1000 KB -> KiB", bytes: NewBytes(1000, Kilobyte), unit: Kibibyte, want: big.NewFloat(976.5625)}, + {name: "1000 KB -> MiB", bytes: NewBytes(1000, Kilobyte), unit: Mebibyte, want: big.NewFloat(0.95367431640625)}, + {name: "1000 KB -> GiB", bytes: NewBytes(1000, Kilobyte), unit: Gibibyte, want: big.NewFloat(0.0009313225746154785)}, + {name: "1000 KB -> TiB", bytes: NewBytes(1000, Kilobyte), unit: Tebibyte, want: big.NewFloat(0.0000009094947017729282)}, + {name: "1000 KB -> PiB", bytes: NewBytes(1000, Kilobyte), unit: Pebibyte, want: big.NewFloat(0.0000000008881784197001252)}, + + {name: "1000 MB -> B", bytes: NewBytes(1000, Megabyte), unit: Byte, want: big.NewFloat(1000000000)}, + {name: "1000 MB -> KB", bytes: NewBytes(1000, Megabyte), unit: Kilobyte, want: big.NewFloat(1000000)}, + {name: "1000 MB -> MB", bytes: NewBytes(1000, Megabyte), unit: Megabyte, want: big.NewFloat(1000)}, + {name: "1000 MB -> GB", bytes: NewBytes(1000, Megabyte), unit: Gigabyte, want: big.NewFloat(1)}, + {name: "1000 MB -> TB", bytes: NewBytes(1000, Megabyte), unit: Terabyte, want: big.NewFloat(0.001)}, + {name: "1000 MB -> PB", bytes: NewBytes(1000, Megabyte), unit: Petabyte, want: big.NewFloat(0.000001)}, + {name: "1000 MB -> KiB", bytes: NewBytes(1000, Megabyte), unit: Kibibyte, want: big.NewFloat(976562.5)}, + {name: "1000 MB -> MiB", bytes: NewBytes(1000, Megabyte), unit: Mebibyte, want: big.NewFloat(953.67431640625)}, + {name: "1000 MB -> GiB", bytes: NewBytes(1000, Megabyte), unit: Gibibyte, want: big.NewFloat(0.9313225746154785)}, + {name: "1000 MB -> TiB", bytes: NewBytes(1000, Megabyte), unit: Tebibyte, want: big.NewFloat(0.0009094947017729282)}, + {name: "1000 MB -> PiB", bytes: NewBytes(1000, Megabyte), unit: Pebibyte, want: big.NewFloat(0.0000008881784197001252)}, + + {name: "1000 GB -> B", bytes: NewBytes(1000, Gigabyte), unit: Byte, want: big.NewFloat(1000000000000)}, + {name: "1000 GB -> KB", bytes: NewBytes(1000, Gigabyte), unit: Kilobyte, want: big.NewFloat(1000000000)}, + {name: "1000 GB -> MB", bytes: NewBytes(1000, Gigabyte), unit: Megabyte, want: big.NewFloat(1000000)}, + {name: "1000 GB -> GB", bytes: NewBytes(1000, Gigabyte), unit: Gigabyte, want: big.NewFloat(1000)}, + {name: "1000 GB -> TB", bytes: NewBytes(1000, Gigabyte), unit: Terabyte, want: big.NewFloat(1)}, + {name: "1000 GB -> PB", bytes: NewBytes(1000, Gigabyte), unit: Petabyte, want: big.NewFloat(0.001)}, + {name: "1000 GB -> KiB", bytes: NewBytes(1000, Gigabyte), unit: Kibibyte, want: big.NewFloat(976562500)}, + {name: "1000 GB -> MiB", bytes: NewBytes(1000, Gigabyte), unit: Mebibyte, want: big.NewFloat(953674.31640625)}, + {name: "1000 GB -> GiB", bytes: NewBytes(1000, Gigabyte), unit: Gibibyte, want: big.NewFloat(931.3225746154785)}, + {name: "1000 GB -> TiB", bytes: NewBytes(1000, Gigabyte), unit: Tebibyte, want: big.NewFloat(0.9094947017729282)}, + {name: "1000 GB -> PiB", bytes: NewBytes(1000, Gigabyte), unit: Pebibyte, want: big.NewFloat(0.0008881784197001252)}, + + {name: "1000 TB -> B", bytes: NewBytes(1000, Terabyte), unit: Byte, want: big.NewFloat(1000000000000000)}, + {name: "1000 TB -> KB", bytes: NewBytes(1000, Terabyte), unit: Kilobyte, want: big.NewFloat(1000000000000)}, + {name: "1000 TB -> MB", bytes: NewBytes(1000, Terabyte), unit: Megabyte, want: big.NewFloat(1000000000)}, + {name: "1000 TB -> GB", bytes: NewBytes(1000, Terabyte), unit: Gigabyte, want: big.NewFloat(1000000)}, + {name: "1000 TB -> TB", bytes: NewBytes(1000, Terabyte), unit: Terabyte, want: big.NewFloat(1000)}, + {name: "1000 TB -> PB", bytes: NewBytes(1000, Terabyte), unit: Petabyte, want: big.NewFloat(1)}, + {name: "1000 TB -> KiB", bytes: NewBytes(1000, Terabyte), unit: Kibibyte, want: big.NewFloat(976562500000)}, + {name: "1000 TB -> MiB", bytes: NewBytes(1000, Terabyte), unit: Mebibyte, want: big.NewFloat(953674316.40625)}, + {name: "1000 TB -> GiB", bytes: NewBytes(1000, Terabyte), unit: Gibibyte, want: big.NewFloat(931322.5746154785)}, + {name: "1000 TB -> TiB", bytes: NewBytes(1000, Terabyte), unit: Tebibyte, want: big.NewFloat(909.4947017729282)}, + {name: "1000 TB -> PiB", bytes: NewBytes(1000, Gigabyte), unit: Pebibyte, want: big.NewFloat(0.0008881784197001252)}, + + {name: "1000 PB -> B", bytes: NewBytes(1000, Petabyte), unit: Byte, want: big.NewFloat(1000000000000000000)}, + {name: "1000 PB -> KB", bytes: NewBytes(1000, Petabyte), unit: Kilobyte, want: big.NewFloat(1000000000000000)}, + {name: "1000 PB -> MB", bytes: NewBytes(1000, Petabyte), unit: Megabyte, want: big.NewFloat(1000000000000)}, + {name: "1000 PB -> GB", bytes: NewBytes(1000, Petabyte), unit: Gigabyte, want: big.NewFloat(1000000000)}, + {name: "1000 PB -> TB", bytes: NewBytes(1000, Petabyte), unit: Terabyte, want: big.NewFloat(1000000)}, + {name: "1000 PB -> PB", bytes: NewBytes(1000, Petabyte), unit: Petabyte, want: big.NewFloat(1000)}, + {name: "1000 PB -> KiB", bytes: NewBytes(1000, Petabyte), unit: Kibibyte, want: big.NewFloat(976562500000000)}, + {name: "1000 PB -> MiB", bytes: NewBytes(1000, Petabyte), unit: Mebibyte, want: big.NewFloat(953674316406.25)}, + {name: "1000 PB -> GiB", bytes: NewBytes(1000, Petabyte), unit: Gibibyte, want: big.NewFloat(931322574.6154785)}, + {name: "1000 PB -> TiB", bytes: NewBytes(1000, Petabyte), unit: Tebibyte, want: big.NewFloat(909494.7017729282)}, + {name: "1000 PB -> PiB", bytes: NewBytes(1000, Petabyte), unit: Pebibyte, want: big.NewFloat(888.1784197001252)}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := test.bytes.ByteCountInUnit(test.unit) + if got.Cmp(test.want) != 0 { + t.Errorf("Bytes.ByteCountInUnit() = %v, want %v", got, test.want) + } + }) + } +} + +func TestBytesByteCountInUnitInt64(t *testing.T) { + tests := []struct { + name string + bytes Bytes + unit BytesUnit + want int64 + wantErr error + }{ + {name: "2048 MiB -> GiB", bytes: NewBytes(2048, Mebibyte), unit: Gibibyte, want: 2, wantErr: nil}, + {name: "2048 EiB -> B", bytes: NewBytes(2048, Exbibyte), unit: Byte, want: 0, wantErr: ErrBytesNotAnInt64}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := test.bytes.ByteCountInUnitInt64(test.unit) + if err != nil { + if test.wantErr != nil { + if !errors.Is(err, test.wantErr) { + t.Errorf("Bytes.ByteCountInUnitInt64() = %v, want %v", err, test.wantErr) + } + } else { + t.Errorf("Bytes.ByteCountInUnitInt64() = %v, want %v", err, test.wantErr) + } + } else if got != test.want { + t.Errorf("Bytes.ByteCountInUnitInt64() = %v, want %v", got, test.want) + } + }) + } +} + +func TestBytesByteCountInUnitInt32(t *testing.T) { + tests := []struct { + name string + bytes Bytes + unit BytesUnit + want int32 + wantErr error + }{ + {name: "2048 MiB -> GiB", bytes: NewBytes(2048, Mebibyte), unit: Gibibyte, want: 2, wantErr: nil}, + {name: "2048 EiB -> B", bytes: NewBytes(2048, Exbibyte), unit: Byte, want: 0, wantErr: ErrBytesNotAnInt64}, + {name: "2048 EiB -> KB", bytes: NewBytes(2048, Exbibyte), unit: Kilobyte, want: 0, wantErr: ErrBytesNotAnInt32}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := test.bytes.ByteCountInUnitInt32(test.unit) + if err != nil { + if test.wantErr != nil { + if !errors.Is(err, test.wantErr) { + t.Errorf("Bytes.ByteCountInUnitInt32() = %v, want %v", err, test.wantErr) + } + } else { + t.Errorf("Bytes.ByteCountInUnitInt32() = %v, want %v", err, test.wantErr) + } + } else if got != test.want { + t.Errorf("Bytes.ByteCountInUnitInt32() = %v, want %v", got, test.want) + } + }) + } +} diff --git a/v1/kubernetes.go b/v1/kubernetes.go index d436778d..e2632833 100644 --- a/v1/kubernetes.go +++ b/v1/kubernetes.go @@ -2,11 +2,26 @@ package v1 import ( "context" + "encoding/json" "fmt" "github.com/brevdev/cloud/internal/errors" ) +var ( + ErrRefIDRequired = errors.New("refID is required") + ErrNameRequired = errors.New("name is required") + ErrNodeGroupInvalidStatus = errors.New("invalid node group status") + ErrClusterInvalidStatus = errors.New("invalid cluster status") + ErrClusterUserClusterNameRequired = errors.New("clusterName is required") + ErrClusterUserClusterCertificateAuthorityDataBase64Required = errors.New("clusterCertificateAuthorityDataBase64 is required") + ErrClusterUserClusterServerURLRequired = errors.New("clusterServerURL is required") + ErrClusterUserUsernameRequired = errors.New("username is required") + ErrClusterUserUserClientCertificateDataBase64Required = errors.New("userClientCertificateDataBase64 is required") + ErrClusterUserUserClientKeyDataBase64Required = errors.New("userClientKeyDataBase64 is required") + ErrClusterUserKubeconfigBase64Required = errors.New("kubeconfigBase64 is required") +) + // Cluster represents the complete specification of a Brev Kubernetes cluster. type Cluster struct { // The ID assigned by the cloud provider to the cluster. @@ -52,16 +67,125 @@ type Cluster struct { tags Tags } -type ClusterStatus string +// clusterJSON is the JSON representation of a Cluster. This struct is maintained separately from the core Cluster +// struct to allow for unexported fields to be used in the MarshalJSON and UnmarshalJSON methods. +type clusterJSON struct { + ID string `json:"id"` + Name string `json:"name"` + RefID string `json:"refID"` + Provider string `json:"provider"` + Cloud string `json:"cloud"` + Location string `json:"location"` + VPCID string `json:"vpcID"` + SubnetIDs []string `json:"subnetIDs"` + KubernetesVersion string `json:"kubernetesVersion"` + Status string `json:"status"` + APIEndpoint string `json:"apiEndpoint"` + ClusterCACertificateBase64 string `json:"clusterCACertificateBase64"` + NodeGroups []*NodeGroup `json:"nodeGroups"` + Tags map[string]string `json:"tags"` +} + +// MarshalJSON implements the json.Marshaler interface +func (c *Cluster) MarshalJSON() ([]byte, error) { + subnetIDs := make([]string, len(c.subnetIDs)) + for i, subnetID := range c.subnetIDs { + subnetIDs[i] = string(subnetID) + } + + return json.Marshal(clusterJSON{ + ID: string(c.id), + Name: c.name, + RefID: c.refID, + Provider: c.provider, + Cloud: c.cloud, + Location: c.location, + VPCID: string(c.vpcID), + SubnetIDs: subnetIDs, + KubernetesVersion: c.kubernetesVersion, + Status: c.status.value, + APIEndpoint: c.apiEndpoint, + ClusterCACertificateBase64: c.clusterCACertificateBase64, + NodeGroups: c.nodeGroups, + Tags: c.tags, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (c *Cluster) UnmarshalJSON(data []byte) error { + var clusterJSON clusterJSON + if err := json.Unmarshal(data, &clusterJSON); err != nil { + return errors.WrapAndTrace(err) + } + + subnetIDs := make([]CloudProviderResourceID, len(clusterJSON.SubnetIDs)) + for i, subnetID := range clusterJSON.SubnetIDs { + subnetIDs[i] = CloudProviderResourceID(subnetID) + } + + status, err := stringToClusterStatus(clusterJSON.Status) + if err != nil { + return errors.WrapAndTrace(err) + } -const ( - ClusterStatusUnknown ClusterStatus = "unknown" - ClusterStatusPending ClusterStatus = "pending" - ClusterStatusAvailable ClusterStatus = "available" - ClusterStatusDeleting ClusterStatus = "deleting" - ClusterStatusFailed ClusterStatus = "failed" + newCluster, err := NewCluster(ClusterSettings{ + ID: CloudProviderResourceID(clusterJSON.ID), + Name: clusterJSON.Name, + RefID: clusterJSON.RefID, + Provider: clusterJSON.Provider, + Cloud: clusterJSON.Cloud, + Location: clusterJSON.Location, + VPCID: CloudProviderResourceID(clusterJSON.VPCID), + SubnetIDs: subnetIDs, + KubernetesVersion: clusterJSON.KubernetesVersion, + Status: status, + APIEndpoint: clusterJSON.APIEndpoint, + ClusterCACertificateBase64: clusterJSON.ClusterCACertificateBase64, + NodeGroups: clusterJSON.NodeGroups, + Tags: clusterJSON.Tags, + }) + if err != nil { + return errors.WrapAndTrace(err) + } + + *c = *newCluster + return nil +} + +// ClusterStatus represents the status of a Kubernetes cluster. Note for maintainers: this is defined as a struct +// rather than a type alias to ensure stronger compile-time type checking and to avoid the need for a validation function. +type ClusterStatus struct { + value string +} + +var ( + ClusterStatusUnknown = ClusterStatus{value: "unknown"} + ClusterStatusPending = ClusterStatus{value: "pending"} + ClusterStatusAvailable = ClusterStatus{value: "available"} + ClusterStatusDeleting = ClusterStatus{value: "deleting"} + ClusterStatusFailed = ClusterStatus{value: "failed"} ) +func (s ClusterStatus) String() string { + return s.value +} + +func stringToClusterStatus(status string) (ClusterStatus, error) { + switch status { + case ClusterStatusUnknown.value: + return ClusterStatusUnknown, nil + case ClusterStatusPending.value: + return ClusterStatusPending, nil + case ClusterStatusAvailable.value: + return ClusterStatusAvailable, nil + case ClusterStatusDeleting.value: + return ClusterStatusDeleting, nil + case ClusterStatusFailed.value: + return ClusterStatusFailed, nil + } + return ClusterStatusUnknown, errors.Join(ErrClusterInvalidStatus, fmt.Errorf("invalid status: %s", status)) +} + func (c *Cluster) GetID() CloudProviderResourceID { return c.id } @@ -169,13 +293,10 @@ func (s *ClusterSettings) setDefaults() { func (s *ClusterSettings) validate() error { var errs []error if s.RefID == "" { - errs = append(errs, fmt.Errorf("refID is required")) + errs = append(errs, ErrRefIDRequired) } if s.Name == "" { - errs = append(errs, fmt.Errorf("name is required")) - } - if s.Status == "" { - errs = append(errs, fmt.Errorf("status is required")) + errs = append(errs, ErrNameRequired) } return errors.WrapAndTrace(errors.Join(errs...)) } @@ -226,7 +347,7 @@ type NodeGroup struct { instanceType string // The disk size of the nodes in the node group. - diskSizeGiB int + diskSize Bytes // The status of the node group. status NodeGroupStatus @@ -235,16 +356,100 @@ type NodeGroup struct { tags Tags } -type NodeGroupStatus string +// nodeGroupJSON is the JSON representation of a NodeGroup. This struct is maintained separately from the core NodeGroup +// struct to allow for unexported fields to be used in the MarshalJSON and UnmarshalJSON methods. +type nodeGroupJSON struct { + Name string `json:"name"` + RefID string `json:"refID"` + ID string `json:"id"` + MinNodeCount int `json:"minNodeCount"` + MaxNodeCount int `json:"maxNodeCount"` + InstanceType string `json:"instanceType"` + DiskSize Bytes `json:"diskSize"` + Status string `json:"status"` + Tags map[string]string `json:"tags"` +} + +// MarshalJSON implements the json.Marshaler interface +func (n *NodeGroup) MarshalJSON() ([]byte, error) { + return json.Marshal(nodeGroupJSON{ + Name: n.name, + RefID: n.refID, + ID: string(n.id), + MinNodeCount: n.minNodeCount, + MaxNodeCount: n.maxNodeCount, + InstanceType: n.instanceType, + DiskSize: n.diskSize, + Status: n.status.value, + Tags: n.tags, + }) +} + +// UnmarshalJSON implements the json.Unmarshaler interface +func (n *NodeGroup) UnmarshalJSON(data []byte) error { + var nodeGroupJSON nodeGroupJSON + if err := json.Unmarshal(data, &nodeGroupJSON); err != nil { + return errors.WrapAndTrace(err) + } + + status, err := stringToNodeGroupStatus(nodeGroupJSON.Status) + if err != nil { + return errors.WrapAndTrace(err) + } + + newNodeGroup, err := NewNodeGroup(NodeGroupSettings{ + Name: nodeGroupJSON.Name, + RefID: nodeGroupJSON.RefID, + ID: CloudProviderResourceID(nodeGroupJSON.ID), + MinNodeCount: nodeGroupJSON.MinNodeCount, + MaxNodeCount: nodeGroupJSON.MaxNodeCount, + InstanceType: nodeGroupJSON.InstanceType, + DiskSize: nodeGroupJSON.DiskSize, + Status: status, + Tags: nodeGroupJSON.Tags, + }) + if err != nil { + return errors.WrapAndTrace(err) + } -const ( - NodeGroupStatusUnknown NodeGroupStatus = "unknown" - NodeGroupStatusPending NodeGroupStatus = "pending" - NodeGroupStatusAvailable NodeGroupStatus = "available" - NodeGroupStatusDeleting NodeGroupStatus = "deleting" - NodeGroupStatusFailed NodeGroupStatus = "failed" + *n = *newNodeGroup + return nil +} + +// NodeGroupStatus represents the status of a Kubernetes node group. Note for maintainers: this is defined as a struct +// rather than a type alias to ensure stronger compile-time type checking and to avoid the need for a validation function. +type NodeGroupStatus struct { + value string +} + +func (s NodeGroupStatus) String() string { + return s.value +} + +var ( + NodeGroupStatusUnknown = NodeGroupStatus{value: "unknown"} + NodeGroupStatusPending = NodeGroupStatus{value: "pending"} + NodeGroupStatusAvailable = NodeGroupStatus{value: "available"} + NodeGroupStatusDeleting = NodeGroupStatus{value: "deleting"} + NodeGroupStatusFailed = NodeGroupStatus{value: "failed"} ) +func stringToNodeGroupStatus(status string) (NodeGroupStatus, error) { + switch status { + case NodeGroupStatusUnknown.value: + return NodeGroupStatusUnknown, nil + case NodeGroupStatusPending.value: + return NodeGroupStatusPending, nil + case NodeGroupStatusAvailable.value: + return NodeGroupStatusAvailable, nil + case NodeGroupStatusDeleting.value: + return NodeGroupStatusDeleting, nil + case NodeGroupStatusFailed.value: + return NodeGroupStatusFailed, nil + } + return NodeGroupStatusUnknown, errors.Join(ErrNodeGroupInvalidStatus, fmt.Errorf("invalid status: %s", status)) +} + func (n *NodeGroup) GetName() string { return n.name } @@ -269,8 +474,8 @@ func (n *NodeGroup) GetInstanceType() string { return n.instanceType } -func (n *NodeGroup) GetDiskSizeGiB() int { - return n.diskSizeGiB +func (n *NodeGroup) GetDiskSize() Bytes { + return n.diskSize } func (n *NodeGroup) GetStatus() NodeGroupStatus { @@ -302,7 +507,7 @@ type NodeGroupSettings struct { InstanceType string // The disk size of the nodes in the node group. - DiskSizeGiB int + DiskSize Bytes // The status of the node group. Status NodeGroupStatus @@ -317,13 +522,10 @@ func (s *NodeGroupSettings) setDefaults() { func (s *NodeGroupSettings) validate() error { var errs []error if s.RefID == "" { - errs = append(errs, fmt.Errorf("refID is required")) + errs = append(errs, ErrRefIDRequired) } if s.Name == "" { - errs = append(errs, fmt.Errorf("name is required")) - } - if s.Status == "" { - errs = append(errs, fmt.Errorf("status is required")) + errs = append(errs, ErrNameRequired) } return errors.WrapAndTrace(errors.Join(errs...)) } @@ -342,7 +544,7 @@ func NewNodeGroup(settings NodeGroupSettings) (*NodeGroup, error) { minNodeCount: settings.MinNodeCount, maxNodeCount: settings.MaxNodeCount, instanceType: settings.InstanceType, - diskSizeGiB: settings.DiskSizeGiB, + diskSize: settings.DiskSize, status: settings.Status, tags: settings.Tags, }, nil @@ -430,25 +632,25 @@ func (s *ClusterUserSettings) setDefaults() { func (s *ClusterUserSettings) validate() error { var errs []error if s.ClusterName == "" { - errs = append(errs, fmt.Errorf("clusterName is required")) + errs = append(errs, ErrClusterUserClusterNameRequired) } if s.ClusterCertificateAuthorityDataBase64 == "" { - errs = append(errs, fmt.Errorf("clusterCertificateAuthorityDataBase64 is required")) + errs = append(errs, ErrClusterUserClusterCertificateAuthorityDataBase64Required) } if s.ClusterServerURL == "" { - errs = append(errs, fmt.Errorf("clusterServerURL is required")) + errs = append(errs, ErrClusterUserClusterServerURLRequired) } if s.Username == "" { - errs = append(errs, fmt.Errorf("username is required")) + errs = append(errs, ErrClusterUserUsernameRequired) } if s.UserClientCertificateDataBase64 == "" { - errs = append(errs, fmt.Errorf("userClientCertificateDataBase64 is required")) + errs = append(errs, ErrClusterUserUserClientCertificateDataBase64Required) } if s.UserClientKeyDataBase64 == "" { - errs = append(errs, fmt.Errorf("userClientKeyDataBase64 is required")) + errs = append(errs, ErrClusterUserUserClientKeyDataBase64Required) } if s.KubeconfigBase64 == "" { - errs = append(errs, fmt.Errorf("kubeconfigBase64 is required")) + errs = append(errs, ErrClusterUserKubeconfigBase64Required) } return errors.WrapAndTrace(errors.Join(errs...)) } @@ -524,7 +726,7 @@ type CreateNodeGroupArgs struct { MinNodeCount int MaxNodeCount int InstanceType string - DiskSizeGiB int + DiskSize Bytes Tags Tags } diff --git a/v1/kubernetes_test.go b/v1/kubernetes_test.go new file mode 100644 index 00000000..900d5250 --- /dev/null +++ b/v1/kubernetes_test.go @@ -0,0 +1,347 @@ +package v1 + +import ( + "encoding/json" + "errors" + "regexp" + "testing" +) + +var reWhitespace = regexp.MustCompile(`\s+`) + +func TestClusterMarshalJSON(t *testing.T) { + cluster := &Cluster{ + id: "test-id", + name: "test-name", + refID: "test-refID", + provider: "test-provider", + cloud: "test-cloud", + location: "test-location", + vpcID: "test-vpcID", + subnetIDs: []CloudProviderResourceID{"test-subnetID"}, + kubernetesVersion: "test-kubernetesVersion", + status: ClusterStatusAvailable, + apiEndpoint: "test-apiEndpoint", + clusterCACertificateBase64: "test-clusterCACertificateBase64", + nodeGroups: []*NodeGroup{ + { + name: "test-nodeGroupName", + refID: "test-nodeGroupRefID", + id: "test-nodeGroupID", + minNodeCount: 1, + maxNodeCount: 2, + instanceType: "test-instanceType", + diskSize: NewBytes(10, Gibibyte), + status: NodeGroupStatusAvailable, + tags: Tags{ + "test-nodeGroupTagName": "test-nodeGroupTagValue", + }, + }, + }, + tags: Tags{ + "test-clusterTagName": "test-clusterTagValue", + }, + } + + expectedJSON := `{ + "id": "test-id", + "name": "test-name", + "refID": "test-refID", + "provider": "test-provider", + "cloud": "test-cloud", + "location": "test-location", + "vpcID": "test-vpcID", + "subnetIDs": ["test-subnetID"], + "kubernetesVersion": "test-kubernetesVersion", + "status": "available", + "apiEndpoint": "test-apiEndpoint", + "clusterCACertificateBase64": "test-clusterCACertificateBase64", + "nodeGroups": [ + { + "name": "test-nodeGroupName", + "refID": "test-nodeGroupRefID", + "id": "test-nodeGroupID", + "minNodeCount": 1, + "maxNodeCount": 2, + "instanceType": "test-instanceType", + "diskSize": { + "value": 10, + "unit": "GiB" + }, + "status": "available", + "tags": { + "test-nodeGroupTagName": "test-nodeGroupTagValue" + } + } + ], + "tags": { + "test-clusterTagName": "test-clusterTagValue" + } + }` + expectedJSON = reWhitespace.ReplaceAllString(expectedJSON, "") + + clusterJSON, err := cluster.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal node group: %v", err) + } + + if string(clusterJSON) != expectedJSON { + t.Fatalf("Cluster JSON = %s, want %s", string(clusterJSON), expectedJSON) + } +} + +func TestClusterUnmarshalJSON_InvalidJSON(t *testing.T) { + tests := []struct { + name string + json string + wantErr error + }{ + {name: "invalid status", json: `{"status":"invalid"}`, wantErr: ErrClusterInvalidStatus}, + {name: "refID is required", json: `{"name":"test-name", "status":"available"}`, wantErr: ErrRefIDRequired}, + {name: "name is required", json: `{"refID":"test-refID", "status":"available"}`, wantErr: ErrNameRequired}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var cluster Cluster + err := json.Unmarshal([]byte(test.json), &cluster) + if err == nil { + t.Fatalf("Expected error, got nil") + } + if !errors.Is(err, test.wantErr) { + t.Fatalf("Expected error, got %v", err) + } + }) + } +} + +func TestClusterUnmarshalJSON(t *testing.T) { //nolint:gocyclo,funlen // test ok + clusterJSON := `{ + "id": "test-id", + "name": "test-name", + "refID": "test-refID", + "provider": "test-provider", + "cloud": "test-cloud", + "location": "test-location", + "vpcID": "test-vpcID", + "subnetIDs": ["test-subnetID"], + "kubernetesVersion": "test-kubernetesVersion", + "status": "available", + "apiEndpoint": "test-apiEndpoint", + "clusterCACertificateBase64": "test-clusterCACertificateBase64", + "nodeGroups": [ + { + "name": "test-nodeGroupName", + "refID": "test-nodeGroupRefID", + "id": "test-nodeGroupID", + "minNodeCount": 1, + "maxNodeCount": 2, + "instanceType": "test-instanceType", + "diskSize": { + "value": 10, + "unit": "GiB" + }, + "status": "available", + "tags": { + "test-nodeGroupTagName": "test-nodeGroupTagValue" + } + } + ], + "tags": { + "test-clusterTagName": "test-clusterTagValue" + } + }` + var cluster Cluster + if err := json.Unmarshal([]byte(clusterJSON), &cluster); err != nil { + t.Fatalf("Failed to unmarshal cluster: %v", err) + } + + if cluster.id != "test-id" { + t.Fatalf("Cluster ID = %s, want %s", cluster.id, "test-id") + } + if cluster.name != "test-name" { + t.Fatalf("Cluster Name = %s, want %s", cluster.name, "test-name") + } + if cluster.refID != "test-refID" { + t.Fatalf("Cluster RefID = %s, want %s", cluster.refID, "test-refID") + } + if cluster.provider != "test-provider" { + t.Fatalf("Cluster Provider = %s, want %s", cluster.provider, "test-provider") + } + if cluster.cloud != "test-cloud" { + t.Fatalf("Cluster Cloud = %s, want %s", cluster.cloud, "test-cloud") + } + if cluster.location != "test-location" { + t.Fatalf("Cluster Location = %s, want %s", cluster.location, "test-location") + } + if cluster.vpcID != "test-vpcID" { + t.Fatalf("Cluster VPCID = %s, want %s", cluster.vpcID, "test-vpcID") + } + if cluster.kubernetesVersion != "test-kubernetesVersion" { + t.Fatalf("Cluster KubernetesVersion = %s, want %s", cluster.kubernetesVersion, "test-kubernetesVersion") + } + if cluster.status != ClusterStatusAvailable { + t.Fatalf("Cluster Status = %s, want %s", cluster.status, "available") + } + if cluster.apiEndpoint != "test-apiEndpoint" { + t.Fatalf("Cluster APIEndpoint = %s, want %s", cluster.apiEndpoint, "test-apiEndpoint") + } + if cluster.clusterCACertificateBase64 != "test-clusterCACertificateBase64" { + t.Fatalf("Cluster ClusterCACertificateBase64 = %s, want %s", cluster.clusterCACertificateBase64, "test-clusterCACertificateBase64") + } + if len(cluster.nodeGroups) != 1 { + t.Fatalf("Cluster NodeGroups = %d, want %d", len(cluster.nodeGroups), 1) + } + if cluster.nodeGroups[0].name != "test-nodeGroupName" { + t.Fatalf("Cluster NodeGroup Name = %s, want %s", cluster.nodeGroups[0].name, "test-nodeGroupName") + } + if cluster.nodeGroups[0].refID != "test-nodeGroupRefID" { + t.Fatalf("Cluster NodeGroup RefID = %s, want %s", cluster.nodeGroups[0].refID, "test-nodeGroupRefID") + } + if cluster.nodeGroups[0].id != "test-nodeGroupID" { + t.Fatalf("Cluster NodeGroup ID = %s, want %s", cluster.nodeGroups[0].id, "test-nodeGroupID") + } + if cluster.nodeGroups[0].minNodeCount != 1 { + t.Fatalf("Cluster NodeGroup MinNodeCount = %d, want %d", cluster.nodeGroups[0].minNodeCount, 1) + } + if cluster.nodeGroups[0].maxNodeCount != 2 { + t.Fatalf("Cluster NodeGroup MaxNodeCount = %d, want %d", cluster.nodeGroups[0].maxNodeCount, 2) + } + if cluster.nodeGroups[0].instanceType != "test-instanceType" { + t.Fatalf("Cluster NodeGroup InstanceType = %s, want %s", cluster.nodeGroups[0].instanceType, "test-instanceType") + } + if !cluster.nodeGroups[0].diskSize.Equal(NewBytes(10, Gibibyte)) { + t.Fatalf("Cluster NodeGroup DiskSize = %s, want %s", cluster.nodeGroups[0].diskSize, "10 GiB") + } + if len(cluster.nodeGroups[0].tags) != 1 { + t.Fatalf("Cluster NodeGroup Tags = %d, want %d", len(cluster.nodeGroups[0].tags), 1) + } + if cluster.nodeGroups[0].tags["test-nodeGroupTagName"] != "test-nodeGroupTagValue" { + t.Fatalf("Cluster NodeGroup Tag = %s, want %s", cluster.nodeGroups[0].tags["test-nodeGroupTagName"], "test-nodeGroupTagValue") + } + if len(cluster.tags) != 1 { + t.Fatalf("Cluster Tags = %d, want %d", len(cluster.tags), 1) + } + if cluster.tags["test-clusterTagName"] != "test-clusterTagValue" { + t.Fatalf("Cluster Tag = %s, want %s", cluster.tags["test-clusterTagName"], "test-clusterTagValue") + } +} + +func TestNodeGroupMarshalJSON(t *testing.T) { + nodeGroup := &NodeGroup{ + name: "test-nodeGroupName", + refID: "test-nodeGroupRefID", + id: "test-nodeGroupID", + minNodeCount: 1, + maxNodeCount: 2, + instanceType: "test-instanceType", + diskSize: NewBytes(10, Gibibyte), + status: NodeGroupStatusAvailable, + tags: Tags{ + "test-tagName": "test-tagValue", + }, + } + expectedJSON := `{ + "name": "test-nodeGroupName", + "refID": "test-nodeGroupRefID", + "id": "test-nodeGroupID", + "minNodeCount": 1, + "maxNodeCount": 2, + "instanceType": "test-instanceType", + "diskSize": { + "value": 10, + "unit": "GiB" + }, + "status": "available", + "tags": { + "test-tagName": "test-tagValue" + } + }` + expectedJSON = reWhitespace.ReplaceAllString(expectedJSON, "") + + nodeGroupJSON, err := nodeGroup.MarshalJSON() + if err != nil { + t.Fatalf("Failed to marshal node group: %v", err) + } + + if string(nodeGroupJSON) != expectedJSON { + t.Fatalf("NodeGroup JSON = %s, want %s", string(nodeGroupJSON), expectedJSON) + } +} + +func TestNodeGroupUnmarshalJSON(t *testing.T) { + nodeGroupJSON := `{ + "name": "test-nodeGroupName", + "refID": "test-nodeGroupRefID", + "id": "test-nodeGroupID", + "minNodeCount": 1, + "maxNodeCount": 2, + "instanceType": "test-instanceType", + "diskSize": { + "value": 10, + "unit": "GiB" + }, + "status": "available", + "tags": { + "test-tagName": "test-tagValue" + } + }` + var nodeGroup NodeGroup + if err := json.Unmarshal([]byte(nodeGroupJSON), &nodeGroup); err != nil { + t.Fatalf("Failed to unmarshal node group: %v", err) + } + + if nodeGroup.name != "test-nodeGroupName" { + t.Fatalf("NodeGroup Name = %s, want %s", nodeGroup.name, "test-nodeGroupName") + } + if nodeGroup.refID != "test-nodeGroupRefID" { + t.Fatalf("NodeGroup RefID = %s, want %s", nodeGroup.refID, "test-nodeGroupRefID") + } + if nodeGroup.id != "test-nodeGroupID" { + t.Fatalf("NodeGroup ID = %s, want %s", nodeGroup.id, "test-nodeGroupID") + } + if nodeGroup.minNodeCount != 1 { + t.Fatalf("NodeGroup MinNodeCount = %d, want %d", nodeGroup.minNodeCount, 1) + } + if nodeGroup.maxNodeCount != 2 { + t.Fatalf("NodeGroup MaxNodeCount = %d, want %d", nodeGroup.maxNodeCount, 2) + } + if nodeGroup.instanceType != "test-instanceType" { + t.Fatalf("NodeGroup InstanceType = %s, want %s", nodeGroup.instanceType, "test-instanceType") + } + if !nodeGroup.diskSize.Equal(NewBytes(10, Gibibyte)) { + t.Fatalf("NodeGroup DiskSize = %s, want %s", nodeGroup.diskSize, "10 GiB") + } + if nodeGroup.status != NodeGroupStatusAvailable { + t.Fatalf("NodeGroup Status = %s, want %s", nodeGroup.status, "available") + } + if len(nodeGroup.tags) != 1 { + t.Fatalf("NodeGroup Tags = %d, want %d", len(nodeGroup.tags), 1) + } + if nodeGroup.tags["test-tagName"] != "test-tagValue" { + t.Fatalf("NodeGroup Tag = %s, want %s", nodeGroup.tags["test-tagName"], "test-tagValue") + } +} + +func TestNodeGroupUnmarshalJSON_InvalidJSON(t *testing.T) { + tests := []struct { + name string + json string + wantErr error + }{ + {name: "invalid status", json: `{"status":"invalid"}`, wantErr: ErrNodeGroupInvalidStatus}, + {name: "refID is required", json: `{"name":"test-nodeGroupName", "status":"available"}`, wantErr: ErrRefIDRequired}, + {name: "name is required", json: `{"refID":"test-nodeGroupRefID", "status":"available"}`, wantErr: ErrNameRequired}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var nodeGroup NodeGroup + err := json.Unmarshal([]byte(test.json), &nodeGroup) + if err == nil { + t.Fatalf("Expected error, got nil") + } + if !errors.Is(err, test.wantErr) { + t.Fatalf("Expected error, got %v", err) + } + }) + } +} diff --git a/v1/kubernetes_validation.go b/v1/kubernetes_validation.go index f28a1a3a..ab7d6fd1 100644 --- a/v1/kubernetes_validation.go +++ b/v1/kubernetes_validation.go @@ -84,8 +84,8 @@ func ValidateCreateKubernetesNodeGroup(ctx context.Context, client CloudMaintain if nodeGroup.GetInstanceType() != attrs.InstanceType { return nil, fmt.Errorf("node group instanceType does not match create args: '%s' != '%s'", nodeGroup.GetInstanceType(), attrs.InstanceType) } - if nodeGroup.GetDiskSizeGiB() != attrs.DiskSizeGiB { - return nil, fmt.Errorf("node group diskSizeGiB does not match create args: '%d' != '%d'", nodeGroup.GetDiskSizeGiB(), attrs.DiskSizeGiB) + if !nodeGroup.GetDiskSize().Equal(attrs.DiskSize) { + return nil, fmt.Errorf("node group diskSize does not match create args: '%s' != '%s'", nodeGroup.GetDiskSize(), attrs.DiskSize) } return nodeGroup, nil @@ -121,8 +121,8 @@ func ValidateClusterNodeGroups(ctx context.Context, client CloudMaintainKubernet if clusterNodeGroup.GetInstanceType() != nodeGroup.GetInstanceType() { return fmt.Errorf("cluster node group instanceType does not match create args: '%s' != '%s'", clusterNodeGroup.GetInstanceType(), nodeGroup.GetInstanceType()) } - if clusterNodeGroup.GetDiskSizeGiB() != nodeGroup.GetDiskSizeGiB() { - return fmt.Errorf("cluster node group diskSizeGiB does not match create args: '%d' != '%d'", clusterNodeGroup.GetDiskSizeGiB(), nodeGroup.GetDiskSizeGiB()) + if !clusterNodeGroup.GetDiskSize().Equal(nodeGroup.GetDiskSize()) { + return fmt.Errorf("cluster node group diskSize does not match create args: '%s' != '%s'", clusterNodeGroup.GetDiskSize(), nodeGroup.GetDiskSize()) } for key, value := range nodeGroup.GetTags() { tagValue, ok := clusterNodeGroup.GetTags()[key] diff --git a/v1/providers/aws/kubernetes.go b/v1/providers/aws/kubernetes.go index 113b4f34..3936c903 100644 --- a/v1/providers/aws/kubernetes.go +++ b/v1/providers/aws/kubernetes.go @@ -25,6 +25,8 @@ import ( ) var ( + minDiskSize = v1.NewBytes(20, v1.Gibibyte) + errUsernameIsRequired = fmt.Errorf("username is required") errRoleIsRequired = fmt.Errorf("role is required") errClusterIDIsRequired = fmt.Errorf("cluster ID is required") @@ -34,7 +36,7 @@ var ( errNodeGroupMaxNodeCountMustBeGreaterThan0 = fmt.Errorf("node group maxNodeCount must be greater than 0") errNodeGroupMaxNodeCountMustBeGreaterThanOrEqualToMinNodeCount = fmt.Errorf("node group maxNodeCount must be greater than or equal to minNodeCount") errNodeGroupInstanceTypeIsRequired = fmt.Errorf("node group instanceType is required") - errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualTo20 = fmt.Errorf("node group diskSizeGiB must be greater than or equal to 20") + errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualToMinDiskSize = fmt.Errorf("node group diskSizeGiB must be greater than or equal to %v", minDiskSize) errNodeGroupDiskSizeGiBMustBeLessThanOrEqualToMaxInt32 = fmt.Errorf("node group diskSizeGiB must be less than or equal to %d", math.MaxInt32) errNodeGroupMaxNodeCountMustBeLessThanOrEqualToMaxInt32 = fmt.Errorf("node group maxNodeCount must be less than or equal to %d", math.MaxInt32) errNodeGroupMinNodeCountMustBeLessThanOrEqualToMaxInt32 = fmt.Errorf("node group minNodeCount must be less than or equal to %d", math.MaxInt32) @@ -402,7 +404,7 @@ func parseEKSNodeGroup(eksNodeGroup *ekstypes.Nodegroup) (*v1.NodeGroup, error) MinNodeCount: int(*eksNodeGroup.ScalingConfig.MinSize), MaxNodeCount: int(*eksNodeGroup.ScalingConfig.MaxSize), InstanceType: eksNodeGroup.InstanceTypes[0], // todo: handle multiple instance types - DiskSizeGiB: int(*eksNodeGroup.DiskSize), + DiskSize: v1.NewBytes(v1.BytesValue(*eksNodeGroup.DiskSize), v1.Gibibyte), Status: parseEKSNodeGroupStatus(eksNodeGroup.Status), Tags: v1.Tags(eksNodeGroup.Tags), }) @@ -471,6 +473,13 @@ func (c *AWSClient) CreateNodeGroup(ctx context.Context, args v1.CreateNodeGroup v1.Field{Key: "clusterName", Value: cluster.GetName()}, v1.Field{Key: "nodeGroupName", Value: args.Name}, ) + + // AWS expects the disk size in GiB, so we need to convert the disk size to GiB + diskSizeGiB, err := args.DiskSize.ByteCountInUnitInt32(v1.Gibibyte) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + output, err := eksClient.CreateNodegroup(ctx, &eks.CreateNodegroupInput{ ClusterName: aws.String(cluster.GetName()), NodegroupName: aws.String(args.Name), @@ -479,7 +488,7 @@ func (c *AWSClient) CreateNodeGroup(ctx context.Context, args v1.CreateNodeGroup MinSize: aws.Int32(int32(args.MinNodeCount)), //nolint:gosec // checked in input validation MaxSize: aws.Int32(int32(args.MaxNodeCount)), //nolint:gosec // checked in input validation }, - DiskSize: aws.Int32(int32(args.DiskSizeGiB)), //nolint:gosec // checked in input validation + DiskSize: aws.Int32(diskSizeGiB), Subnets: subnetIDs, InstanceTypes: []string{ args.InstanceType, @@ -511,10 +520,10 @@ func validateCreateNodeGroupArgs(args v1.CreateNodeGroupArgs) error { if args.InstanceType == "" { errs = append(errs, errNodeGroupInstanceTypeIsRequired) } - if args.DiskSizeGiB < 20 { - errs = append(errs, errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualTo20) + if args.DiskSize.LessThan(minDiskSize) { + errs = append(errs, errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualToMinDiskSize) } - if args.DiskSizeGiB > math.MaxInt32 { + if args.DiskSize.GreaterThan(v1.NewBytes(math.MaxInt32, v1.Gibibyte)) { errs = append(errs, errNodeGroupDiskSizeGiBMustBeLessThanOrEqualToMaxInt32) } if args.MaxNodeCount > math.MaxInt32 { diff --git a/v1/providers/aws/kubernetes_unit_test.go b/v1/providers/aws/kubernetes_unit_test.go index 80a1e702..d81654c5 100644 --- a/v1/providers/aws/kubernetes_unit_test.go +++ b/v1/providers/aws/kubernetes_unit_test.go @@ -26,7 +26,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), ClusterID: "cluster-123", }, expectError: nil, @@ -39,7 +39,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 0, MaxNodeCount: 3, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), }, expectError: errNodeGroupMinNodeCountMustBeGreaterThan0, }, @@ -51,7 +51,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 0, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), }, expectError: errNodeGroupMaxNodeCountMustBeGreaterThan0, }, @@ -63,7 +63,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 5, MaxNodeCount: 3, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), }, expectError: errNodeGroupMaxNodeCountMustBeGreaterThanOrEqualToMinNodeCount, }, @@ -75,7 +75,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), }, expectError: errNodeGroupInstanceTypeIsRequired, }, @@ -87,9 +87,9 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "t3.medium", - DiskSizeGiB: 10, + DiskSize: v1.NewBytes(10, v1.Gibibyte), }, - expectError: errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualTo20, + expectError: errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualToMinDiskSize, }, { name: "disk size exceeds max int32", @@ -99,7 +99,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "t3.medium", - DiskSizeGiB: math.MaxInt32 + 1, + DiskSize: v1.NewBytes(math.MaxInt32+1, v1.Gibibyte), }, expectError: errNodeGroupDiskSizeGiBMustBeLessThanOrEqualToMaxInt32, }, @@ -111,7 +111,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: math.MaxInt32 + 1, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), }, expectError: errNodeGroupMaxNodeCountMustBeLessThanOrEqualToMaxInt32, }, @@ -123,7 +123,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: math.MaxInt32 + 1, MaxNodeCount: math.MaxInt32 + 2, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), }, expectError: errNodeGroupMinNodeCountMustBeLessThanOrEqualToMaxInt32, }, @@ -614,8 +614,8 @@ func TestParseEKSNodeGroup(t *testing.T) { //nolint:gocognit // test ok if result.GetInstanceType() != tt.nodeGroup.InstanceTypes[0] { t.Errorf("expected instance type %s, got %s", tt.nodeGroup.InstanceTypes[0], result.GetInstanceType()) } - if result.GetDiskSizeGiB() != int(*tt.nodeGroup.DiskSize) { - t.Errorf("expected disk size %d, got %d", *tt.nodeGroup.DiskSize, result.GetDiskSizeGiB()) + if !result.GetDiskSize().Equal(v1.NewBytes(v1.BytesValue(*tt.nodeGroup.DiskSize), v1.Gibibyte)) { + t.Errorf("expected disk size %s, got %s", v1.NewBytes(v1.BytesValue(*tt.nodeGroup.DiskSize), v1.Gibibyte), result.GetDiskSize()) } if result.GetStatus() != parseEKSNodeGroupStatus(tt.nodeGroup.Status) { t.Errorf("expected status %v, got %v", parseEKSNodeGroupStatus(tt.nodeGroup.Status), result.GetStatus()) diff --git a/v1/providers/aws/validation_kubernetes_test.go b/v1/providers/aws/validation_kubernetes_test.go index 67a6c3a7..1d534302 100644 --- a/v1/providers/aws/validation_kubernetes_test.go +++ b/v1/providers/aws/validation_kubernetes_test.go @@ -55,7 +55,7 @@ func TestAWSKubernetesValidation(t *testing.T) { Name: name, RefID: name, InstanceType: "t3.medium", - DiskSizeGiB: 20, + DiskSize: v1.NewBytes(20, v1.Gibibyte), MinNodeCount: 1, MaxNodeCount: 1, }, diff --git a/v1/providers/nebius/kubernetes.go b/v1/providers/nebius/kubernetes.go index ef763fc1..68c310bf 100644 --- a/v1/providers/nebius/kubernetes.go +++ b/v1/providers/nebius/kubernetes.go @@ -23,6 +23,8 @@ import ( ) var ( + maxDiskSize = v1.NewBytes(v1.BytesValue(64), v1.Gibibyte) + errVPCHasNoPublicSubnets = fmt.Errorf("VPC must have at least one public subnet with a CIDR block larger than /24") errVPCHasNoPrivateSubnets = fmt.Errorf("VPC must have at least one private subnet with a CIDR block larger than /24") errNoSubnetIDsSpecifiedForVPC = fmt.Errorf("no subnet IDs specified for VPC") @@ -33,7 +35,7 @@ var ( errNodeGroupMinNodeCountMustBeGreaterThan0 = fmt.Errorf("node group minNodeCount must be greater than 0") errNodeGroupMaxNodeCountMustBeGreaterThan0 = fmt.Errorf("node group maxNodeCount must be greater than 0") errNodeGroupMaxNodeCountMustBeGreaterThanOrEqualToMinNodeCount = fmt.Errorf("node group maxNodeCount must be greater than or equal to minNodeCount") - errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualTo64 = fmt.Errorf("node group diskSizeGiB must be greater than or equal to 64") + errNodeGroupDiskSizeMustBeGreaterThanOrEqualToMax = fmt.Errorf("node group diskSize must be greater than or equal to %v", maxDiskSize) errNodeGroupInstanceTypeIsRequired = fmt.Errorf("node group instanceType is required") errUsernameIsRequired = fmt.Errorf("username is required") @@ -408,6 +410,12 @@ func (c *NebiusClient) CreateNodeGroup(ctx context.Context, args v1.CreateNodeGr labels[labelBrevRefID] = args.RefID labels[labelCreatedBy] = labelBrevCloudSDK + // Nebius expects the disk size in GiB, so we need to convert the disk size to GiB + diskSizeGiB, err := args.DiskSize.ByteCountInUnitInt64(v1.Gibibyte) + if err != nil { + return nil, errors.WrapAndTrace(err) + } + // create the node groups createNodeGroupOperation, err := nebiusNodeGroupService.Create(ctx, &nebiusmk8s.CreateNodeGroupRequest{ Metadata: &nebiuscommon.ResourceMetadata{ @@ -435,7 +443,7 @@ func (c *NebiusClient) CreateNodeGroup(ctx context.Context, args v1.CreateNodeGr BootDisk: &nebiusmk8s.DiskSpec{ Type: nebiusmk8s.DiskSpec_NETWORK_SSD, Size: &nebiusmk8s.DiskSpec_SizeGibibytes{ - SizeGibibytes: int64(args.DiskSizeGiB), + SizeGibibytes: diskSizeGiB, }, }, }, @@ -452,7 +460,7 @@ func (c *NebiusClient) CreateNodeGroup(ctx context.Context, args v1.CreateNodeGr MinNodeCount: args.MinNodeCount, MaxNodeCount: args.MaxNodeCount, InstanceType: args.InstanceType, - DiskSizeGiB: args.DiskSizeGiB, + DiskSize: args.DiskSize, Status: v1.NodeGroupStatusPending, Tags: args.Tags, }) @@ -478,8 +486,8 @@ func validateCreateNodeGroupArgs(args v1.CreateNodeGroupArgs) error { if args.MaxNodeCount < args.MinNodeCount { return errNodeGroupMaxNodeCountMustBeGreaterThanOrEqualToMinNodeCount } - if args.DiskSizeGiB < 64 { - return errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualTo64 + if args.DiskSize.LessThan(maxDiskSize) { + return errNodeGroupDiskSizeMustBeGreaterThanOrEqualToMax } if args.InstanceType == "" { return errNodeGroupInstanceTypeIsRequired @@ -515,7 +523,7 @@ func parseNebiusNodeGroup(nodeGroup *nebiusmk8s.NodeGroup) (*v1.NodeGroup, error MinNodeCount: int(nodeGroup.Spec.GetAutoscaling().MinNodeCount), MaxNodeCount: int(nodeGroup.Spec.GetAutoscaling().MaxNodeCount), InstanceType: nodeGroup.Spec.Template.Resources.Platform + "." + nodeGroup.Spec.Template.Resources.GetPreset(), - DiskSizeGiB: int(nodeGroup.Spec.Template.BootDisk.GetSizeGibibytes()), + DiskSize: v1.NewBytes(v1.BytesValue(nodeGroup.Spec.Template.BootDisk.GetSizeGibibytes()), v1.Gibibyte), Status: parseNebiusNodeGroupStatus(nodeGroup.Status), Tags: v1.Tags(nodeGroup.Metadata.Labels), }) diff --git a/v1/providers/nebius/kubernetes_unit_test.go b/v1/providers/nebius/kubernetes_unit_test.go index 65db6317..a978eb8a 100644 --- a/v1/providers/nebius/kubernetes_unit_test.go +++ b/v1/providers/nebius/kubernetes_unit_test.go @@ -23,7 +23,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), ClusterID: "cluster-123", }, expectError: nil, @@ -36,7 +36,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), }, expectError: errNodeGroupNameIsRequired, }, @@ -48,7 +48,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), }, expectError: errNodeGroupRefIDIsRequired, }, @@ -60,7 +60,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 0, MaxNodeCount: 3, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), }, expectError: errNodeGroupMinNodeCountMustBeGreaterThan0, }, @@ -72,7 +72,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 0, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), }, expectError: errNodeGroupMaxNodeCountMustBeGreaterThan0, }, @@ -84,7 +84,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 5, MaxNodeCount: 3, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), }, expectError: errNodeGroupMaxNodeCountMustBeGreaterThanOrEqualToMinNodeCount, }, @@ -96,9 +96,9 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 32, + DiskSize: v1.NewBytes(32, v1.Gibibyte), }, - expectError: errNodeGroupDiskSizeGiBMustBeGreaterThanOrEqualTo64, + expectError: errNodeGroupDiskSizeMustBeGreaterThanOrEqualToMax, }, { name: "missing instance type", @@ -108,7 +108,7 @@ func TestValidateCreateNodeGroupArgs(t *testing.T) { //nolint:funlen // test ok MinNodeCount: 1, MaxNodeCount: 3, InstanceType: "", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), }, expectError: errNodeGroupInstanceTypeIsRequired, }, diff --git a/v1/providers/nebius/validation_kubernetes_test.go b/v1/providers/nebius/validation_kubernetes_test.go index 8b9fbf02..4403dcd1 100644 --- a/v1/providers/nebius/validation_kubernetes_test.go +++ b/v1/providers/nebius/validation_kubernetes_test.go @@ -52,7 +52,7 @@ func TestKubernetesValidation(t *testing.T) { Name: name, RefID: name, InstanceType: "cpu-d3.4vcpu-16gb", - DiskSizeGiB: 64, + DiskSize: v1.NewBytes(64, v1.Gibibyte), MinNodeCount: 1, MaxNodeCount: 1, },