diff --git a/backend/controllers/do_nothing.go b/backend/controllers/do_nothing.go index a8695dca95..7cc04fce58 100644 --- a/backend/controllers/do_nothing.go +++ b/backend/controllers/do_nothing.go @@ -125,7 +125,10 @@ func (c *doNothingExample) SyncOnce(ctx context.Context, keyObj any) error { func (c *doNothingExample) queueAllHCPClusters(ctx context.Context) { logger := utils.LoggerFromContext(ctx) - allSubscriptions := c.cosmosClient.ListAllSubscriptionDocs() + allSubscriptions, err := c.cosmosClient.Subscriptions().List(ctx, nil) + if err != nil { + logger.Error("unable to list subscriptions", "error", err) + } for subscriptionID := range allSubscriptions.Items(ctx) { allHCPClusters, err := c.cosmosClient.HCPClusters(subscriptionID, "").List(ctx, nil) if err != nil { diff --git a/backend/operations_scanner.go b/backend/operations_scanner.go index c94cc0fe5a..6da9b582f7 100644 --- a/backend/operations_scanner.go +++ b/backend/operations_scanner.go @@ -366,7 +366,12 @@ func (s *OperationsScanner) collectSubscriptions(ctx context.Context, logger *sl var subscriptions []string - iterator := s.dbClient.ListAllSubscriptionDocs() + iterator, err := s.dbClient.Subscriptions().List(ctx, nil) + if err != nil { + s.recordOperationError(ctx, collectSubscriptionsLabel, err) + logger.Error(fmt.Sprintf("Error creating iterator: %v", err.Error())) + return + } subscriptionStates := map[arm.SubscriptionState]int{} for subscriptionState := range arm.ListSubscriptionStates() { @@ -382,8 +387,7 @@ func (s *OperationsScanner) collectSubscriptions(ctx context.Context, logger *sl } span.SetAttributes(tracing.ProcessedItemsKey.Int(len(subscriptions))) - err := iterator.GetError() - if err != nil { + if err := iterator.GetError(); err != nil { s.recordOperationError(ctx, collectSubscriptionsLabel, err) logger.Error(fmt.Sprintf("Error while paging through Cosmos query results: %v", err.Error())) return diff --git a/frontend/pkg/frontend/cluster.go b/frontend/pkg/frontend/cluster.go index becde12e5e..14b2d61361 100644 --- a/frontend/pkg/frontend/cluster.go +++ b/frontend/pkg/frontend/cluster.go @@ -275,7 +275,7 @@ func (f *Frontend) createHCPCluster(writer http.ResponseWriter, request *http.Re return err } - subscription, err := f.dbClient.GetSubscriptionDoc(ctx, resourceID.SubscriptionID) + subscription, err := f.dbClient.Subscriptions().Get(ctx, resourceID.SubscriptionID) if err != nil { return err } diff --git a/frontend/pkg/frontend/frontend.go b/frontend/pkg/frontend/frontend.go index cf44b25bce..96cf446f07 100644 --- a/frontend/pkg/frontend/frontend.go +++ b/frontend/pkg/frontend/frontend.go @@ -445,7 +445,7 @@ func (f *Frontend) ArmSubscriptionGet(writer http.ResponseWriter, request *http. subscriptionID := request.PathValue(PathSegmentSubscriptionID) - subscription, err := f.dbClient.GetSubscriptionDoc(ctx, subscriptionID) + subscription, err := f.dbClient.Subscriptions().Get(ctx, subscriptionID) if database.IsResponseError(err, http.StatusNotFound) { return arm.NewResourceNotFoundError(resourceID) } @@ -468,23 +468,27 @@ func (f *Frontend) ArmSubscriptionPut(writer http.ResponseWriter, request *http. if err != nil { return utils.TrackError(err) } + subscriptionID := request.PathValue(PathSegmentSubscriptionID) - var subscription arm.Subscription - err = json.Unmarshal(body, &subscription) + var requestSubscription arm.Subscription + err = json.Unmarshal(body, &requestSubscription) if err != nil { return arm.NewInvalidRequestContentError(err) } + requestSubscription.ResourceID, err = arm.ToSubscriptionResourceID(subscriptionID) + if err != nil { + return utils.TrackError(err) + } - validationErrs := validation.ValidateSubscriptionCreate(ctx, &subscription) + validationErrs := validation.ValidateSubscriptionCreate(ctx, &requestSubscription) if err := arm.CloudErrorFromFieldErrors(validationErrs); err != nil { return utils.TrackError(err) } - subscriptionID := request.PathValue(PathSegmentSubscriptionID) - - _, err = f.dbClient.GetSubscriptionDoc(ctx, subscriptionID) + var resultingSubscription *arm.Subscription + existingSubscription, err := f.dbClient.Subscriptions().Get(ctx, subscriptionID) if database.IsResponseError(err, http.StatusNotFound) { - err = f.dbClient.CreateSubscriptionDoc(ctx, subscriptionID, &subscription) + resultingSubscription, err = f.dbClient.Subscriptions().Create(ctx, &requestSubscription, nil) if err != nil { return utils.TrackError(err) } @@ -492,32 +496,29 @@ func (f *Frontend) ArmSubscriptionPut(writer http.ResponseWriter, request *http. } else if err != nil { return utils.TrackError(err) } else { - updated, err := f.dbClient.UpdateSubscriptionDoc(ctx, subscriptionID, func(updateSubscription *arm.Subscription) bool { - messages := getSubscriptionDifferences(updateSubscription, &subscription) - for _, message := range messages { - logger.Info(message) - } - - *updateSubscription = subscription - - return len(messages) > 0 - }) - if err != nil { - return utils.TrackError(err) + messages := getSubscriptionDifferences(existingSubscription, &requestSubscription) + for _, message := range messages { + logger.Info(message) } - if updated { + if len(messages) > 0 { + resultingSubscription, err = f.dbClient.Subscriptions().Replace(ctx, &requestSubscription, nil) + if err != nil { + return utils.TrackError(err) + } logger.Info(fmt.Sprintf("updated document for subscription %s", subscriptionID)) + } else { + resultingSubscription = existingSubscription } } // Clean up resources if subscription is deleted. - if subscription.State == arm.SubscriptionStateDeleted { + if resultingSubscription.State == arm.SubscriptionStateDeleted { if err := f.DeleteAllResourcesInSubscription(ctx, subscriptionID); err != nil { return utils.TrackError(err) } } - _, err = arm.WriteJSONResponse(writer, http.StatusOK, subscription) + _, err = arm.WriteJSONResponse(writer, http.StatusOK, resultingSubscription) if err != nil { return utils.TrackError(err) } @@ -531,7 +532,7 @@ func (f *Frontend) ArmDeploymentPreflight(writer http.ResponseWriter, request *h ctx := request.Context() logger := utils.LoggerFromContext(ctx) - subscription, err := f.dbClient.GetSubscriptionDoc(ctx, subscriptionID) + subscription, err := f.dbClient.Subscriptions().Get(ctx, subscriptionID) if err != nil { return err } diff --git a/frontend/pkg/frontend/frontend_test.go b/frontend/pkg/frontend/frontend_test.go index 1aeef07489..b285e78489 100644 --- a/frontend/pkg/frontend/frontend_test.go +++ b/frontend/pkg/frontend/frontend_test.go @@ -94,6 +94,7 @@ func TestSubscriptionsGET(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctrl := gomock.NewController(t) mockDBClient := mocks.NewMockDBClient(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) reg := prometheus.NewRegistry() f := NewFrontend( @@ -108,7 +109,10 @@ func TestSubscriptionsGET(t *testing.T) { // ArmSubscriptionGet. mockDBClient.EXPECT(). - GetSubscriptionDoc(gomock.Any(), gomock.Any()). + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Get(gomock.Any(), gomock.Any()). Return(getMockDBDoc(test.subDoc)). Times(1) @@ -117,7 +121,7 @@ func TestSubscriptionsGET(t *testing.T) { if test.subDoc != nil { subs[api.TestSubscriptionID] = test.subDoc } - ts := newHTTPServer(f, ctrl, mockDBClient, subs) + ts := newHTTPServer(f, ctrl, mockDBClient, mockSubscriptionCRUD, subs) rs, err := ts.Client().Get(ts.URL + api.TestSubscriptionResourceID + "?api-version=" + arm.SubscriptionAPIVersion) require.NoError(t, err) @@ -173,6 +177,7 @@ func TestSubscriptionsPUT(t *testing.T) { Properties: nil, }, subDoc: &arm.Subscription{ + ResourceID: api.Must(arm.ToSubscriptionResourceID(api.TestSubscriptionID)), State: arm.SubscriptionStateRegistered, RegistrationDate: api.Ptr(time.Now().String()), Properties: nil, @@ -196,6 +201,7 @@ func TestSubscriptionsPUT(t *testing.T) { }, }, subDoc: &arm.Subscription{ + ResourceID: api.Must(arm.ToSubscriptionResourceID(api.TestSubscriptionID)), State: arm.SubscriptionStateRegistered, RegistrationDate: api.Ptr(time.Now().String()), Properties: nil, @@ -251,6 +257,7 @@ func TestSubscriptionsPUT(t *testing.T) { t.Run(test.name, func(t *testing.T) { ctrl := gomock.NewController(t) mockDBClient := mocks.NewMockDBClient(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) reg := prometheus.NewRegistry() f := NewFrontend( @@ -271,21 +278,33 @@ func TestSubscriptionsPUT(t *testing.T) { mockDBClient.EXPECT(). GetLockClient(). MaxTimes(1) + if test.expectedStatusCode != http.StatusBadRequest { // ArmSubscriptionPut mockDBClient.EXPECT(). - GetSubscriptionDoc(gomock.Any(), gomock.Any()). + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Get(gomock.Any(), gomock.Any()). Return(getMockDBDoc(test.subDoc)) // ArmSubscriptionPut if test.subDoc == nil { mockDBClient.EXPECT(). - CreateSubscriptionDoc(gomock.Any(), gomock.Any(), gomock.Any()) - } else { + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Create(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, subscription *arm.Subscription, options *azcosmos.ItemOptions) (*arm.Subscription, error) { + return subscription, nil + }) + } else if test.expectUpdated { mockDBClient.EXPECT(). - UpdateSubscriptionDoc(gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, subscriptionID string, callback func(updateSubscription *arm.Subscription) bool) (bool, error) { - updated := callback(test.subDoc) - assert.Equal(t, test.expectUpdated, updated) + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Replace(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, updated *arm.Subscription, options *azcosmos.ItemOptions) (*arm.Subscription, error) { + assert.True(t, test.expectUpdated) return updated, nil }) } @@ -295,7 +314,7 @@ func TestSubscriptionsPUT(t *testing.T) { if test.subDoc != nil { subs[api.TestSubscriptionID] = test.subDoc } - ts := newHTTPServer(f, ctrl, mockDBClient, subs) + ts := newHTTPServer(f, ctrl, mockDBClient, mockSubscriptionCRUD, subs) urlPath := test.urlPath + "?api-version=" + arm.SubscriptionAPIVersion req, err := http.NewRequest(http.MethodPut, ts.URL+urlPath, bytes.NewReader(body)) @@ -465,6 +484,7 @@ func TestDeploymentPreflight(t *testing.T) { ctrl := gomock.NewController(t) mockDBClient := mocks.NewMockDBClient(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) reg := prometheus.NewRegistry() f := NewFrontend( @@ -479,7 +499,11 @@ func TestDeploymentPreflight(t *testing.T) { // MiddlewareValidateSubscriptionState and MetricsMiddleware mockDBClient.EXPECT(). - GetSubscriptionDoc(gomock.Any(), api.TestSubscriptionID). + Subscriptions(). + Return(mockSubscriptionCRUD). + MaxTimes(2) + mockSubscriptionCRUD.EXPECT(). + Get(gomock.Any(), api.TestSubscriptionID). Return(&arm.Subscription{ State: arm.SubscriptionStateRegistered, }, nil). @@ -490,7 +514,7 @@ func TestDeploymentPreflight(t *testing.T) { State: arm.SubscriptionStateRegistered, }, } - ts := newHTTPServer(f, ctrl, mockDBClient, subs) + ts := newHTTPServer(f, ctrl, mockDBClient, mockSubscriptionCRUD, subs) resource, err := json.Marshal(&test.resource) require.NoError(t, err) @@ -603,6 +627,7 @@ func TestRequestAdminCredential(t *testing.T) { mockCSClient := mocks.NewMockClusterServiceClientSpec(ctrl) mockOperationCRUD := mocks.NewMockOperationCRUD(ctrl) mockClusterCRUD := mocks.NewMockHCPClusterCRUD(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) f := NewFrontend( api.NewTestLogger(), @@ -616,7 +641,10 @@ func TestRequestAdminCredential(t *testing.T) { // MiddlewareValidateSubscriptionState and MetricsMiddleware mockDBClient.EXPECT(). - GetSubscriptionDoc(gomock.Any(), api.TestSubscriptionID). + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Get(gomock.Any(), api.TestSubscriptionID). Return(&arm.Subscription{ State: arm.SubscriptionStateRegistered, }, nil). @@ -708,7 +736,7 @@ func TestRequestAdminCredential(t *testing.T) { State: arm.SubscriptionStateRegistered, }, } - ts := newHTTPServer(f, ctrl, mockDBClient, subs) + ts := newHTTPServer(f, ctrl, mockDBClient, mockSubscriptionCRUD, subs) url := ts.URL + requestPath + "?api-version=" + api.TestAPIVersion resp, err := ts.Client().Post(url, "", nil) @@ -770,6 +798,7 @@ func TestRevokeCredentials(t *testing.T) { mockCSClient := mocks.NewMockClusterServiceClientSpec(ctrl) mockOperationCRUD := mocks.NewMockOperationCRUD(ctrl) mockClusterCRUD := mocks.NewMockHCPClusterCRUD(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) f := NewFrontend( api.NewTestLogger(), @@ -783,7 +812,10 @@ func TestRevokeCredentials(t *testing.T) { // MiddlewareValidateSubscriptionState and MetricsMiddleware mockDBClient.EXPECT(). - GetSubscriptionDoc(gomock.Any(), api.TestSubscriptionID). + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Get(gomock.Any(), api.TestSubscriptionID). Return(&arm.Subscription{ State: arm.SubscriptionStateRegistered, }, nil). @@ -915,7 +947,7 @@ func TestRevokeCredentials(t *testing.T) { State: arm.SubscriptionStateRegistered, }, } - ts := newHTTPServer(f, ctrl, mockDBClient, subs) + ts := newHTTPServer(f, ctrl, mockDBClient, mockSubscriptionCRUD, subs) url := ts.URL + requestPath + "?api-version=" + api.TestAPIVersion resp, err := ts.Client().Post(url, "", nil) @@ -998,7 +1030,7 @@ func assertHTTPMetrics(t *testing.T, r prometheus.Gatherer, subscription *arm.Su // newHTTPServer returns a test HTTP server. When a mock DB client is provided, // the subscription collector will be bootstrapped with the provided // subscription documents. -func newHTTPServer(f *Frontend, ctrl *gomock.Controller, mockDBClient *mocks.MockDBClient, subs map[string]*arm.Subscription) *httptest.Server { +func newHTTPServer(f *Frontend, ctrl *gomock.Controller, mockDBClient *mocks.MockDBClient, mockSubscriptionCRUD *mocks.MockSubscriptionCRUD, subs map[string]*arm.Subscription) *httptest.Server { ts := httptest.NewUnstartedServer(f.server.Handler) ts.Config.BaseContext = f.server.BaseContext ts.Start() @@ -1013,8 +1045,11 @@ func newHTTPServer(f *Frontend, ctrl *gomock.Controller, mockDBClient *mocks.Moc Return(nil) mockDBClient.EXPECT(). - ListAllSubscriptionDocs(). - Return(mockIter). + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(mockIter, nil). Times(1) // The initialization of the subscriptions collector is normally part of diff --git a/frontend/pkg/frontend/middleware_validatesubscription.go b/frontend/pkg/frontend/middleware_validatesubscription.go index 08e6fad82a..1259c80c51 100644 --- a/frontend/pkg/frontend/middleware_validatesubscription.go +++ b/frontend/pkg/frontend/middleware_validatesubscription.go @@ -60,7 +60,7 @@ func (h *middlewareValidateSubscriptionState) handleRequest(w http.ResponseWrite // TODO: Ideally, we don't want to have to hit the database in this middleware // Currently, we are using the database to retrieve the subscription's tenantID and state - subscription, err := h.dbClient.GetSubscriptionDoc(ctx, subscriptionId) + subscription, err := h.dbClient.Subscriptions().Get(ctx, subscriptionId) if err != nil { arm.WriteError( w, http.StatusBadRequest, diff --git a/frontend/pkg/frontend/middleware_validatesubscription_test.go b/frontend/pkg/frontend/middleware_validatesubscription_test.go index 1d31d225ee..7051780f7b 100644 --- a/frontend/pkg/frontend/middleware_validatesubscription_test.go +++ b/frontend/pkg/frontend/middleware_validatesubscription_test.go @@ -181,6 +181,7 @@ func TestMiddlewareValidateSubscription(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ctrl := gomock.NewController(t) mockDBClient := mocks.NewMockDBClient(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) var subscription *arm.Subscription @@ -211,7 +212,10 @@ func TestMiddlewareValidateSubscription(t *testing.T) { if tt.requestPath == defaultRequestPath { request.SetPathValue(PathSegmentSubscriptionID, subscriptionId) mockDBClient.EXPECT(). - GetSubscriptionDoc(gomock.Any(), subscriptionId). + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + Get(gomock.Any(), subscriptionId). Return(getMockDBDoc(subscription)) // defined in frontend_test.go } diff --git a/frontend/pkg/metrics/metrics.go b/frontend/pkg/metrics/metrics.go index 5df02a219e..11a5b99c03 100644 --- a/frontend/pkg/metrics/metrics.go +++ b/frontend/pkg/metrics/metrics.go @@ -151,7 +151,10 @@ func (sc *SubscriptionCollector) refresh(ctx context.Context, logger *slog.Logge func (sc *SubscriptionCollector) updateCache(ctx context.Context) error { subscriptions := make(map[string]subscription) - iter := sc.dbClient.ListAllSubscriptionDocs() + iter, err := sc.dbClient.Subscriptions().List(ctx, nil) + if err != nil { + return utils.TrackError(err) + } for id, sub := range iter.Items(ctx) { subscriptions[id] = subscription{ id: id, @@ -160,7 +163,7 @@ func (sc *SubscriptionCollector) updateCache(ctx context.Context) error { } } if err := iter.GetError(); err != nil { - return err + return utils.TrackError(err) } sc.mtx.Lock() diff --git a/frontend/pkg/metrics/metrics_test.go b/frontend/pkg/metrics/metrics_test.go index c508b1c8f8..98f5fb977c 100644 --- a/frontend/pkg/metrics/metrics_test.go +++ b/frontend/pkg/metrics/metrics_test.go @@ -45,6 +45,7 @@ func TestSubscriptionCollector(t *testing.T) { ctrl := gomock.NewController(t) mockDBClient := mocks.NewMockDBClient(ctrl) + mockSubscriptionCRUD := mocks.NewMockSubscriptionCRUD(ctrl) r := prometheus.NewPedanticRegistry() collector := NewSubscriptionCollector(r, mockDBClient, "test") @@ -59,9 +60,11 @@ func TestSubscriptionCollector(t *testing.T) { Return(nil) mockDBClient.EXPECT(). - ListAllSubscriptionDocs(). - Return(mockIter). - Times(1) + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(mockIter, nil).Times(1) collector.refresh(context.Background(), logger) assertMetrics(t, r, 5, `# HELP frontend_subscription_collector_failed_syncs_total Total number of failed syncs for the Subscription collector. @@ -85,9 +88,11 @@ frontend_subscription_collector_last_sync 1 GetError(). Return(errors.New("db error")) mockDBClient.EXPECT(). - ListAllSubscriptionDocs(). - Return(mockIter). - Times(1) + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(mockIter, nil).Times(1) collector.refresh(context.Background(), logger) @@ -112,9 +117,11 @@ frontend_subscription_collector_last_sync 0 GetError(). Return(nil) mockDBClient.EXPECT(). - ListAllSubscriptionDocs(). - Return(mockIter). - Times(1) + Subscriptions(). + Return(mockSubscriptionCRUD) + mockSubscriptionCRUD.EXPECT(). + List(gomock.Any(), gomock.Any()). + Return(mockIter, nil).Times(1) collector.refresh(context.Background(), logger) diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/00-load-old-subscription/01-subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/00-load-old-subscription/01-subscription.json new file mode 100644 index 0000000000..0cfd4e7010 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/00-load-old-subscription/01-subscription.json @@ -0,0 +1,15 @@ +{ + "_attachments": "attachments/", + "_etag": "\"00000000-0000-0000-6f81-5c29040501dc\"", + "_rid": "LKgdAIiT-BgDAAAAAAAAAA==", + "_self": "dbs/LKgdAA==/colls/LKgdAIiT-Bg=/docs/LKgdAIiT-BgDAAAAAAAAAA==/", + "_ts": 1765995430, + "id": "b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "partitionKey": "b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "properties": { + "properties": null, + "registrationDate": "2025-12-17T18:16:37+00:00", + "state": "Registered" + }, + "resourceType": "microsoft.resources/subscriptions" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/01-replace-old/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/01-replace-old/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/01-replace-old/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/01-replace-old/subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/01-replace-old/subscription.json new file mode 100644 index 0000000000..1fa557a933 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/01-replace-old/subscription.json @@ -0,0 +1,8 @@ +{ + "resourceId": "/subscriptions/b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": { + "tenantId": "value" + } +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/02-get-old/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/02-get-old/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/02-get-old/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/02-get-old/subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/02-get-old/subscription.json new file mode 100644 index 0000000000..1fa557a933 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/02-get-old/subscription.json @@ -0,0 +1,8 @@ +{ + "resourceId": "/subscriptions/b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": { + "tenantId": "value" + } +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/03-create-new/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/03-create-new/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/03-create-new/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/03-create-new/new-subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/03-create-new/new-subscription.json new file mode 100644 index 0000000000..a2357c992a --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/03-create-new/new-subscription.json @@ -0,0 +1,8 @@ +{ + "resourceId": "/subscriptions/bb1db6bc-e0f9-47d9-b9c2-2575101c2f69", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": { + "tenantId": "other" + } +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/new-subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/new-subscription.json new file mode 100644 index 0000000000..a2357c992a --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/new-subscription.json @@ -0,0 +1,8 @@ +{ + "resourceId": "/subscriptions/bb1db6bc-e0f9-47d9-b9c2-2575101c2f69", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": { + "tenantId": "other" + } +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/old-subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/old-subscription.json new file mode 100644 index 0000000000..1fa557a933 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/mutate-old-via-new/04-list-all/old-subscription.json @@ -0,0 +1,8 @@ +{ + "resourceId": "/subscriptions/b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": { + "tenantId": "value" + } +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/00-load-new/new-subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/00-load-new/new-subscription.json new file mode 100644 index 0000000000..1c2dbba001 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/00-load-new/new-subscription.json @@ -0,0 +1,16 @@ +{ + "_attachments": "attachments/", + "_etag": "\"00000000-0000-0000-6f82-a69a340501dc\"", + "_rid": "zDEpAKTEnZ4DAAAAAAAAAA==", + "_self": "dbs/zDEpAA==/colls/zDEpAKTEnZ4=/docs/zDEpAKTEnZ4DAAAAAAAAAA==/", + "_ts": 1765995984, + "id": "ddfbdeeb-89a1-4a9a-9469-2895f63e2d82", + "partitionKey": "ddfbdeeb-89a1-4a9a-9469-2895f63e2d82", + "properties": { + "properties": null, + "registrationDate": "2025-12-17T18:16:37+00:00", + "resourceId": "/subscriptions/ddfbdeeb-89a1-4a9a-9469-2895f63e2d82", + "state": "Registered" + }, + "resourceType": "Microsoft.Resources/subscriptions" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/01-list-new/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/01-list-new/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/01-list-new/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/01-list-new/subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/01-list-new/subscription.json new file mode 100644 index 0000000000..56af75179a --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/01-list-new/subscription.json @@ -0,0 +1,6 @@ +{ + "resourceId": "/subscriptions/ddfbdeeb-89a1-4a9a-9469-2895f63e2d82", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": null +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/02-get-new/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/02-get-new/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/02-get-new/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/02-get-new/subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/02-get-new/subscription.json new file mode 100644 index 0000000000..56af75179a --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-new/02-get-new/subscription.json @@ -0,0 +1,6 @@ +{ + "resourceId": "/subscriptions/ddfbdeeb-89a1-4a9a-9469-2895f63e2d82", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": null +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/00-load-old-subscription/01-subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/00-load-old-subscription/01-subscription.json new file mode 100644 index 0000000000..0cfd4e7010 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/00-load-old-subscription/01-subscription.json @@ -0,0 +1,15 @@ +{ + "_attachments": "attachments/", + "_etag": "\"00000000-0000-0000-6f81-5c29040501dc\"", + "_rid": "LKgdAIiT-BgDAAAAAAAAAA==", + "_self": "dbs/LKgdAA==/colls/LKgdAIiT-Bg=/docs/LKgdAIiT-BgDAAAAAAAAAA==/", + "_ts": 1765995430, + "id": "b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "partitionKey": "b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "properties": { + "properties": null, + "registrationDate": "2025-12-17T18:16:37+00:00", + "state": "Registered" + }, + "resourceType": "microsoft.resources/subscriptions" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/01-list-old/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/01-list-old/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/01-list-old/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/01-list-old/subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/01-list-old/subscription.json new file mode 100644 index 0000000000..5a5f32c380 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/01-list-old/subscription.json @@ -0,0 +1,6 @@ +{ + "resourceId": "/subscriptions/b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": null +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/02-get-old/00-key.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/02-get-old/00-key.json new file mode 100644 index 0000000000..31ec5e3696 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/02-get-old/00-key.json @@ -0,0 +1,3 @@ +{ + "parentResourceID": "" +} \ No newline at end of file diff --git a/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/02-get-old/subscription.json b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/02-get-old/subscription.json new file mode 100644 index 0000000000..5a5f32c380 --- /dev/null +++ b/frontend/test/simulate/artifacts/DatabaseCRUD/SubscriptionCRUD/read-old/02-get-old/subscription.json @@ -0,0 +1,6 @@ +{ + "resourceId": "/subscriptions/b3fa2ee1-d1b3-4eaa-beae-7dd4003b4987", + "state": "Registered", + "registrationDate": "2025-12-17T18:16:37+00:00", + "properties": null +} \ No newline at end of file diff --git a/internal/api/arm/subscription.go b/internal/api/arm/subscription.go index a5938a7ccc..0d9e55eb2a 100644 --- a/internal/api/arm/subscription.go +++ b/internal/api/arm/subscription.go @@ -16,17 +16,33 @@ package arm import ( "iter" + "path" "slices" + "strings" "k8s.io/apimachinery/pkg/util/sets" azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" ) +// CosmosData contains the information that persisted resources must have for us to support CRUD against them. +// These are not (currently) all stored in the same place in our various types. +type CosmosData struct { + CosmosUID string + PartitionKey string + ItemID *azcorearm.ResourceID +} + // SubscriptionAPIVersion is the system API version for the subscription endpoint. const SubscriptionAPIVersion = "2.0" +func ToSubscriptionResourceID(subscriptionName string) (*azcorearm.ResourceID, error) { + return azcorearm.ParseResourceID(strings.ToLower(path.Join("/subscriptions", subscriptionName))) +} + type Subscription struct { + ResourceID *azcorearm.ResourceID `json:"resourceId,omitempty"` + // The resource provider contract gives an example RegistrationDate // in RFC1123 format but does not explicitly state a required format // so we leave it a plain string. @@ -39,6 +55,18 @@ type Subscription struct { LastUpdated int `json:"-"` } +func (o *Subscription) GetCosmosData() CosmosData { + return CosmosData{ + CosmosUID: o.ResourceID.Name, // this is compatible with preexisting code + PartitionKey: strings.ToLower(o.ResourceID.Name), + ItemID: o.ResourceID, + } +} + +func (o *Subscription) SetCosmosDocumentData(cosmosUID string) { + panic("unsupported") +} + // GetValidTypes returns the valid resource types for a Subscription. func (s Subscription) GetValidTypes() []string { return []string{azcorearm.SubscriptionResourceType.String()} diff --git a/internal/api/types_cosmosdata.go b/internal/api/types_cosmosdata.go index e7e1c30f4a..6962760b60 100644 --- a/internal/api/types_cosmosdata.go +++ b/internal/api/types_cosmosdata.go @@ -15,7 +15,7 @@ package api import ( - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/ARO-HCP/internal/api/arm" ) type CosmosPersistable interface { @@ -25,8 +25,4 @@ type CosmosPersistable interface { // CosmosData contains the information that persisted resources must have for us to support CRUD against them. // These are not (currently) all stored in the same place in our various types. -type CosmosData struct { - CosmosUID string - PartitionKey string - ItemID *azcorearm.ResourceID -} +type CosmosData = arm.CosmosData diff --git a/internal/database/convert_any.go b/internal/database/convert_any.go index 6fc8a744c1..efb745f8e8 100644 --- a/internal/database/convert_any.go +++ b/internal/database/convert_any.go @@ -18,6 +18,7 @@ import ( "fmt" "github.com/Azure/ARO-HCP/internal/api" + "github.com/Azure/ARO-HCP/internal/api/arm" "github.com/Azure/ARO-HCP/internal/utils" ) @@ -40,6 +41,9 @@ func CosmosToInternal[InternalAPIType, CosmosAPIType any](obj *CosmosAPIType) (* case *Operation: internalObj, err = CosmosToInternalOperation(cosmosObj) + case *Subscription: + internalObj, err = CosmosToInternalSubscription(cosmosObj) + case *TypedDocument: var expectedObj InternalAPIType switch castObj := any(expectedObj).(type) { @@ -83,6 +87,9 @@ func InternalToCosmos[InternalAPIType, CosmosAPIType any](obj *InternalAPIType) case *api.Operation: cosmosObj, err = InternalToCosmosOperation(internalObj) + case *arm.Subscription: + cosmosObj, err = InternalToCosmosSubscription(internalObj) + case *TypedDocument: var expectedObj CosmosAPIType switch castObj := any(expectedObj).(type) { diff --git a/internal/database/convert_subscription.go b/internal/database/convert_subscription.go new file mode 100644 index 0000000000..dc689c3234 --- /dev/null +++ b/internal/database/convert_subscription.go @@ -0,0 +1,71 @@ +// Copyright 2025 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "fmt" + "path" + "strings" + + azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + + "github.com/Azure/ARO-HCP/internal/api/arm" + "github.com/Azure/ARO-HCP/internal/utils" +) + +func InternalToCosmosSubscription(internalObj *arm.Subscription) (*Subscription, error) { + if internalObj == nil { + return nil, nil + } + + if len(internalObj.ResourceID.Name) == 0 { + return nil, fmt.Errorf("invalid resource id: %q", internalObj.ResourceID.String()) + } + + cosmosObj := &Subscription{ + TypedDocument: TypedDocument{ + BaseDocument: BaseDocument{ + ID: strings.ToLower(internalObj.ResourceID.Name), + }, + PartitionKey: strings.ToLower(internalObj.ResourceID.Name), + ResourceType: internalObj.ResourceID.ResourceType.String(), + }, + InternalState: SubscriptionProperties{ + Subscription: *internalObj, + }, + } + + return cosmosObj, nil +} + +func CosmosToInternalSubscription(cosmosObj *Subscription) (*arm.Subscription, error) { + if cosmosObj == nil { + return nil, nil + } + + tempInternalAPI := cosmosObj.InternalState.Subscription + internalObj := &tempInternalAPI + + // some pieces of data are stored on the ResourceDocument, so we need to restore that data + // this allows us to read old data until we migrate all existing data + resourceID, err := azcorearm.ParseResourceID(path.Join("/subscriptions", cosmosObj.ID)) + if err != nil { + return nil, utils.TrackError(err) + } + internalObj.ResourceID = resourceID + internalObj.LastUpdated = cosmosObj.CosmosTimestamp + + return internalObj, nil +} diff --git a/internal/database/crud_helpers.go b/internal/database/crud_helpers.go index 453910ced7..714aba5d1b 100644 --- a/internal/database/crud_helpers.go +++ b/internal/database/crud_helpers.go @@ -126,21 +126,35 @@ func list[InternalAPIType, CosmosAPIType any](ctx context.Context, containerClie if strings.ToLower(partitionKeyString) != partitionKeyString { return nil, fmt.Errorf("partitionKeyString must be lowercase, not: %q", partitionKeyString) } + if prefix == nil && resourceType == nil { + return nil, fmt.Errorf("prefix or resource type is required") + } - query := "SELECT * FROM c WHERE STARTSWITH(c.properties.resourceId, @prefix, true)" - + query := "" queryOptions := azcosmos.QueryOptions{ PageSizeHint: -1, - QueryParameters: []azcosmos.QueryParameter{ - { - Name: "@prefix", - Value: prefix.String() + "/", + } + if prefix == nil { + query = "SELECT * FROM c" + } else { + query = "SELECT * FROM c WHERE STARTSWITH(c.properties.resourceId, @prefix, true)" + queryOptions = azcosmos.QueryOptions{ + PageSizeHint: -1, + QueryParameters: []azcosmos.QueryParameter{ + { + Name: "@prefix", + Value: prefix.String() + "/", + }, }, - }, + } } if resourceType != nil { - query += " AND STRINGEQUALS(c.resourceType, @resourceType, true)" + if prefix == nil { + query += " WHERE STRINGEQUALS(c.resourceType, @resourceType, true)" + } else { + query += " AND STRINGEQUALS(c.resourceType, @resourceType, true)" + } queryParameter := azcosmos.QueryParameter{ Name: "@resourceType", Value: resourceType.String(), @@ -171,8 +185,14 @@ func list[InternalAPIType, CosmosAPIType any](ctx context.Context, containerClie } queryOptions.ContinuationToken = options.ContinuationToken } + var partitionKey azcosmos.PartitionKey + if len(partitionKeyString) > 0 { + partitionKey = azcosmos.NewPartitionKeyString(partitionKeyString) + } else { + partitionKey = azcosmos.NewPartitionKey() + } - pager := containerClient.NewQueryItemsPager(query, azcosmos.NewPartitionKeyString(partitionKeyString), &queryOptions) + pager := containerClient.NewQueryItemsPager(query, partitionKey, &queryOptions) if options != nil && ptr.Deref(options.PageSizeHint, -1) > 0 { return newQueryResourcesSinglePageIterator[InternalAPIType, CosmosAPIType](pager), nil @@ -186,7 +206,7 @@ func list[InternalAPIType, CosmosAPIType any](ctx context.Context, containerClie func serializeItem[InternalAPIType, CosmosAPIType any](newObj *InternalAPIType) (string, string, []byte, error) { cosmosPersistable, ok := any(newObj).(api.CosmosPersistable) if !ok { - return "", "", nil, fmt.Errorf("type %T does not implement ResourceProperties interface", newObj) + return "", "", nil, fmt.Errorf("type %T does not implement CosmosPersistable interface", newObj) } cosmosData := cosmosPersistable.GetCosmosData() diff --git a/internal/database/crud_subscription.go b/internal/database/crud_subscription.go new file mode 100644 index 0000000000..1334d02031 --- /dev/null +++ b/internal/database/crud_subscription.go @@ -0,0 +1,110 @@ +// Copyright 2025 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "context" + "fmt" + "strings" + + azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" + + "github.com/Azure/ARO-HCP/internal/api/arm" + "github.com/Azure/ARO-HCP/internal/utils" +) + +type SubscriptionCRUD interface { + ResourceCRUD[arm.Subscription] +} + +type subscriptionCRUD struct { + containerClient *azcosmos.ContainerClient +} + +var _ SubscriptionCRUD = (*subscriptionCRUD)(nil) + +func NewSubscriptionCRUD(containerClient *azcosmos.ContainerClient) SubscriptionCRUD { + return &subscriptionCRUD{ + containerClient: containerClient, + } +} + +func (d *subscriptionCRUD) GetByID(ctx context.Context, cosmosID string) (*arm.Subscription, error) { + // for subscriptions, the cosmosID IS the partitionKey (at least for now) + if strings.ToLower(cosmosID) != cosmosID { + return nil, fmt.Errorf("cosmosID must be lowercase, not: %q", cosmosID) + } + partitionKey := strings.ToLower(cosmosID) + + return getByItemID[arm.Subscription, Subscription](ctx, d.containerClient, partitionKey, cosmosID) +} + +func (d *subscriptionCRUD) Get(ctx context.Context, resourceName string) (*arm.Subscription, error) { + // notice that this function will not work until we rewrite all records so that subscriptions contain a resourceID + // first attempt to use the old way. + byID, byIDErr := d.GetByID(ctx, resourceName) + if byIDErr == nil { + return byID, nil + } + logger := utils.LoggerFromContext(ctx) + logger.Info("record has been migrated, trying new lookup") + + // for subscriptions, the resourceName IS the partitionKey (at least for now). + completeResourceID, err := arm.ToSubscriptionResourceID(resourceName) + if err != nil { + return nil, fmt.Errorf("failed to make ResourceID path for '%s': %w", resourceName, err) + } + partitionKey := strings.ToLower(completeResourceID.SubscriptionID) + + return get[arm.Subscription, Subscription](ctx, d.containerClient, partitionKey, completeResourceID) +} + +func (d *subscriptionCRUD) List(ctx context.Context, options *DBClientListResourceDocsOptions) (DBClientIterator[arm.Subscription], error) { + // prefix is intentionally nil so that we don't have a resource prefix until after we've written all reords with a resourceID + var prefix *azcorearm.ResourceID + // list must be across all partitions + partitionKey := "" + + return list[arm.Subscription, Subscription](ctx, d.containerClient, partitionKey, &azcorearm.SubscriptionResourceType, prefix, options, false) +} + +func (d *subscriptionCRUD) AddCreateToTransaction(ctx context.Context, transaction DBTransaction, newObj *arm.Subscription, opts *azcosmos.TransactionalBatchItemOptions) (string, error) { + return addCreateToTransaction[arm.Subscription, Subscription](ctx, transaction, newObj, opts) +} + +func (d *subscriptionCRUD) AddReplaceToTransaction(ctx context.Context, transaction DBTransaction, newObj *arm.Subscription, opts *azcosmos.TransactionalBatchItemOptions) (string, error) { + return addReplaceToTransaction[arm.Subscription, Subscription](ctx, transaction, newObj, opts) +} + +func (d *subscriptionCRUD) Create(ctx context.Context, newObj *arm.Subscription, options *azcosmos.ItemOptions) (*arm.Subscription, error) { + partitionKey := strings.ToLower(newObj.ResourceID.SubscriptionID) + return create[arm.Subscription, Subscription](ctx, d.containerClient, partitionKey, newObj, options) +} + +func (d *subscriptionCRUD) Replace(ctx context.Context, newObj *arm.Subscription, options *azcosmos.ItemOptions) (*arm.Subscription, error) { + partitionKey := strings.ToLower(newObj.ResourceID.SubscriptionID) + return replace[arm.Subscription, Subscription](ctx, d.containerClient, partitionKey, newObj, options) +} + +func (d *subscriptionCRUD) Delete(ctx context.Context, resourceName string) error { + completeResourceID, err := arm.ToSubscriptionResourceID(resourceName) + if err != nil { + return fmt.Errorf("failed to make ResourceID path for '%s': %w", resourceName, err) + } + partitionKey := strings.ToLower(completeResourceID.SubscriptionID) + + return deleteResource(ctx, d.containerClient, partitionKey, completeResourceID) +} diff --git a/internal/database/database.go b/internal/database/database.go index fd1ac5d612..e1e55d1e9f 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -28,7 +28,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azidentity" "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" - "github.com/Azure/ARO-HCP/internal/api/arm" "github.com/Azure/ARO-HCP/internal/utils" ) @@ -118,32 +117,7 @@ type DBClient interface { // to end users via ARM. They must also survive the thing they are deleting, so they live under a subscription directly. Operations(subscriptionID string) OperationCRUD - // GetSubscriptionDoc retrieves a subscription document from the "Resources" container. - GetSubscriptionDoc(ctx context.Context, subscriptionID string) (*arm.Subscription, error) - - // CreateSubscriptionDoc creates a new subscription document in the "Resources" container. - CreateSubscriptionDoc(ctx context.Context, subscriptionID string, subscription *arm.Subscription) error - - // UpdateSubscriptionDoc updates a subscription document in the "Resources" container by first - // fetching the document and passing it to the provided callback for modifications to be applied. - // It then attempts to replace the existing document with the modified document an an "etag" - // precondition. Upon a precondition failure the function repeats for a limited number of times - // before giving up. - // - // The callback function should return true if modifications were applied, signaling to proceed - // with the document replacement. The boolean return value reflects this: returning true if the - // document was successfully replaced, or false with or without an error to indicate no change. - UpdateSubscriptionDoc(ctx context.Context, subscriptionID string, callback func(*arm.Subscription) bool) (bool, error) - - // ListAllSubscriptionDocs() returns an iterator that searches for all subscription documents in - // the "Resources" container. Since the "Resources" container is partitioned by subscription ID, - // there will only be one subscription document per logical partition. Thus, this method enables - // iterating over all the logical partitions in the "Resources" container. - // - // Note that ListAllSubscriptionDocs does not perform the search, but merely prepares an iterator - // to do so. Hence the lack of a Context argument. The search is performed by calling Items() on - // the iterator in a ranged for loop. - ListAllSubscriptionDocs() DBClientIterator[arm.Subscription] + Subscriptions() SubscriptionCRUD } var _ DBClient = &cosmosDBClient{} @@ -284,107 +258,6 @@ func (d *cosmosDBClient) PatchBillingDoc(ctx context.Context, resourceID *azcore return nil } -func (d *cosmosDBClient) getSubscriptionDoc(ctx context.Context, subscriptionID string) (*TypedDocument, *arm.Subscription, error) { - // Make sure lookup keys are lowercase. - subscriptionID = strings.ToLower(subscriptionID) - - pk := NewPartitionKey(subscriptionID) - - response, err := d.resources.ReadItem(ctx, pk, subscriptionID, nil) - if err != nil { - return nil, nil, fmt.Errorf("failed to read Subscriptions container item for '%s': %w", subscriptionID, err) - } - - typedDoc, innerDoc, err := typedDocumentUnmarshal[arm.Subscription](response.Value) - if err != nil { - return nil, nil, fmt.Errorf("failed to unmarshal Subscriptions container item for '%s': %w", subscriptionID, err) - } - - // Expose the "_ts" field for metics reporting. - innerDoc.LastUpdated = typedDoc.CosmosTimestamp - - return typedDoc, innerDoc, nil -} - -func (d *cosmosDBClient) GetSubscriptionDoc(ctx context.Context, subscriptionID string) (*arm.Subscription, error) { - _, innerDoc, err := d.getSubscriptionDoc(ctx, subscriptionID) - return innerDoc, err -} - -func (d *cosmosDBClient) CreateSubscriptionDoc(ctx context.Context, subscriptionID string, subscription *arm.Subscription) error { - typedDoc := newTypedDocument(subscriptionID, azcorearm.SubscriptionResourceType) - typedDoc.ID = strings.ToLower(subscriptionID) - - data, err := typedDocumentMarshal(typedDoc, subscription) - if err != nil { - return fmt.Errorf("failed to marshal Subscriptions container item for '%s': %w", subscriptionID, err) - } - - _, err = d.resources.CreateItem(ctx, typedDoc.getPartitionKey(), data, nil) - if err != nil { - return fmt.Errorf("failed to create Subscriptions container item for '%s': %w", subscriptionID, err) - } - - return nil -} - -func (d *cosmosDBClient) UpdateSubscriptionDoc(ctx context.Context, subscriptionID string, callback func(*arm.Subscription) bool) (bool, error) { - var err error - - options := &azcosmos.ItemOptions{} - - for try := 0; try < 5; try++ { - var typedDoc *TypedDocument - var innerDoc *arm.Subscription - var data []byte - - typedDoc, innerDoc, err = d.getSubscriptionDoc(ctx, subscriptionID) - if err != nil { - return false, err - } - - if !callback(innerDoc) { - return false, nil - } - - data, err = typedDocumentMarshal(typedDoc, innerDoc) - if err != nil { - return false, fmt.Errorf("failed to marshal Subscriptions container item for '%s': %w", subscriptionID, err) - } - - options.IfMatchEtag = &typedDoc.CosmosETag - _, err = d.resources.ReplaceItem(ctx, typedDoc.getPartitionKey(), typedDoc.ID, data, options) - if err == nil { - return true, nil - } - - var responseError *azcore.ResponseError - err = fmt.Errorf("failed to replace Subscriptions container item for '%s': %w", subscriptionID, err) - if !errors.As(err, &responseError) || responseError.StatusCode != http.StatusPreconditionFailed { - return false, err - } - } - - return false, err -} - -func (d *cosmosDBClient) ListAllSubscriptionDocs() DBClientIterator[arm.Subscription] { - const query = "SELECT * FROM c WHERE STRINGEQUALS(c.resourceType, @resourceType, true)" - opt := azcosmos.QueryOptions{ - QueryParameters: []azcosmos.QueryParameter{ - { - Name: "@resourceType", - Value: azcorearm.SubscriptionResourceType.String(), - }, - }, - } - - // Empty partition key triggers a cross-partition query. - pager := d.resources.NewQueryItemsPager(query, azcosmos.NewPartitionKey(), &opt) - - return newQueryItemsIterator[arm.Subscription](pager) -} - func (d *cosmosDBClient) HCPClusters(subscriptionID, resourceGroupName string) HCPClusterCRUD { return NewHCPClusterCRUD(d.resources, subscriptionID, resourceGroupName) } @@ -393,6 +266,10 @@ func (d *cosmosDBClient) Operations(subscriptionID string) OperationCRUD { return NewOperationCRUD(d.resources, subscriptionID) } +func (d *cosmosDBClient) Subscriptions() SubscriptionCRUD { + return NewSubscriptionCRUD(d.resources) +} + func (d *cosmosDBClient) UntypedCRUD(parentResourceID azcorearm.ResourceID) (UntypedResourceCRUD, error) { return NewUntypedCRUD(d.resources, parentResourceID), nil } diff --git a/internal/database/types_externalauth.go b/internal/database/types_externalauth.go index 2fa2d4f9ed..8585933c0d 100644 --- a/internal/database/types_externalauth.go +++ b/internal/database/types_externalauth.go @@ -17,8 +17,6 @@ package database import ( "fmt" - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/ARO-HCP/internal/api" ) @@ -51,7 +49,3 @@ func (o *ExternalAuth) ValidateResourceType() error { func (o *ExternalAuth) GetTypedDocument() *TypedDocument { return &o.TypedDocument } - -func (o *ExternalAuth) SetResourceID(newResourceID *azcorearm.ResourceID) { - o.ResourceDocument.SetResourceID(newResourceID) -} diff --git a/internal/database/types_hcpcluster.go b/internal/database/types_hcpcluster.go index 9fadb202b1..95167b1e7a 100644 --- a/internal/database/types_hcpcluster.go +++ b/internal/database/types_hcpcluster.go @@ -17,8 +17,6 @@ package database import ( "fmt" - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/ARO-HCP/internal/api" ) @@ -51,7 +49,3 @@ func (o *HCPCluster) ValidateResourceType() error { func (o *HCPCluster) GetTypedDocument() *TypedDocument { return &o.TypedDocument } - -func (o *HCPCluster) SetResourceID(newResourceID *azcorearm.ResourceID) { - o.ResourceDocument.SetResourceID(newResourceID) -} diff --git a/internal/database/types_nodepool.go b/internal/database/types_nodepool.go index 9f2aca86c3..7f75492e77 100644 --- a/internal/database/types_nodepool.go +++ b/internal/database/types_nodepool.go @@ -17,8 +17,6 @@ package database import ( "fmt" - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/ARO-HCP/internal/api" ) @@ -51,7 +49,3 @@ func (o *NodePool) ValidateResourceType() error { func (o *NodePool) GetTypedDocument() *TypedDocument { return &o.TypedDocument } - -func (o *NodePool) SetResourceID(newResourceID *azcorearm.ResourceID) { - o.ResourceDocument.SetResourceID(newResourceID) -} diff --git a/internal/database/types_subscription.go b/internal/database/types_subscription.go new file mode 100644 index 0000000000..5d39a139df --- /dev/null +++ b/internal/database/types_subscription.go @@ -0,0 +1,46 @@ +// Copyright 2025 Microsoft Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package database + +import ( + "fmt" + + azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + + "github.com/Azure/ARO-HCP/internal/api/arm" +) + +type Subscription struct { + TypedDocument `json:",inline"` + + InternalState SubscriptionProperties `json:"properties"` +} + +var _ ResourceProperties = &Subscription{} + +type SubscriptionProperties struct { + arm.Subscription `json:",inline"` +} + +func (o *Subscription) ValidateResourceType() error { + if o.ResourceType != azcorearm.SubscriptionResourceType.String() { + return fmt.Errorf("invalid resource type: %s", o.ResourceType) + } + return nil +} + +func (o *Subscription) GetTypedDocument() *TypedDocument { + return &o.TypedDocument +} diff --git a/internal/database/types_typeddocument.go b/internal/database/types_typeddocument.go index c78207dacc..af8e5ebd02 100644 --- a/internal/database/types_typeddocument.go +++ b/internal/database/types_typeddocument.go @@ -16,29 +16,8 @@ package database import ( "encoding/json" - "fmt" - "reflect" - "strings" - - azcorearm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" - "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" ) -// typedDocumentError signifies a mismatched Type field and Properties type -// when attempting to unmarshal JSON-encoded data. -type typedDocumentError struct { - invalidType string - propertiesType string -} - -func (e typedDocumentError) Error() string { - if e.invalidType == "" { - return "missing type" - } - - return fmt.Sprintf("invalid type '%s' for %s", e.invalidType, e.propertiesType) -} - // TypedDocument is a BaseDocument with a ResourceType field to // help distinguish heterogeneous items in a Cosmos DB container. // The Properties field can be unmarshalled to any type that @@ -57,83 +36,3 @@ var ( func (td *TypedDocument) GetTypedDocument() *TypedDocument { return td } - -// newTypedDocument returns a TypedDocument from a ResourceType. -func newTypedDocument(partitionKey string, resourceType azcorearm.ResourceType) *TypedDocument { - return &TypedDocument{ - BaseDocument: newBaseDocument(), - PartitionKey: strings.ToLower(partitionKey), - ResourceType: strings.ToLower(resourceType.String()), - } -} - -// getPartitionKey returns an azcosmos.PartitionKey. -func (td *TypedDocument) getPartitionKey() azcosmos.PartitionKey { - return azcosmos.NewPartitionKeyString(td.PartitionKey) -} - -// validateType validates the type field against the given properties type. -// If type validation fails, validateType returns a typedDocumentError. -func (td *TypedDocument) validateType(properties DocumentProperties) error { - for _, t := range properties.GetValidTypes() { - if strings.EqualFold(td.ResourceType, t) { - return nil - } - } - - propertiesType := reflect.TypeOf(properties) - if propertiesType.Kind() == reflect.Pointer { - propertiesType = propertiesType.Elem() - } - - return &typedDocumentError{ - invalidType: td.ResourceType, - propertiesType: propertiesType.Name(), - } -} - -// typedDocumentMarshal returns the JSON encoding of typedDoc with innerDoc -// as the properties value. First, however, typedDocumentMarshal validates -// the type field in typeDoc against innerDoc to ensure compatibility. If -// validation fails, typedDocumentMarshal returns a typedDocumentError. -func typedDocumentMarshal[T DocumentProperties](typedDoc *TypedDocument, innerDoc *T) ([]byte, error) { - err := typedDoc.validateType(*innerDoc) - if err != nil { - return nil, err - } - - data, err := json.Marshal(innerDoc) - if err != nil { - return nil, err - } - - typedDoc.Properties = data - - return json.Marshal(typedDoc) -} - -// typedDocumentUnmarshal parses JSON-encoded data into a TypedDocument, -// validates the type field against the type parameter T, and then parses -// the JSON-encoded properties data into an instance of type parameter T. -// If validation fails, typedDocumentUnmarshal returns a typedDocumentError. -func typedDocumentUnmarshal[T DocumentProperties](data []byte) (*TypedDocument, *T, error) { - var typedDoc TypedDocument - var innerDoc T - - err := json.Unmarshal(data, &typedDoc) - if err != nil { - return nil, nil, err - } - - err = typedDoc.validateType(innerDoc) - if err != nil { - return nil, nil, err - } - - err = json.Unmarshal(typedDoc.Properties, &innerDoc) - if err != nil { - return nil, nil, err - } - - return &typedDoc, &innerDoc, nil -} diff --git a/internal/database/types_typeddocument_test.go b/internal/database/types_typeddocument_test.go deleted file mode 100644 index a0b08ecff1..0000000000 --- a/internal/database/types_typeddocument_test.go +++ /dev/null @@ -1,111 +0,0 @@ -// Copyright 2025 Microsoft Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package database - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -const testResourceType = "test" -const testPropertiesValue = "foo" - -type testProperties struct { - Value string -} - -func (p testProperties) GetValidTypes() []string { - return []string{testResourceType} -} - -func TestTypedDocumentMarshal(t *testing.T) { - tests := []struct { - name string - typedDoc *TypedDocument - err string - }{ - { - name: "successful marshal", - typedDoc: &TypedDocument{ - ResourceType: testResourceType, - }, - err: "", - }, - { - name: "missing resource type", - typedDoc: &TypedDocument{}, - err: "missing type", - }, - { - name: "invalid resource type", - typedDoc: &TypedDocument{ - ResourceType: "invalid", - }, - err: "invalid type 'invalid' for testProperties", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - innerDoc := &testProperties{testPropertiesValue} - data, err := typedDocumentMarshal[testProperties](tt.typedDoc, innerDoc) - - if tt.err != "" { - assert.EqualError(t, err, tt.err) - } else if assert.NoError(t, err) { - assert.NotEmpty(t, data) - } - }) - } -} - -func TestTypedDocumentUnmarshal(t *testing.T) { - tests := []struct { - name string - data string - err string - }{ - { - name: "successful unmarshal", - data: fmt.Sprintf("{\"resourceType\": \"%s\", \"properties\": {\"value\": \"%s\"}}", testResourceType, testPropertiesValue), - err: "", - }, - { - name: "missing resource type", - data: fmt.Sprintf("{\"properties\": {\"value\": \"%s\"}}", testPropertiesValue), - err: "missing type", - }, - { - name: "invalid resource type", - data: fmt.Sprintf("{\"resourceType\": \"invalid\", \"properties\": {\"value\": \"%s\"}}", testPropertiesValue), - err: "invalid type 'invalid' for testProperties", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - typedDoc, innerDoc, err := typedDocumentUnmarshal[testProperties]([]byte(tt.data)) - - if tt.err != "" { - assert.EqualError(t, err, tt.err) - } else if assert.NoError(t, err) { - assert.Equal(t, testResourceType, typedDoc.ResourceType) - assert.Equal(t, testPropertiesValue, innerDoc.Value) - } - }) - } -} diff --git a/internal/database/util.go b/internal/database/util.go deleted file mode 100644 index b2e7a927cd..0000000000 --- a/internal/database/util.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2025 Microsoft Corporation -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package database - -import ( - "context" - - "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" - "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" -) - -type queryItemsIterator[T DocumentProperties] struct { - pager *runtime.Pager[azcosmos.QueryItemsResponse] - singlePage bool - continuationToken string - err error -} - -// newqueryItemsIterator is a failable push iterator for a paged query response. -func newQueryItemsIterator[T DocumentProperties](pager *runtime.Pager[azcosmos.QueryItemsResponse]) DBClientIterator[T] { - return &queryItemsIterator[T]{pager: pager} -} - -// Items returns a push iterator that can be used directly in for/range loops. -// If an error occurs during paging, iteration stops and the error is recorded. -func (iter *queryItemsIterator[T]) Items(ctx context.Context) DBClientIteratorItem[T] { - return func(yield func(string, *T) bool) { - for iter.pager.More() { - response, err := iter.pager.NextPage(ctx) - if err != nil { - iter.err = err - return - } - if iter.singlePage && response.ContinuationToken != nil { - iter.continuationToken = *response.ContinuationToken - } - for _, item := range response.Items { - typedDoc, innerDoc, err := typedDocumentUnmarshal[T](item) - if err != nil { - iter.err = err - return - } - - if !yield(typedDoc.ID, innerDoc) { - return - } - } - if iter.singlePage { - return - } - } - } -} - -// GetContinuationToken returns a continuation token that can be used to obtain -// the next page of results. This is only set when the iterator was created with -// NewQueryItemsSinglePageIterator and additional items are available. -func (iter queryItemsIterator[T]) GetContinuationToken() string { - return iter.continuationToken -} - -// GetError returns any error that occurred during iteration. Call this after the -// for/range loop that calls Items() to check if iteration completed successfully. -func (iter queryItemsIterator[T]) GetError() error { - return iter.err -} diff --git a/internal/mocks/crud_subscription.go b/internal/mocks/crud_subscription.go new file mode 100644 index 0000000000..8fc8872bf9 --- /dev/null +++ b/internal/mocks/crud_subscription.go @@ -0,0 +1,356 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ../database/crud_subscription.go +// +// Generated by this command: +// +// mockgen-v0.5.0 -typed -source=../database/crud_subscription.go -destination=crud_subscription.go -package mocks github.com/Azure/ARO-HCP/internal/database SubscriptionCRUD +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + azcosmos "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" + gomock "go.uber.org/mock/gomock" + + arm "github.com/Azure/ARO-HCP/internal/api/arm" + database "github.com/Azure/ARO-HCP/internal/database" +) + +// MockSubscriptionCRUD is a mock of SubscriptionCRUD interface. +type MockSubscriptionCRUD struct { + ctrl *gomock.Controller + recorder *MockSubscriptionCRUDMockRecorder + isgomock struct{} +} + +// MockSubscriptionCRUDMockRecorder is the mock recorder for MockSubscriptionCRUD. +type MockSubscriptionCRUDMockRecorder struct { + mock *MockSubscriptionCRUD +} + +// NewMockSubscriptionCRUD creates a new mock instance. +func NewMockSubscriptionCRUD(ctrl *gomock.Controller) *MockSubscriptionCRUD { + mock := &MockSubscriptionCRUD{ctrl: ctrl} + mock.recorder = &MockSubscriptionCRUDMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSubscriptionCRUD) EXPECT() *MockSubscriptionCRUDMockRecorder { + return m.recorder +} + +// AddCreateToTransaction mocks base method. +func (m *MockSubscriptionCRUD) AddCreateToTransaction(ctx context.Context, transaction database.DBTransaction, newObj *arm.Subscription, opts *azcosmos.TransactionalBatchItemOptions) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddCreateToTransaction", ctx, transaction, newObj, opts) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddCreateToTransaction indicates an expected call of AddCreateToTransaction. +func (mr *MockSubscriptionCRUDMockRecorder) AddCreateToTransaction(ctx, transaction, newObj, opts any) *MockSubscriptionCRUDAddCreateToTransactionCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddCreateToTransaction", reflect.TypeOf((*MockSubscriptionCRUD)(nil).AddCreateToTransaction), ctx, transaction, newObj, opts) + return &MockSubscriptionCRUDAddCreateToTransactionCall{Call: call} +} + +// MockSubscriptionCRUDAddCreateToTransactionCall wrap *gomock.Call +type MockSubscriptionCRUDAddCreateToTransactionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDAddCreateToTransactionCall) Return(arg0 string, arg1 error) *MockSubscriptionCRUDAddCreateToTransactionCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDAddCreateToTransactionCall) Do(f func(context.Context, database.DBTransaction, *arm.Subscription, *azcosmos.TransactionalBatchItemOptions) (string, error)) *MockSubscriptionCRUDAddCreateToTransactionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDAddCreateToTransactionCall) DoAndReturn(f func(context.Context, database.DBTransaction, *arm.Subscription, *azcosmos.TransactionalBatchItemOptions) (string, error)) *MockSubscriptionCRUDAddCreateToTransactionCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// AddReplaceToTransaction mocks base method. +func (m *MockSubscriptionCRUD) AddReplaceToTransaction(ctx context.Context, transaction database.DBTransaction, newObj *arm.Subscription, opts *azcosmos.TransactionalBatchItemOptions) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddReplaceToTransaction", ctx, transaction, newObj, opts) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AddReplaceToTransaction indicates an expected call of AddReplaceToTransaction. +func (mr *MockSubscriptionCRUDMockRecorder) AddReplaceToTransaction(ctx, transaction, newObj, opts any) *MockSubscriptionCRUDAddReplaceToTransactionCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddReplaceToTransaction", reflect.TypeOf((*MockSubscriptionCRUD)(nil).AddReplaceToTransaction), ctx, transaction, newObj, opts) + return &MockSubscriptionCRUDAddReplaceToTransactionCall{Call: call} +} + +// MockSubscriptionCRUDAddReplaceToTransactionCall wrap *gomock.Call +type MockSubscriptionCRUDAddReplaceToTransactionCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDAddReplaceToTransactionCall) Return(arg0 string, arg1 error) *MockSubscriptionCRUDAddReplaceToTransactionCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDAddReplaceToTransactionCall) Do(f func(context.Context, database.DBTransaction, *arm.Subscription, *azcosmos.TransactionalBatchItemOptions) (string, error)) *MockSubscriptionCRUDAddReplaceToTransactionCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDAddReplaceToTransactionCall) DoAndReturn(f func(context.Context, database.DBTransaction, *arm.Subscription, *azcosmos.TransactionalBatchItemOptions) (string, error)) *MockSubscriptionCRUDAddReplaceToTransactionCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Create mocks base method. +func (m *MockSubscriptionCRUD) Create(ctx context.Context, newObj *arm.Subscription, options *azcosmos.ItemOptions) (*arm.Subscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", ctx, newObj, options) + ret0, _ := ret[0].(*arm.Subscription) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Create indicates an expected call of Create. +func (mr *MockSubscriptionCRUDMockRecorder) Create(ctx, newObj, options any) *MockSubscriptionCRUDCreateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockSubscriptionCRUD)(nil).Create), ctx, newObj, options) + return &MockSubscriptionCRUDCreateCall{Call: call} +} + +// MockSubscriptionCRUDCreateCall wrap *gomock.Call +type MockSubscriptionCRUDCreateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDCreateCall) Return(arg0 *arm.Subscription, arg1 error) *MockSubscriptionCRUDCreateCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDCreateCall) Do(f func(context.Context, *arm.Subscription, *azcosmos.ItemOptions) (*arm.Subscription, error)) *MockSubscriptionCRUDCreateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDCreateCall) DoAndReturn(f func(context.Context, *arm.Subscription, *azcosmos.ItemOptions) (*arm.Subscription, error)) *MockSubscriptionCRUDCreateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Delete mocks base method. +func (m *MockSubscriptionCRUD) Delete(ctx context.Context, resourceID string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", ctx, resourceID) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockSubscriptionCRUDMockRecorder) Delete(ctx, resourceID any) *MockSubscriptionCRUDDeleteCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockSubscriptionCRUD)(nil).Delete), ctx, resourceID) + return &MockSubscriptionCRUDDeleteCall{Call: call} +} + +// MockSubscriptionCRUDDeleteCall wrap *gomock.Call +type MockSubscriptionCRUDDeleteCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDDeleteCall) Return(arg0 error) *MockSubscriptionCRUDDeleteCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDDeleteCall) Do(f func(context.Context, string) error) *MockSubscriptionCRUDDeleteCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDDeleteCall) DoAndReturn(f func(context.Context, string) error) *MockSubscriptionCRUDDeleteCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Get mocks base method. +func (m *MockSubscriptionCRUD) Get(ctx context.Context, resourceID string) (*arm.Subscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceID) + ret0, _ := ret[0].(*arm.Subscription) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockSubscriptionCRUDMockRecorder) Get(ctx, resourceID any) *MockSubscriptionCRUDGetCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSubscriptionCRUD)(nil).Get), ctx, resourceID) + return &MockSubscriptionCRUDGetCall{Call: call} +} + +// MockSubscriptionCRUDGetCall wrap *gomock.Call +type MockSubscriptionCRUDGetCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDGetCall) Return(arg0 *arm.Subscription, arg1 error) *MockSubscriptionCRUDGetCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDGetCall) Do(f func(context.Context, string) (*arm.Subscription, error)) *MockSubscriptionCRUDGetCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDGetCall) DoAndReturn(f func(context.Context, string) (*arm.Subscription, error)) *MockSubscriptionCRUDGetCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// GetByID mocks base method. +func (m *MockSubscriptionCRUD) GetByID(ctx context.Context, cosmosID string) (*arm.Subscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetByID", ctx, cosmosID) + ret0, _ := ret[0].(*arm.Subscription) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetByID indicates an expected call of GetByID. +func (mr *MockSubscriptionCRUDMockRecorder) GetByID(ctx, cosmosID any) *MockSubscriptionCRUDGetByIDCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetByID", reflect.TypeOf((*MockSubscriptionCRUD)(nil).GetByID), ctx, cosmosID) + return &MockSubscriptionCRUDGetByIDCall{Call: call} +} + +// MockSubscriptionCRUDGetByIDCall wrap *gomock.Call +type MockSubscriptionCRUDGetByIDCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDGetByIDCall) Return(arg0 *arm.Subscription, arg1 error) *MockSubscriptionCRUDGetByIDCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDGetByIDCall) Do(f func(context.Context, string) (*arm.Subscription, error)) *MockSubscriptionCRUDGetByIDCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDGetByIDCall) DoAndReturn(f func(context.Context, string) (*arm.Subscription, error)) *MockSubscriptionCRUDGetByIDCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// List mocks base method. +func (m *MockSubscriptionCRUD) List(ctx context.Context, opts *database.DBClientListResourceDocsOptions) (database.DBClientIterator[arm.Subscription], error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", ctx, opts) + ret0, _ := ret[0].(database.DBClientIterator[arm.Subscription]) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// List indicates an expected call of List. +func (mr *MockSubscriptionCRUDMockRecorder) List(ctx, opts any) *MockSubscriptionCRUDListCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockSubscriptionCRUD)(nil).List), ctx, opts) + return &MockSubscriptionCRUDListCall{Call: call} +} + +// MockSubscriptionCRUDListCall wrap *gomock.Call +type MockSubscriptionCRUDListCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDListCall) Return(arg0 database.DBClientIterator[arm.Subscription], arg1 error) *MockSubscriptionCRUDListCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDListCall) Do(f func(context.Context, *database.DBClientListResourceDocsOptions) (database.DBClientIterator[arm.Subscription], error)) *MockSubscriptionCRUDListCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDListCall) DoAndReturn(f func(context.Context, *database.DBClientListResourceDocsOptions) (database.DBClientIterator[arm.Subscription], error)) *MockSubscriptionCRUDListCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Replace mocks base method. +func (m *MockSubscriptionCRUD) Replace(ctx context.Context, newObj *arm.Subscription, options *azcosmos.ItemOptions) (*arm.Subscription, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Replace", ctx, newObj, options) + ret0, _ := ret[0].(*arm.Subscription) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Replace indicates an expected call of Replace. +func (mr *MockSubscriptionCRUDMockRecorder) Replace(ctx, newObj, options any) *MockSubscriptionCRUDReplaceCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Replace", reflect.TypeOf((*MockSubscriptionCRUD)(nil).Replace), ctx, newObj, options) + return &MockSubscriptionCRUDReplaceCall{Call: call} +} + +// MockSubscriptionCRUDReplaceCall wrap *gomock.Call +type MockSubscriptionCRUDReplaceCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockSubscriptionCRUDReplaceCall) Return(arg0 *arm.Subscription, arg1 error) *MockSubscriptionCRUDReplaceCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockSubscriptionCRUDReplaceCall) Do(f func(context.Context, *arm.Subscription, *azcosmos.ItemOptions) (*arm.Subscription, error)) *MockSubscriptionCRUDReplaceCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockSubscriptionCRUDReplaceCall) DoAndReturn(f func(context.Context, *arm.Subscription, *azcosmos.ItemOptions) (*arm.Subscription, error)) *MockSubscriptionCRUDReplaceCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/internal/mocks/dbclient.go b/internal/mocks/dbclient.go index 215e6adb75..28832a3e5b 100644 --- a/internal/mocks/dbclient.go +++ b/internal/mocks/dbclient.go @@ -13,10 +13,9 @@ import ( context "context" reflect "reflect" - arm0 "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" + arm "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" gomock "go.uber.org/mock/gomock" - arm "github.com/Azure/ARO-HCP/internal/api/arm" database "github.com/Azure/ARO-HCP/internal/database" ) @@ -220,44 +219,6 @@ func (c *MockDBClientCreateBillingDocCall) DoAndReturn(f func(context.Context, * return c } -// CreateSubscriptionDoc mocks base method. -func (m *MockDBClient) CreateSubscriptionDoc(ctx context.Context, subscriptionID string, subscription *arm.Subscription) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateSubscriptionDoc", ctx, subscriptionID, subscription) - ret0, _ := ret[0].(error) - return ret0 -} - -// CreateSubscriptionDoc indicates an expected call of CreateSubscriptionDoc. -func (mr *MockDBClientMockRecorder) CreateSubscriptionDoc(ctx, subscriptionID, subscription any) *MockDBClientCreateSubscriptionDocCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateSubscriptionDoc", reflect.TypeOf((*MockDBClient)(nil).CreateSubscriptionDoc), ctx, subscriptionID, subscription) - return &MockDBClientCreateSubscriptionDocCall{Call: call} -} - -// MockDBClientCreateSubscriptionDocCall wrap *gomock.Call -type MockDBClientCreateSubscriptionDocCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDBClientCreateSubscriptionDocCall) Return(arg0 error) *MockDBClientCreateSubscriptionDocCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDBClientCreateSubscriptionDocCall) Do(f func(context.Context, string, *arm.Subscription) error) *MockDBClientCreateSubscriptionDocCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDBClientCreateSubscriptionDocCall) DoAndReturn(f func(context.Context, string, *arm.Subscription) error) *MockDBClientCreateSubscriptionDocCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // GetLockClient mocks base method. func (m *MockDBClient) GetLockClient() database.LockClientInterface { m.ctrl.T.Helper() @@ -296,45 +257,6 @@ func (c *MockDBClientGetLockClientCall) DoAndReturn(f func() database.LockClient return c } -// GetSubscriptionDoc mocks base method. -func (m *MockDBClient) GetSubscriptionDoc(ctx context.Context, subscriptionID string) (*arm.Subscription, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSubscriptionDoc", ctx, subscriptionID) - ret0, _ := ret[0].(*arm.Subscription) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSubscriptionDoc indicates an expected call of GetSubscriptionDoc. -func (mr *MockDBClientMockRecorder) GetSubscriptionDoc(ctx, subscriptionID any) *MockDBClientGetSubscriptionDocCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubscriptionDoc", reflect.TypeOf((*MockDBClient)(nil).GetSubscriptionDoc), ctx, subscriptionID) - return &MockDBClientGetSubscriptionDocCall{Call: call} -} - -// MockDBClientGetSubscriptionDocCall wrap *gomock.Call -type MockDBClientGetSubscriptionDocCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDBClientGetSubscriptionDocCall) Return(arg0 *arm.Subscription, arg1 error) *MockDBClientGetSubscriptionDocCall { - c.Call = c.Call.Return(arg0, arg1) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDBClientGetSubscriptionDocCall) Do(f func(context.Context, string) (*arm.Subscription, error)) *MockDBClientGetSubscriptionDocCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDBClientGetSubscriptionDocCall) DoAndReturn(f func(context.Context, string) (*arm.Subscription, error)) *MockDBClientGetSubscriptionDocCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // HCPClusters mocks base method. func (m *MockDBClient) HCPClusters(subscriptionID, resourceGroupName string) database.HCPClusterCRUD { m.ctrl.T.Helper() @@ -373,44 +295,6 @@ func (c *MockDBClientHCPClustersCall) DoAndReturn(f func(string, string) databas return c } -// ListAllSubscriptionDocs mocks base method. -func (m *MockDBClient) ListAllSubscriptionDocs() database.DBClientIterator[arm.Subscription] { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListAllSubscriptionDocs") - ret0, _ := ret[0].(database.DBClientIterator[arm.Subscription]) - return ret0 -} - -// ListAllSubscriptionDocs indicates an expected call of ListAllSubscriptionDocs. -func (mr *MockDBClientMockRecorder) ListAllSubscriptionDocs() *MockDBClientListAllSubscriptionDocsCall { - mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAllSubscriptionDocs", reflect.TypeOf((*MockDBClient)(nil).ListAllSubscriptionDocs)) - return &MockDBClientListAllSubscriptionDocsCall{Call: call} -} - -// MockDBClientListAllSubscriptionDocsCall wrap *gomock.Call -type MockDBClientListAllSubscriptionDocsCall struct { - *gomock.Call -} - -// Return rewrite *gomock.Call.Return -func (c *MockDBClientListAllSubscriptionDocsCall) Return(arg0 database.DBClientIterator[arm.Subscription]) *MockDBClientListAllSubscriptionDocsCall { - c.Call = c.Call.Return(arg0) - return c -} - -// Do rewrite *gomock.Call.Do -func (c *MockDBClientListAllSubscriptionDocsCall) Do(f func() database.DBClientIterator[arm.Subscription]) *MockDBClientListAllSubscriptionDocsCall { - c.Call = c.Call.Do(f) - return c -} - -// DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDBClientListAllSubscriptionDocsCall) DoAndReturn(f func() database.DBClientIterator[arm.Subscription]) *MockDBClientListAllSubscriptionDocsCall { - c.Call = c.Call.DoAndReturn(f) - return c -} - // NewTransaction mocks base method. func (m *MockDBClient) NewTransaction(pk string) database.DBTransaction { m.ctrl.T.Helper() @@ -488,7 +372,7 @@ func (c *MockDBClientOperationsCall) DoAndReturn(f func(string) database.Operati } // PatchBillingDoc mocks base method. -func (m *MockDBClient) PatchBillingDoc(ctx context.Context, resourceID *arm0.ResourceID, ops database.BillingDocumentPatchOperations) error { +func (m *MockDBClient) PatchBillingDoc(ctx context.Context, resourceID *arm.ResourceID, ops database.BillingDocumentPatchOperations) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "PatchBillingDoc", ctx, resourceID, ops) ret0, _ := ret[0].(error) @@ -514,91 +398,90 @@ func (c *MockDBClientPatchBillingDocCall) Return(arg0 error) *MockDBClientPatchB } // Do rewrite *gomock.Call.Do -func (c *MockDBClientPatchBillingDocCall) Do(f func(context.Context, *arm0.ResourceID, database.BillingDocumentPatchOperations) error) *MockDBClientPatchBillingDocCall { +func (c *MockDBClientPatchBillingDocCall) Do(f func(context.Context, *arm.ResourceID, database.BillingDocumentPatchOperations) error) *MockDBClientPatchBillingDocCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDBClientPatchBillingDocCall) DoAndReturn(f func(context.Context, *arm0.ResourceID, database.BillingDocumentPatchOperations) error) *MockDBClientPatchBillingDocCall { +func (c *MockDBClientPatchBillingDocCall) DoAndReturn(f func(context.Context, *arm.ResourceID, database.BillingDocumentPatchOperations) error) *MockDBClientPatchBillingDocCall { c.Call = c.Call.DoAndReturn(f) return c } -// UntypedCRUD mocks base method. -func (m *MockDBClient) UntypedCRUD(parentResourceID arm0.ResourceID) (database.UntypedResourceCRUD, error) { +// Subscriptions mocks base method. +func (m *MockDBClient) Subscriptions() database.SubscriptionCRUD { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UntypedCRUD", parentResourceID) - ret0, _ := ret[0].(database.UntypedResourceCRUD) - ret1, _ := ret[1].(error) - return ret0, ret1 + ret := m.ctrl.Call(m, "Subscriptions") + ret0, _ := ret[0].(database.SubscriptionCRUD) + return ret0 } -// UntypedCRUD indicates an expected call of UntypedCRUD. -func (mr *MockDBClientMockRecorder) UntypedCRUD(parentResourceID any) *MockDBClientUntypedCRUDCall { +// Subscriptions indicates an expected call of Subscriptions. +func (mr *MockDBClientMockRecorder) Subscriptions() *MockDBClientSubscriptionsCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UntypedCRUD", reflect.TypeOf((*MockDBClient)(nil).UntypedCRUD), parentResourceID) - return &MockDBClientUntypedCRUDCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscriptions", reflect.TypeOf((*MockDBClient)(nil).Subscriptions)) + return &MockDBClientSubscriptionsCall{Call: call} } -// MockDBClientUntypedCRUDCall wrap *gomock.Call -type MockDBClientUntypedCRUDCall struct { +// MockDBClientSubscriptionsCall wrap *gomock.Call +type MockDBClientSubscriptionsCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockDBClientUntypedCRUDCall) Return(arg0 database.UntypedResourceCRUD, arg1 error) *MockDBClientUntypedCRUDCall { - c.Call = c.Call.Return(arg0, arg1) +func (c *MockDBClientSubscriptionsCall) Return(arg0 database.SubscriptionCRUD) *MockDBClientSubscriptionsCall { + c.Call = c.Call.Return(arg0) return c } // Do rewrite *gomock.Call.Do -func (c *MockDBClientUntypedCRUDCall) Do(f func(arm0.ResourceID) (database.UntypedResourceCRUD, error)) *MockDBClientUntypedCRUDCall { +func (c *MockDBClientSubscriptionsCall) Do(f func() database.SubscriptionCRUD) *MockDBClientSubscriptionsCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDBClientUntypedCRUDCall) DoAndReturn(f func(arm0.ResourceID) (database.UntypedResourceCRUD, error)) *MockDBClientUntypedCRUDCall { +func (c *MockDBClientSubscriptionsCall) DoAndReturn(f func() database.SubscriptionCRUD) *MockDBClientSubscriptionsCall { c.Call = c.Call.DoAndReturn(f) return c } -// UpdateSubscriptionDoc mocks base method. -func (m *MockDBClient) UpdateSubscriptionDoc(ctx context.Context, subscriptionID string, callback func(*arm.Subscription) bool) (bool, error) { +// UntypedCRUD mocks base method. +func (m *MockDBClient) UntypedCRUD(parentResourceID arm.ResourceID) (database.UntypedResourceCRUD, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateSubscriptionDoc", ctx, subscriptionID, callback) - ret0, _ := ret[0].(bool) + ret := m.ctrl.Call(m, "UntypedCRUD", parentResourceID) + ret0, _ := ret[0].(database.UntypedResourceCRUD) ret1, _ := ret[1].(error) return ret0, ret1 } -// UpdateSubscriptionDoc indicates an expected call of UpdateSubscriptionDoc. -func (mr *MockDBClientMockRecorder) UpdateSubscriptionDoc(ctx, subscriptionID, callback any) *MockDBClientUpdateSubscriptionDocCall { +// UntypedCRUD indicates an expected call of UntypedCRUD. +func (mr *MockDBClientMockRecorder) UntypedCRUD(parentResourceID any) *MockDBClientUntypedCRUDCall { mr.mock.ctrl.T.Helper() - call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateSubscriptionDoc", reflect.TypeOf((*MockDBClient)(nil).UpdateSubscriptionDoc), ctx, subscriptionID, callback) - return &MockDBClientUpdateSubscriptionDocCall{Call: call} + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UntypedCRUD", reflect.TypeOf((*MockDBClient)(nil).UntypedCRUD), parentResourceID) + return &MockDBClientUntypedCRUDCall{Call: call} } -// MockDBClientUpdateSubscriptionDocCall wrap *gomock.Call -type MockDBClientUpdateSubscriptionDocCall struct { +// MockDBClientUntypedCRUDCall wrap *gomock.Call +type MockDBClientUntypedCRUDCall struct { *gomock.Call } // Return rewrite *gomock.Call.Return -func (c *MockDBClientUpdateSubscriptionDocCall) Return(arg0 bool, arg1 error) *MockDBClientUpdateSubscriptionDocCall { +func (c *MockDBClientUntypedCRUDCall) Return(arg0 database.UntypedResourceCRUD, arg1 error) *MockDBClientUntypedCRUDCall { c.Call = c.Call.Return(arg0, arg1) return c } // Do rewrite *gomock.Call.Do -func (c *MockDBClientUpdateSubscriptionDocCall) Do(f func(context.Context, string, func(*arm.Subscription) bool) (bool, error)) *MockDBClientUpdateSubscriptionDocCall { +func (c *MockDBClientUntypedCRUDCall) Do(f func(arm.ResourceID) (database.UntypedResourceCRUD, error)) *MockDBClientUntypedCRUDCall { c.Call = c.Call.Do(f) return c } // DoAndReturn rewrite *gomock.Call.DoAndReturn -func (c *MockDBClientUpdateSubscriptionDocCall) DoAndReturn(f func(context.Context, string, func(*arm.Subscription) bool) (bool, error)) *MockDBClientUpdateSubscriptionDocCall { +func (c *MockDBClientUntypedCRUDCall) DoAndReturn(f func(arm.ResourceID) (database.UntypedResourceCRUD, error)) *MockDBClientUntypedCRUDCall { c.Call = c.Call.DoAndReturn(f) return c } diff --git a/internal/mocks/generate.go b/internal/mocks/generate.go index 24de1e7b40..00cd3acea1 100644 --- a/internal/mocks/generate.go +++ b/internal/mocks/generate.go @@ -17,6 +17,7 @@ package mocks //go:generate $MOCKGEN -typed -source=../database/database.go -destination=dbclient.go -package mocks github.com/Azure/ARO-HCP/internal/database DBClient //go:generate $MOCKGEN -typed -source=../database/crud_hcpcluster.go -destination=crud_hcpcluster.go -package mocks github.com/Azure/ARO-HCP/internal/database OperationCRUD //go:generate $MOCKGEN -typed -source=../database/crud_untyped_resource.go -destination=crud_untyped_resource.go -package mocks github.com/Azure/ARO-HCP/internal/database UntypedResourceCRUD +//go:generate $MOCKGEN -typed -source=../database/crud_subscription.go -destination=crud_subscription.go -package mocks github.com/Azure/ARO-HCP/internal/database SubscriptionCRUD //go:generate $MOCKGEN -typed -source=../database/lock.go -destination=lock.go -package mocks github.com/Azure/ARO-HCP/internal/database LockClientInterface //go:generate $MOCKGEN -typed -source=../database/transaction.go -destination=dbtransaction.go -package mocks github.com/Azure/ARO-HCP/internal/database DBTransaction DBTransactionResult //go:generate $MOCKGEN -typed -source=../ocm/client.go -destination=ocm.go -package mocks github.com/Azure/ARO-HCP/internal/ocm ClusterServiceClientSpec diff --git a/test-integration/frontend/database_crud_test.go b/test-integration/frontend/database_crud_test.go index fe9522211d..1afc824d33 100644 --- a/test-integration/frontend/database_crud_test.go +++ b/test-integration/frontend/database_crud_test.go @@ -21,8 +21,6 @@ import ( "github.com/stretchr/testify/require" - "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" - "github.com/Azure/ARO-HCP/internal/api" "github.com/Azure/ARO-HCP/test-integration/utils/databasemutationhelpers" "github.com/Azure/ARO-HCP/test-integration/utils/integrationutils" @@ -35,10 +33,6 @@ func TestDatabaseCRUD(t *testing.T) { ctx, cancel := context.WithCancel(ctx) defer cancel() - _, testInfo, err := integrationutils.NewFrontendFromTestingEnv(ctx, t) - require.NoError(t, err) - defer testInfo.Cleanup(context.Background()) - allCRUDDirFS, err := fs.Sub(artifacts, "artifacts/DatabaseCRUD") require.NoError(t, err) @@ -52,7 +46,6 @@ func TestDatabaseCRUD(t *testing.T) { ctx, t, databasemutationhelpers.ControllerCRUDSpecializer{}, - testInfo.CosmosResourcesContainer(), crudSuiteDir) }) @@ -62,7 +55,15 @@ func TestDatabaseCRUD(t *testing.T) { ctx, t, databasemutationhelpers.OperationCRUDSpecializer{}, - testInfo.CosmosResourcesContainer(), + crudSuiteDir) + }) + + case "SubscriptionCRUD": + t.Run(crudSuiteDirEntry.Name(), func(t *testing.T) { + testCRUDSuite( + ctx, + t, + databasemutationhelpers.SubscriptionCRUDSpecializer{}, crudSuiteDir) }) @@ -72,7 +73,6 @@ func TestDatabaseCRUD(t *testing.T) { ctx, t, databasemutationhelpers.OperationCRUDSpecializer{}, - testInfo.CosmosResourcesContainer(), crudSuiteDir) }) @@ -82,7 +82,7 @@ func TestDatabaseCRUD(t *testing.T) { } } -func testCRUDSuite[InternalAPIType any](ctx context.Context, t *testing.T, specializer databasemutationhelpers.ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, crudSuiteDir fs.FS) { +func testCRUDSuite[InternalAPIType any](ctx context.Context, t *testing.T, specializer databasemutationhelpers.ResourceCRUDTestSpecializer[InternalAPIType], crudSuiteDir fs.FS) { testDirs := api.Must(fs.ReadDir(crudSuiteDir, ".")) for _, testDirEntry := range testDirs { testDir := api.Must(fs.Sub(crudSuiteDir, testDirEntry.Name())) @@ -90,7 +90,6 @@ func testCRUDSuite[InternalAPIType any](ctx context.Context, t *testing.T, speci currTest, err := databasemutationhelpers.NewResourceMutationTest( ctx, specializer, - cosmosContainer, testDirEntry.Name(), testDir, ) diff --git a/test-integration/utils/controllertesthelpers/basic_controller.go b/test-integration/utils/controllertesthelpers/basic_controller.go index 3102f907b9..6cea4f739a 100644 --- a/test-integration/utils/controllertesthelpers/basic_controller.go +++ b/test-integration/utils/controllertesthelpers/basic_controller.go @@ -70,11 +70,10 @@ func (tc *BasicControllerTest) RunTest(t *testing.T) { if fsMightContainFiles(initialState) { loadInitialStateStep, err := databasemutationhelpers.NewLoadStep( databasemutationhelpers.NewStepID(00, "load", "initial-state"), - cosmosTestInfo.CosmosResourcesContainer(), initialState, ) require.NoError(t, err) - loadInitialStateStep.RunTest(ctx, t) + loadInitialStateStep.RunTest(ctx, t, cosmosTestInfo.CosmosResourcesContainer()) } controllerInstance, testMemory := tc.ControllerInitializerFn(ctx, t, cosmosTestInfo.DBClient) @@ -86,11 +85,10 @@ func (tc *BasicControllerTest) RunTest(t *testing.T) { if fsMightContainFiles(endState) { verifyEndStateStep, err := databasemutationhelpers.NewCosmosCompareStep( databasemutationhelpers.NewStepID(99, "cosmosCompare", "end-state"), - cosmosTestInfo.CosmosResourcesContainer(), endState, ) require.NoError(t, err) - verifyEndStateStep.RunTest(ctx, t) + verifyEndStateStep.RunTest(ctx, t, cosmosTestInfo.CosmosResourcesContainer()) } tc.ControllerVerifierFn(ctx, t, controllerInstance, testMemory) diff --git a/test-integration/utils/databasemutationhelpers/per_resource_crud.go b/test-integration/utils/databasemutationhelpers/per_resource_crud.go index 3d73e9b009..7927a4f39c 100644 --- a/test-integration/utils/databasemutationhelpers/per_resource_crud.go +++ b/test-integration/utils/databasemutationhelpers/per_resource_crud.go @@ -29,6 +29,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/ARO-HCP/internal/api" + "github.com/Azure/ARO-HCP/internal/api/arm" "github.com/Azure/ARO-HCP/internal/database" ) @@ -175,3 +176,29 @@ func (UntypedCRUDSpecializer) NameFromInstance(obj *database.TypedDocument) stri func (UntypedCRUDSpecializer) WriteCosmosID(newObj, oldObj *database.TypedDocument) { newObj.ID = oldObj.ID } + +type SubscriptionCRUDSpecializer struct { +} + +var _ ResourceCRUDTestSpecializer[arm.Subscription] = &SubscriptionCRUDSpecializer{} + +func (SubscriptionCRUDSpecializer) ResourceCRUDFromKey(t *testing.T, cosmosContainer *azcosmos.ContainerClient, key CosmosCRUDKey) database.ResourceCRUD[arm.Subscription] { + return database.NewSubscriptionCRUD(cosmosContainer) +} + +func (SubscriptionCRUDSpecializer) InstanceEquals(expected, actual *arm.Subscription) bool { + // clear the fields that don't compare + shallowExpected := *expected + shallowActual := *actual + shallowExpected.LastUpdated = 0 + shallowActual.LastUpdated = 0 + return equality.Semantic.DeepEqual(shallowExpected, shallowActual) +} + +func (SubscriptionCRUDSpecializer) NameFromInstance(obj *arm.Subscription) string { + return obj.ResourceID.Name +} + +func (SubscriptionCRUDSpecializer) WriteCosmosID(newObj, oldObj *arm.Subscription) { + newObj.ResourceID = oldObj.ResourceID +} diff --git a/test-integration/utils/databasemutationhelpers/resource_crud_test_util.go b/test-integration/utils/databasemutationhelpers/resource_crud_test_util.go index c462a68d3e..f7d9c57ec2 100644 --- a/test-integration/utils/databasemutationhelpers/resource_crud_test_util.go +++ b/test-integration/utils/databasemutationhelpers/resource_crud_test_util.go @@ -24,37 +24,38 @@ import ( "strings" "testing" + "github.com/stretchr/testify/require" + "github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos" "github.com/Azure/ARO-HCP/internal/api" "github.com/Azure/ARO-HCP/internal/utils" + "github.com/Azure/ARO-HCP/test-integration/utils/integrationutils" ) type ResourceMutationTest struct { - testDir fs.FS - cosmosContainer *azcosmos.ContainerClient + testDir fs.FS steps []IntegrationTestStep } type IntegrationTestStep interface { StepID() StepID - RunTest(ctx context.Context, t *testing.T) + RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) } -func NewResourceMutationTest[InternalAPIType any](ctx context.Context, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, testName string, testDir fs.FS) (*ResourceMutationTest, error) { - steps, err := readSteps(ctx, testDir, specializer, cosmosContainer) +func NewResourceMutationTest[InternalAPIType any](ctx context.Context, specializer ResourceCRUDTestSpecializer[InternalAPIType], testName string, testDir fs.FS) (*ResourceMutationTest, error) { + steps, err := readSteps(ctx, testDir, specializer) if err != nil { return nil, fmt.Errorf("failed to read steps for test %q: %w", testName, err) } return &ResourceMutationTest{ - testDir: testDir, - cosmosContainer: cosmosContainer, - steps: steps, + testDir: testDir, + steps: steps, }, nil } -func readSteps[InternalAPIType any](ctx context.Context, testDir fs.FS, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient) ([]IntegrationTestStep, error) { +func readSteps[InternalAPIType any](ctx context.Context, testDir fs.FS, specializer ResourceCRUDTestSpecializer[InternalAPIType]) ([]IntegrationTestStep, error) { steps := []IntegrationTestStep{} testContent := api.Must(fs.ReadDir(testDir, ".")) @@ -72,7 +73,7 @@ func readSteps[InternalAPIType any](ctx context.Context, testDir fs.FS, speciali stepType := filenameParts[1] stepName, _ := strings.CutSuffix(filenameParts[2], ".json") - testStep, err := newStep(index, stepType, stepName, testDir, dirEntry.Name(), specializer, cosmosContainer) + testStep, err := newStep(index, stepType, stepName, testDir, dirEntry.Name(), specializer) if err != nil { return nil, fmt.Errorf("failed to create new step %q: %w", dirEntry.Name(), err) } @@ -84,13 +85,17 @@ func readSteps[InternalAPIType any](ctx context.Context, testDir fs.FS, speciali } func (tt *ResourceMutationTest) RunTest(t *testing.T) { + testInfo, err := integrationutils.NewCosmosFromTestingEnv(t.Context()) + require.NoError(t, err) + defer testInfo.Cleanup(context.Background()) + for _, step := range tt.steps { t.Logf("Running step %s", step.StepID()) - step.RunTest(t.Context(), t) + step.RunTest(t.Context(), t, testInfo.CosmosResourcesContainer()) } } -func newStep[InternalAPIType any](indexString, stepType, stepName string, testDir fs.FS, path string, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient) (IntegrationTestStep, error) { +func newStep[InternalAPIType any](indexString, stepType, stepName string, testDir fs.FS, path string, specializer ResourceCRUDTestSpecializer[InternalAPIType]) (IntegrationTestStep, error) { itoInt, err := strconv.Atoi(indexString) if err != nil { return nil, fmt.Errorf("failed to convert %s to int: %w", indexString, err) @@ -103,43 +108,43 @@ func newStep[InternalAPIType any](indexString, stepType, stepName string, testDi switch stepType { case "load": - return NewLoadStep(stepID, cosmosContainer, stepDir) + return NewLoadStep(stepID, stepDir) case "cosmosCompare": - return NewCosmosCompareStep(stepID, cosmosContainer, stepDir) + return NewCosmosCompareStep(stepID, stepDir) case "create": - return newCreateStep(stepID, specializer, cosmosContainer, stepDir) + return newCreateStep(stepID, specializer, stepDir) case "replace": - return newReplaceStep(stepID, specializer, cosmosContainer, stepDir) + return newReplaceStep(stepID, specializer, stepDir) case "get": - return newGetStep(stepID, specializer, cosmosContainer, stepDir) + return newGetStep(stepID, specializer, stepDir) case "getByID": - return newGetByIDStep(stepID, specializer, cosmosContainer, stepDir) + return newGetByIDStep(stepID, specializer, stepDir) case "untypedGet": - return newUntypedGetStep(stepID, cosmosContainer, stepDir) + return newUntypedGetStep(stepID, stepDir) case "list": - return newListStep(stepID, specializer, cosmosContainer, stepDir) + return newListStep(stepID, specializer, stepDir) case "listActiveOperations": - return newListActiveOperationsStep(stepID, cosmosContainer, stepDir) + return newListActiveOperationsStep(stepID, stepDir) case "untypedListRecursive": - return newUntypedListRecursiveStep(stepID, cosmosContainer, stepDir) + return newUntypedListRecursiveStep(stepID, stepDir) case "untypedList": - return newUntypedListStep(stepID, cosmosContainer, stepDir) + return newUntypedListStep(stepID, stepDir) case "delete": - return newDeleteStep(stepID, specializer, cosmosContainer, stepDir) + return newDeleteStep(stepID, specializer, stepDir) case "untypedDelete": - return newUntypedDeleteStep(stepID, cosmosContainer, stepDir) + return newUntypedDeleteStep(stepID, stepDir) default: return nil, fmt.Errorf("unknown step type: %s", stepType) diff --git a/test-integration/utils/databasemutationhelpers/step_cosmoscompare.go b/test-integration/utils/databasemutationhelpers/step_cosmoscompare.go index b508e0ce3c..0f17bfb2ac 100644 --- a/test-integration/utils/databasemutationhelpers/step_cosmoscompare.go +++ b/test-integration/utils/databasemutationhelpers/step_cosmoscompare.go @@ -32,11 +32,10 @@ import ( type cosmosCompare struct { stepID StepID - cosmosContainer *azcosmos.ContainerClient expectedContent []*database.TypedDocument } -func NewCosmosCompareStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*cosmosCompare, error) { +func NewCosmosCompareStep(stepID StepID, stepDir fs.FS) (*cosmosCompare, error) { expectedContent, err := readResourcesInDir[database.TypedDocument](stepDir) if err != nil { return nil, utils.TrackError(err) @@ -44,7 +43,6 @@ func NewCosmosCompareStep(stepID StepID, cosmosContainer *azcosmos.ContainerClie return &cosmosCompare{ stepID: stepID, - cosmosContainer: cosmosContainer, expectedContent: expectedContent, }, nil } @@ -55,14 +53,14 @@ func (l *cosmosCompare) StepID() StepID { return l.stepID } -func (l *cosmosCompare) RunTest(ctx context.Context, t *testing.T) { +func (l *cosmosCompare) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { // Query all documents in the container querySQL := "SELECT * FROM c" queryOptions := &azcosmos.QueryOptions{ QueryParameters: []azcosmos.QueryParameter{}, } - queryPager := l.cosmosContainer.NewQueryItemsPager(querySQL, azcosmos.PartitionKey{}, queryOptions) + queryPager := cosmosContainer.NewQueryItemsPager(querySQL, azcosmos.PartitionKey{}, queryOptions) allActual := []*database.TypedDocument{} for queryPager.More() { diff --git a/test-integration/utils/databasemutationhelpers/step_create.go b/test-integration/utils/databasemutationhelpers/step_create.go index 43ea0dd761..4b8e2fa592 100644 --- a/test-integration/utils/databasemutationhelpers/step_create.go +++ b/test-integration/utils/databasemutationhelpers/step_create.go @@ -31,11 +31,10 @@ type createStep[InternalAPIType any] struct { key CosmosCRUDKey specializer ResourceCRUDTestSpecializer[InternalAPIType] - cosmosContainer *azcosmos.ContainerClient - resources []*InternalAPIType + resources []*InternalAPIType } -func newCreateStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*createStep[InternalAPIType], error) { +func newCreateStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], stepDir fs.FS) (*createStep[InternalAPIType], error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -51,11 +50,10 @@ func newCreateStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDT } return &createStep[InternalAPIType]{ - stepID: stepID, - key: key, - specializer: specializer, - cosmosContainer: cosmosContainer, - resources: resources, + stepID: stepID, + key: key, + specializer: specializer, + resources: resources, }, nil } @@ -65,8 +63,8 @@ func (l *createStep[InternalAPIType]) StepID() StepID { return l.stepID } -func (l *createStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T) { - controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, l.cosmosContainer, l.key) +func (l *createStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { + controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, cosmosContainer, l.key) for _, resource := range l.resources { _, err := controllerCRUDClient.Create(ctx, resource, nil) diff --git a/test-integration/utils/databasemutationhelpers/step_delete.go b/test-integration/utils/databasemutationhelpers/step_delete.go index 1a4423a924..e90777a616 100644 --- a/test-integration/utils/databasemutationhelpers/step_delete.go +++ b/test-integration/utils/databasemutationhelpers/step_delete.go @@ -39,11 +39,10 @@ type deleteStep[InternalAPIType any] struct { key CosmosDeleteKey specializer ResourceCRUDTestSpecializer[InternalAPIType] - cosmosContainer *azcosmos.ContainerClient - expectedError string + expectedError string } -func newDeleteStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*deleteStep[InternalAPIType], error) { +func newDeleteStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], stepDir fs.FS) (*deleteStep[InternalAPIType], error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -60,11 +59,10 @@ func newDeleteStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDT expectedError := strings.TrimSpace(string(expectedErrorBytes)) return &deleteStep[InternalAPIType]{ - stepID: stepID, - key: key, - specializer: specializer, - cosmosContainer: cosmosContainer, - expectedError: expectedError, + stepID: stepID, + key: key, + specializer: specializer, + expectedError: expectedError, }, nil } @@ -74,8 +72,8 @@ func (l *deleteStep[InternalAPIType]) StepID() StepID { return l.stepID } -func (l *deleteStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T) { - controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, l.cosmosContainer, l.key.CosmosCRUDKey) +func (l *deleteStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { + controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, cosmosContainer, l.key.CosmosCRUDKey) err := controllerCRUDClient.Delete(ctx, l.key.DeleteResourceName) switch { case len(l.expectedError) > 0: diff --git a/test-integration/utils/databasemutationhelpers/step_get.go b/test-integration/utils/databasemutationhelpers/step_get.go index 9703c1c74c..c9fadfe715 100644 --- a/test-integration/utils/databasemutationhelpers/step_get.go +++ b/test-integration/utils/databasemutationhelpers/step_get.go @@ -33,12 +33,11 @@ type getStep[InternalAPIType any] struct { key CosmosCRUDKey specializer ResourceCRUDTestSpecializer[InternalAPIType] - cosmosContainer *azcosmos.ContainerClient expectedResource *InternalAPIType expectedError string } -func newGetStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*getStep[InternalAPIType], error) { +func newGetStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], stepDir fs.FS) (*getStep[InternalAPIType], error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -75,7 +74,6 @@ func newGetStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTest stepID: stepID, key: key, specializer: specializer, - cosmosContainer: cosmosContainer, expectedResource: expectedResource, expectedError: expectedError, }, nil @@ -87,8 +85,8 @@ func (l *getStep[InternalAPIType]) StepID() StepID { return l.stepID } -func (l *getStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T) { - controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, l.cosmosContainer, l.key) +func (l *getStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { + controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, cosmosContainer, l.key) resourceName := l.specializer.NameFromInstance(l.expectedResource) actualController, err := controllerCRUDClient.Get(ctx, resourceName) switch { diff --git a/test-integration/utils/databasemutationhelpers/step_getbyid.go b/test-integration/utils/databasemutationhelpers/step_getbyid.go index 24d3aac438..ac78603c7e 100644 --- a/test-integration/utils/databasemutationhelpers/step_getbyid.go +++ b/test-integration/utils/databasemutationhelpers/step_getbyid.go @@ -39,12 +39,11 @@ type getByIDStep[InternalAPIType any] struct { key GetByIDCRUDKey specializer ResourceCRUDTestSpecializer[InternalAPIType] - cosmosContainer *azcosmos.ContainerClient expectedResource *InternalAPIType expectedError string } -func newGetByIDStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*getByIDStep[InternalAPIType], error) { +func newGetByIDStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], stepDir fs.FS) (*getByIDStep[InternalAPIType], error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -81,7 +80,6 @@ func newGetByIDStep[InternalAPIType any](stepID StepID, specializer ResourceCRUD stepID: stepID, key: key, specializer: specializer, - cosmosContainer: cosmosContainer, expectedResource: expectedResource, expectedError: expectedError, }, nil @@ -93,8 +91,8 @@ func (l *getByIDStep[InternalAPIType]) StepID() StepID { return l.stepID } -func (l *getByIDStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T) { - controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, l.cosmosContainer, l.key.CosmosCRUDKey) +func (l *getByIDStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { + controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, cosmosContainer, l.key.CosmosCRUDKey) actualController, err := controllerCRUDClient.GetByID(ctx, l.key.CosmosID) switch { case len(l.expectedError) > 0: diff --git a/test-integration/utils/databasemutationhelpers/step_list.go b/test-integration/utils/databasemutationhelpers/step_list.go index 0e6ba5afd4..c7405cc417 100644 --- a/test-integration/utils/databasemutationhelpers/step_list.go +++ b/test-integration/utils/databasemutationhelpers/step_list.go @@ -31,11 +31,10 @@ type listStep[InternalAPIType any] struct { key CosmosCRUDKey specializer ResourceCRUDTestSpecializer[InternalAPIType] - cosmosContainer *azcosmos.ContainerClient expectedResources []*InternalAPIType } -func newListStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*listStep[InternalAPIType], error) { +func newListStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], stepDir fs.FS) (*listStep[InternalAPIType], error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -54,7 +53,6 @@ func newListStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTes stepID: stepID, key: key, specializer: specializer, - cosmosContainer: cosmosContainer, expectedResources: expectedResources, }, nil } @@ -65,8 +63,8 @@ func (l *listStep[InternalAPIType]) StepID() StepID { return l.stepID } -func (l *listStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T) { - controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, l.cosmosContainer, l.key) +func (l *listStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { + controllerCRUDClient := l.specializer.ResourceCRUDFromKey(t, cosmosContainer, l.key) actualControllersIterator, err := controllerCRUDClient.List(ctx, nil) require.NoError(t, err) diff --git a/test-integration/utils/databasemutationhelpers/step_list_active_operations.go b/test-integration/utils/databasemutationhelpers/step_list_active_operations.go index 2b56248f9d..4d2ecc2ff5 100644 --- a/test-integration/utils/databasemutationhelpers/step_list_active_operations.go +++ b/test-integration/utils/databasemutationhelpers/step_list_active_operations.go @@ -34,11 +34,10 @@ type listActiveOperationsStep struct { stepID StepID key CosmosCRUDKey - cosmosContainer *azcosmos.ContainerClient expectedOperations []*api.Operation } -func newListActiveOperationsStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*listActiveOperationsStep, error) { +func newListActiveOperationsStep(stepID StepID, stepDir fs.FS) (*listActiveOperationsStep, error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -56,7 +55,6 @@ func newListActiveOperationsStep(stepID StepID, cosmosContainer *azcosmos.Contai return &listActiveOperationsStep{ stepID: stepID, key: key, - cosmosContainer: cosmosContainer, expectedOperations: expectedResources, }, nil } @@ -67,11 +65,11 @@ func (l *listActiveOperationsStep) StepID() StepID { return l.stepID } -func (l *listActiveOperationsStep) RunTest(ctx context.Context, t *testing.T) { +func (l *listActiveOperationsStep) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { parentResourceID, err := azcorearm.ParseResourceID(l.key.ParentResourceID) require.NoError(t, err) - operationsCRUD := database.NewOperationCRUD(l.cosmosContainer, parentResourceID.SubscriptionID) + operationsCRUD := database.NewOperationCRUD(cosmosContainer, parentResourceID.SubscriptionID) actualControllersIterator := operationsCRUD.ListActiveOperations(nil) require.NoError(t, err) diff --git a/test-integration/utils/databasemutationhelpers/step_load.go b/test-integration/utils/databasemutationhelpers/step_load.go index 3b6d717b33..255e227408 100644 --- a/test-integration/utils/databasemutationhelpers/step_load.go +++ b/test-integration/utils/databasemutationhelpers/step_load.go @@ -32,11 +32,10 @@ import ( type loadStep struct { stepID StepID - cosmosContainer *azcosmos.ContainerClient - contents [][]byte + contents [][]byte } -func NewLoadStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*loadStep, error) { +func NewLoadStep(stepID StepID, stepDir fs.FS) (*loadStep, error) { contents := [][]byte{} testContent := api.Must(fs.ReadDir(stepDir, ".")) @@ -59,9 +58,8 @@ func NewLoadStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepD } return &loadStep{ - stepID: stepID, - cosmosContainer: cosmosContainer, - contents: contents, + stepID: stepID, + contents: contents, }, nil } @@ -71,9 +69,9 @@ func (l *loadStep) StepID() StepID { return l.stepID } -func (l *loadStep) RunTest(ctx context.Context, t *testing.T) { +func (l *loadStep) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { for _, content := range l.contents { - err := integrationutils.LoadCosmosContent(ctx, l.cosmosContainer, content) + err := integrationutils.LoadCosmosContent(ctx, cosmosContainer, content) require.NoError(t, err, "failed to load cosmos content: %v", string(content)) } } diff --git a/test-integration/utils/databasemutationhelpers/step_replace.go b/test-integration/utils/databasemutationhelpers/step_replace.go index 8ec5ff0803..a7a227ecad 100644 --- a/test-integration/utils/databasemutationhelpers/step_replace.go +++ b/test-integration/utils/databasemutationhelpers/step_replace.go @@ -31,11 +31,10 @@ type replaceStep[InternalAPIType any] struct { key CosmosCRUDKey specializer ResourceCRUDTestSpecializer[InternalAPIType] - cosmosContainer *azcosmos.ContainerClient - resources []*InternalAPIType + resources []*InternalAPIType } -func newReplaceStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*replaceStep[InternalAPIType], error) { +func newReplaceStep[InternalAPIType any](stepID StepID, specializer ResourceCRUDTestSpecializer[InternalAPIType], stepDir fs.FS) (*replaceStep[InternalAPIType], error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -51,11 +50,10 @@ func newReplaceStep[InternalAPIType any](stepID StepID, specializer ResourceCRUD } return &replaceStep[InternalAPIType]{ - stepID: stepID, - key: key, - specializer: specializer, - cosmosContainer: cosmosContainer, - resources: resources, + stepID: stepID, + key: key, + specializer: specializer, + resources: resources, }, nil } @@ -65,8 +63,8 @@ func (l *replaceStep[InternalAPIType]) StepID() StepID { return l.stepID } -func (l *replaceStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T) { - resourceCRUDClient := l.specializer.ResourceCRUDFromKey(t, l.cosmosContainer, l.key) +func (l *replaceStep[InternalAPIType]) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { + resourceCRUDClient := l.specializer.ResourceCRUDFromKey(t, cosmosContainer, l.key) for _, resource := range l.resources { // find the existing to set the UID for an replace to replace instead of creating a new record. diff --git a/test-integration/utils/databasemutationhelpers/step_untypeddelete.go b/test-integration/utils/databasemutationhelpers/step_untypeddelete.go index 7891c49d45..a7b7243b84 100644 --- a/test-integration/utils/databasemutationhelpers/step_untypeddelete.go +++ b/test-integration/utils/databasemutationhelpers/step_untypeddelete.go @@ -43,11 +43,10 @@ type untypedDeleteStep struct { key UntypedDeleteKey specializer ResourceCRUDTestSpecializer[database.TypedDocument] - cosmosContainer *azcosmos.ContainerClient - expectedError string + expectedError string } -func newUntypedDeleteStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*untypedDeleteStep, error) { +func newUntypedDeleteStep(stepID StepID, stepDir fs.FS) (*untypedDeleteStep, error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -64,11 +63,10 @@ func newUntypedDeleteStep(stepID StepID, cosmosContainer *azcosmos.ContainerClie expectedError := strings.TrimSpace(string(expectedErrorBytes)) return &untypedDeleteStep{ - stepID: stepID, - key: key, - specializer: UntypedCRUDSpecializer{}, - cosmosContainer: cosmosContainer, - expectedError: expectedError, + stepID: stepID, + key: key, + specializer: UntypedCRUDSpecializer{}, + expectedError: expectedError, }, nil } @@ -78,11 +76,11 @@ func (l *untypedDeleteStep) StepID() StepID { return l.stepID } -func (l *untypedDeleteStep) RunTest(ctx context.Context, t *testing.T) { +func (l *untypedDeleteStep) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { parentResourceID, err := azcorearm.ParseResourceID(l.key.ParentResourceID) require.NoError(t, err) - untypedCRUD := database.NewUntypedCRUD(l.cosmosContainer, *parentResourceID) + untypedCRUD := database.NewUntypedCRUD(cosmosContainer, *parentResourceID) for _, childKey := range l.key.Descendents { childResourceType, err := azcorearm.ParseResourceType(childKey.ResourceType) require.NoError(t, err) diff --git a/test-integration/utils/databasemutationhelpers/step_untypedget.go b/test-integration/utils/databasemutationhelpers/step_untypedget.go index 1aa6b0de7b..8c462e58c7 100644 --- a/test-integration/utils/databasemutationhelpers/step_untypedget.go +++ b/test-integration/utils/databasemutationhelpers/step_untypedget.go @@ -36,12 +36,11 @@ type untypedGetStep struct { key UntypedCRUDKey specializer ResourceCRUDTestSpecializer[database.TypedDocument] - cosmosContainer *azcosmos.ContainerClient expectedResource *database.TypedDocument expectedError string } -func newUntypedGetStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*untypedGetStep, error) { +func newUntypedGetStep(stepID StepID, stepDir fs.FS) (*untypedGetStep, error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -78,7 +77,6 @@ func newUntypedGetStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepID: stepID, key: key, specializer: UntypedCRUDSpecializer{}, - cosmosContainer: cosmosContainer, expectedResource: expectedResource, expectedError: expectedError, }, nil @@ -90,11 +88,11 @@ func (l *untypedGetStep) StepID() StepID { return l.stepID } -func (l *untypedGetStep) RunTest(ctx context.Context, t *testing.T) { +func (l *untypedGetStep) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { parentResourceID, err := azcorearm.ParseResourceID(l.key.ParentResourceID) require.NoError(t, err) - untypedCRUD := database.NewUntypedCRUD(l.cosmosContainer, *parentResourceID) + untypedCRUD := database.NewUntypedCRUD(cosmosContainer, *parentResourceID) for _, childKey := range l.key.Descendents { childResourceType, err := azcorearm.ParseResourceType(childKey.ResourceType) require.NoError(t, err) diff --git a/test-integration/utils/databasemutationhelpers/step_untypedlist.go b/test-integration/utils/databasemutationhelpers/step_untypedlist.go index 86521e86a2..81541ecda8 100644 --- a/test-integration/utils/databasemutationhelpers/step_untypedlist.go +++ b/test-integration/utils/databasemutationhelpers/step_untypedlist.go @@ -34,11 +34,10 @@ type untypedListStep struct { key UntypedCRUDKey specializer ResourceCRUDTestSpecializer[database.TypedDocument] - cosmosContainer *azcosmos.ContainerClient expectedResources []*database.TypedDocument } -func newUntypedListStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*untypedListStep, error) { +func newUntypedListStep(stepID StepID, stepDir fs.FS) (*untypedListStep, error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -57,7 +56,6 @@ func newUntypedListStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient stepID: stepID, key: key, specializer: UntypedCRUDSpecializer{}, - cosmosContainer: cosmosContainer, expectedResources: expectedResources, }, nil } @@ -68,11 +66,11 @@ func (l *untypedListStep) StepID() StepID { return l.stepID } -func (l *untypedListStep) RunTest(ctx context.Context, t *testing.T) { +func (l *untypedListStep) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { parentResourceID, err := azcorearm.ParseResourceID(l.key.ParentResourceID) require.NoError(t, err) - untypedCRUD := database.NewUntypedCRUD(l.cosmosContainer, *parentResourceID) + untypedCRUD := database.NewUntypedCRUD(cosmosContainer, *parentResourceID) for _, childKey := range l.key.Descendents { childResourceType, err := azcorearm.ParseResourceType(childKey.ResourceType) require.NoError(t, err) diff --git a/test-integration/utils/databasemutationhelpers/step_untypedlistrecursive.go b/test-integration/utils/databasemutationhelpers/step_untypedlistrecursive.go index 61ed0a7150..4a5508155e 100644 --- a/test-integration/utils/databasemutationhelpers/step_untypedlistrecursive.go +++ b/test-integration/utils/databasemutationhelpers/step_untypedlistrecursive.go @@ -45,11 +45,10 @@ type untypedListRecursiveStep struct { key UntypedCRUDKey specializer ResourceCRUDTestSpecializer[database.TypedDocument] - cosmosContainer *azcosmos.ContainerClient expectedResources []*database.TypedDocument } -func newUntypedListRecursiveStep(stepID StepID, cosmosContainer *azcosmos.ContainerClient, stepDir fs.FS) (*untypedListRecursiveStep, error) { +func newUntypedListRecursiveStep(stepID StepID, stepDir fs.FS) (*untypedListRecursiveStep, error) { keyBytes, err := fs.ReadFile(stepDir, "00-key.json") if err != nil { return nil, fmt.Errorf("failed to read key.json: %w", err) @@ -68,7 +67,6 @@ func newUntypedListRecursiveStep(stepID StepID, cosmosContainer *azcosmos.Contai stepID: stepID, key: key, specializer: UntypedCRUDSpecializer{}, - cosmosContainer: cosmosContainer, expectedResources: expectedResources, }, nil } @@ -79,11 +77,11 @@ func (l *untypedListRecursiveStep) StepID() StepID { return l.stepID } -func (l *untypedListRecursiveStep) RunTest(ctx context.Context, t *testing.T) { +func (l *untypedListRecursiveStep) RunTest(ctx context.Context, t *testing.T, cosmosContainer *azcosmos.ContainerClient) { parentResourceID, err := azcorearm.ParseResourceID(l.key.ParentResourceID) require.NoError(t, err) - untypedCRUD := database.NewUntypedCRUD(l.cosmosContainer, *parentResourceID) + untypedCRUD := database.NewUntypedCRUD(cosmosContainer, *parentResourceID) for _, childKey := range l.key.Descendents { childResourceType, err := azcorearm.ParseResourceType(childKey.ResourceType) require.NoError(t, err) diff --git a/test-integration/utils/integrationutils/cosmos_testinfo.go b/test-integration/utils/integrationutils/cosmos_testinfo.go index 1d305e2dfe..05242b2d0b 100644 --- a/test-integration/utils/integrationutils/cosmos_testinfo.go +++ b/test-integration/utils/integrationutils/cosmos_testinfo.go @@ -143,14 +143,14 @@ func (s *CosmosIntegrationTestInfo) CreateNewSubscription(ctx context.Context) ( func (s *CosmosIntegrationTestInfo) CreateSpecificSubscription(ctx context.Context, subscriptionID string) (string, *arm.Subscription, error) { subscription := &arm.Subscription{ - State: arm.SubscriptionStateRegistered, + ResourceID: api.Must(arm.ToSubscriptionResourceID(subscriptionID)), + State: arm.SubscriptionStateRegistered, } - err := s.DBClient.CreateSubscriptionDoc(ctx, subscriptionID, subscription) + ret, err := s.DBClient.Subscriptions().Create(ctx, subscription, nil) if err != nil { return "", nil, err } - ret, err := s.DBClient.GetSubscriptionDoc(ctx, subscriptionID) return subscriptionID, ret, err }