diff --git a/CHANGELOG.md b/CHANGELOG.md index 91f7e88..03294cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,12 +2,16 @@ ## [Unreleased](https://github.com/openfga/go-sdk/compare/v0.7.5...HEAD) +- feat!: add `ExecuteStreaming` to `APIExecutor` for streaming any OpenFGA endpoint via the generic executor. Check out the [documentation](./README.md#calling-other-endpoints) - feat(telemetry): add `fga-client.request.count` metric to track total HTTP requests made by the SDK - fix: The `fga-client.http_request.duration` metric is now disabled by default. Users can enable it via telemetry configuration if needed. -[!WARNING] -BREAKING CHANGE: -The default behavior changed, and fga-client.http_request.duration is now disabled unless explicitly enabled. +> [!WARNING] +> BREAKING CHANGES (pre-v1, semver allows): +> - `OpenFgaApi`: `ApiStreamedListObjectsRequest.Execute()` now returns `(*StreamedListObjectsChannel, error)` instead of `(StreamResultOfStreamedListObjectsResponse, *http.Response, error)` +> - `OpenFgaApi`: `ApiStreamedListObjectsRequest.Options()` now takes `StreamingRequestOptions` instead of `RequestOptions` +> - Removed `ExecuteStreamedListObjects`, `ExecuteStreamedListObjectsWithBufferSize`, and `ProcessStreamedListObjectsResponse` from the `openfga` package. Use `fgaClient.StreamedListObjects(ctx).Body(...).Execute()` or `APIExecutor.ExecuteStreaming` instead +> - `fga-client.http_request.duration` metric is now disabled by default; enable it via telemetry configuration if needed ## v0.7.5 diff --git a/README.md b/README.md index 7b14c95..578f668 100644 --- a/README.md +++ b/README.md @@ -1177,6 +1177,70 @@ fmt.Printf("Status Code: %d\n", rawResponse.StatusCode) fmt.Printf("Headers: %+v\n", rawResponse.Headers) ``` +#### Example: Calling streaming endpoints (e.g., StreamedListObjects) + +For streaming API endpoints, use the `ExecuteStreaming` method. This is useful for endpoints like `StreamedListObjects` that stream results as they are computed rather than waiting for all results before responding. + +```go +// Get the generic API executor +executor := fgaClient.GetAPIExecutor() + +// Build a streaming request for StreamedListObjects +request := openfga.NewAPIExecutorRequestBuilder("StreamedListObjects", http.MethodPost, "/stores/{store_id}/streamed-list-objects"). + WithPathParameter("store_id", storeID). + WithBody(openfga.ListObjectsRequest{ + AuthorizationModelId: openfga.PtrString(modelID), + Type: "document", + Relation: "viewer", + User: "user:alice", + }). + Build() + +// Execute the streaming request +// The bufferSize parameter controls how many results can be buffered in the channel +channel, err := executor.ExecuteStreaming(ctx, request, openfga.DefaultStreamBufferSize) +if err != nil { + log.Fatalf("Streaming request failed: %v", err) +} +defer channel.Close() // Always close the channel when done + +// Process results as they stream in +for { + select { + case result, ok := <-channel.Results: + if !ok { + // Results channel closed, stream completed + // Check for any final errors + select { + case err := <-channel.Errors: + if err != nil { + log.Fatalf("Stream error: %v", err) + } + default: + } + fmt.Println("Stream completed successfully") + return + } + // Decode the raw JSON bytes into a typed response + var response openfga.StreamedListObjectsResponse + if err := json.Unmarshal(result, &response); err != nil { + log.Fatalf("Failed to decode stream result: %v", err) + } + fmt.Printf("Received object: %s\n", response.Object) + case err := <-channel.Errors: + if err != nil { + log.Fatalf("Stream error: %v", err) + } + } +} +``` + +The `ExecuteStreaming` method returns an `APIExecutorStreamingChannel` with: +- `Results chan []byte`: Raw JSON bytes for each streamed result +- `Errors chan error`: Any errors that occur during streaming +- `Close()`: Method to cancel streaming and cleanup resources + + ### Retries If a network request fails with a 429 or 5xx error from the server, the SDK will automatically retry the request up to 3 times with a minimum wait time of 100 milliseconds between each attempt. diff --git a/api_executor.go b/api_executor.go index d48ef91..c0a1a84 100644 --- a/api_executor.go +++ b/api_executor.go @@ -1,8 +1,10 @@ package openfga import ( + "bufio" "bytes" "context" + "encoding/json" "errors" "io" "log" @@ -81,8 +83,11 @@ func (b *APIExecutorRequestBuilder) WithPathParameter(key, value string) *APIExe } // WithPathParameters sets all path parameters at once. -// Replaces any previously set path parameters. +// Replaces any previously set path parameters. If params is nil, this is a no-op. func (b *APIExecutorRequestBuilder) WithPathParameters(params map[string]string) *APIExecutorRequestBuilder { + if params == nil { + return b + } b.request.PathParameters = params return b } @@ -98,8 +103,11 @@ func (b *APIExecutorRequestBuilder) WithQueryParameter(key, value string) *APIEx } // WithQueryParameters sets all query parameters at once. -// Replaces any previously set query parameters. +// Replaces any previously set query parameters. If params is nil, this is a no-op. func (b *APIExecutorRequestBuilder) WithQueryParameters(params url.Values) *APIExecutorRequestBuilder { + if params == nil { + return b + } b.request.QueryParameters = params return b } @@ -121,8 +129,11 @@ func (b *APIExecutorRequestBuilder) WithHeader(key, value string) *APIExecutorRe } // WithHeaders sets all custom headers at once. -// Replaces any previously set headers. +// Replaces any previously set headers. If headers is nil, this is a no-op. func (b *APIExecutorRequestBuilder) WithHeaders(headers map[string]string) *APIExecutorRequestBuilder { + if headers == nil { + return b + } b.request.Headers = headers return b } @@ -152,17 +163,7 @@ type APIExecutor interface { // Execute performs an API request with automatic retry logic, telemetry, and error handling. // It returns the raw response that can be decoded manually. // - // Example using struct literal: - // openfga.APIExecutorRequest{ - // OperationName: "Check", - // Method: "POST", - // Path: "/stores/{store_id}/check", - // PathParameters: map[string]string{"store_id": storeID}, - // Body: checkRequest, - // } - // response, err := executor.Execute(ctx, request) - // - // Example using builder pattern: + // Example: // request := openfga.NewAPIExecutorRequestBuilder("Check", "POST", "/stores/{store_id}/check"). // WithPathParameter("store_id", storeID). // WithBody(checkRequest). @@ -173,18 +174,7 @@ type APIExecutor interface { // ExecuteWithDecode performs an API request and decodes the response into the provided result pointer. // The result parameter must be a pointer to the type you want to decode into. // - // Example using struct literal: - // var response openfga.CheckResponse - // openfga.APIExecutorRequest{ - // OperationName: "Check", - // Method: "POST", - // Path: "/stores/{store_id}/check", - // PathParameters: map[string]string{"store_id": storeID}, - // Body: checkRequest, - // } - // _, err := executor.ExecuteWithDecode(ctx, request, &response) - // - // Example using builder pattern: + // Example: // var response openfga.CheckResponse // request := openfga.NewAPIExecutorRequestBuilder("Check", "POST", "/stores/{store_id}/check"). // WithPathParameter("store_id", storeID). @@ -192,6 +182,58 @@ type APIExecutor interface { // Build() // _, err := executor.ExecuteWithDecode(ctx, request, &response) ExecuteWithDecode(ctx context.Context, request APIExecutorRequest, result interface{}) (*APIExecutorResponse, error) + + // ExecuteStreaming performs an API request that returns a streaming response. + // It returns an APIExecutorStreamingChannel that provides results and errors through channels. + // The caller is responsible for closing the channel when done using defer channel.Close(). + // + // This method is useful for streaming API endpoints like StreamedListObjects where + // each line in the response body is a separate JSON object. + // + // Parameters: + // - ctx: Context for cancellation. When cancelled, the streaming will stop. + // - request: The API request configuration. The Accept header is automatically set to + // "application/x-ndjson" unless explicitly overridden. + // - bufferSize: The buffer size for the results channel. Use DefaultStreamBufferSize (10) for most cases. + // + // Example - Calling StreamedListObjects: + // + // executor := openfga.NewAPIExecutor(client) + // + // request := openfga.NewAPIExecutorRequestBuilder("StreamedListObjects", "POST", "/stores/{store_id}/streamed-list-objects"). + // WithPathParameter("store_id", storeID). + // WithBody(openfga.ListObjectsRequest{ + // AuthorizationModelId: openfga.PtrString(modelID), + // Type: "document", + // Relation: "viewer", + // User: "user:alice", + // }). + // Build() + // + // channel, err := executor.ExecuteStreaming(ctx, request, openfga.DefaultStreamBufferSize) + // if err != nil { + // return err + // } + // defer channel.Close() + // + // for { + // select { + // case result, ok := <-channel.Results: + // if !ok { + // return nil // Stream completed + // } + // var response openfga.StreamedListObjectsResponse + // if err := json.Unmarshal(result, &response); err != nil { + // return err + // } + // fmt.Printf("Object: %s\n", response.Object) + // case err := <-channel.Errors: + // if err != nil { + // return err + // } + // } + // } + ExecuteStreaming(ctx context.Context, request APIExecutorRequest, bufferSize int) (*APIExecutorStreamingChannel, error) } // validateRequest checks that required fields are present in the request. @@ -247,6 +289,9 @@ type apiExecutor struct { client *APIClient } +// Compile-time check that apiExecutor implements APIExecutor. +var _ APIExecutor = (*apiExecutor)(nil) + // NewAPIExecutor creates a new APIExecutor instance. // This allows users to call any OpenFGA API endpoint, including those not yet supported by the SDK. func NewAPIExecutor(client *APIClient) APIExecutor { @@ -465,3 +510,482 @@ func (e *apiExecutor) logRetry(request APIExecutorRequest, err error, response * waitDuration, request.OperationName, attemptNum, err, request.Body) } } + +// ============================================================================ +// Streaming API Support +// ============================================================================ + +// DefaultStreamBufferSize is the default buffer size for streaming channels. +const DefaultStreamBufferSize = 10 + +// StreamResult represents a streaming result wrapper with either a result or an error. +// This is the format used by OpenFGA's streaming responses. +type StreamResult[T any] struct { + Result *T `json:"result,omitempty" yaml:"result,omitempty"` + Error *Status `json:"error,omitempty" yaml:"error,omitempty"` +} + +// StreamStatusError is returned when the server sends an error object inside a streaming response. +// Callers can type-assert to access the full Status (code, message, details) for classification. +type StreamStatusError struct { + Status *Status +} + +func (e *StreamStatusError) Error() string { + if e.Status != nil && e.Status.Message != nil { + return *e.Status.Message + } + return "stream error" +} + +// StreamingChannel represents a generic channel for streaming responses. +// It provides typed results directly decoded from the stream. +type StreamingChannel[T any] struct { + Results chan T + Errors chan error + cancel context.CancelFunc +} + +// Close cancels the streaming context and cleans up resources. +func (s *StreamingChannel[T]) Close() { + if s.cancel != nil { + s.cancel() + } +} + +// ProcessStreamingResponse processes an HTTP streaming response +// and returns a StreamingChannel with typed results and errors. +// +// This is a convenience wrapper around processStreamingResponseRaw that adds automatic +// JSON unmarshalling of the raw bytes into the target type T. +// +// Parameters: +// - ctx: The context for cancellation +// - httpResponse: The HTTP response to process +// - bufferSize: The buffer size for the channels (default 10 if <= 0) +// +// Returns: +// - *StreamingChannel[T]: A channel containing streaming results and errors +// - error: An error if the response is invalid +func ProcessStreamingResponse[T any](ctx context.Context, httpResponse *http.Response, bufferSize int) (*StreamingChannel[T], error) { + streamCtx, cancel := context.WithCancel(ctx) + + // Use default buffer size of 10 if not specified or invalid + if bufferSize <= 0 { + bufferSize = DefaultStreamBufferSize + } + + channel := &StreamingChannel[T]{ + Results: make(chan T, bufferSize), + Errors: make(chan error, 1), + cancel: cancel, + } + + if httpResponse == nil || httpResponse.Body == nil { + cancel() + return nil, errors.New("response or response body is nil") + } + + doneCh := make(chan struct{}) + + // Interrupt blocking scanner.Scan() when the context is cancelled. + go func() { + select { + case <-streamCtx.Done(): + _ = httpResponse.Body.Close() + case <-doneCh: + } + }() + + go func() { + defer close(channel.Results) + defer close(channel.Errors) + defer cancel() + defer close(doneCh) + defer func() { _ = httpResponse.Body.Close() }() + + scanner := bufio.NewScanner(httpResponse.Body) + // Allow large NDJSON entries (up to 10MB). Tune as needed. + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 10*1024*1024) + + for scanner.Scan() { + select { + case <-streamCtx.Done(): + channel.Errors <- streamCtx.Err() + return + default: + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + var streamResult StreamResult[T] + if err := json.Unmarshal(line, &streamResult); err != nil { + channel.Errors <- err + return + } + + if streamResult.Error != nil { + channel.Errors <- &StreamStatusError{Status: streamResult.Error} + return + } + + if streamResult.Result != nil { + select { + case <-streamCtx.Done(): + channel.Errors <- streamCtx.Err() + return + case channel.Results <- *streamResult.Result: + } + } + } + } + + if err := scanner.Err(); err != nil { + // Prefer context error if we were canceled to avoid surfacing net/http "use of closed network connection". + if streamCtx.Err() != nil { + channel.Errors <- streamCtx.Err() + return + } + channel.Errors <- err + } + }() + + return channel, nil +} + +// APIExecutorStreamingChannel represents a channel for streaming API responses. +// It provides two channels: Results for successful responses and Errors for any errors encountered. +// +// Usage pattern: +// +// channel, err := executor.ExecuteStreaming(ctx, request, 10) +// if err != nil { +// return err +// } +// defer channel.Close() +// +// for { +// select { +// case result, ok := <-channel.Results: +// if !ok { +// // Channel closed, check for errors +// select { +// case err := <-channel.Errors: +// if err != nil { +// return err +// } +// default: +// } +// return nil +// } +// // Process result (raw JSON bytes) +// var response YourResponseType +// json.Unmarshal(result, &response) +// case err := <-channel.Errors: +// if err != nil { +// return err +// } +// } +// } +type APIExecutorStreamingChannel struct { + // Results channel receives raw JSON bytes for each streamed result. + // The channel is closed when the stream ends or an error occurs. + Results chan []byte + + // Errors channel receives any errors that occur during streaming. + // Only one error will be sent before the channel is closed. + Errors chan error + + // cancel is the function to cancel the streaming context + cancel context.CancelFunc +} + +// Close cancels the streaming context and cleans up resources. +// It is safe to call Close multiple times. +// Always defer Close() after successfully creating a streaming channel. +func (s *APIExecutorStreamingChannel) Close() { + if s.cancel != nil { + s.cancel() + } +} + +// ExecuteStreaming performs an API request that returns a streaming response. +// It returns an APIExecutorStreamingChannel that provides results and errors through channels. +// The caller is responsible for closing the channel when done using defer channel.Close(). +// +// Streaming responses are line-delimited JSON where each line is a JSON object. +// Each line is expected to have either a "result" or "error" field wrapped in a StreamResult structure. +// +// Parameters: +// - ctx: Context for cancellation. When cancelled, the streaming will stop. +// - request: The API request configuration. Headers should include "Accept": "application/x-ndjson" for streaming. +// - bufferSize: The buffer size for the results channel. Use DefaultStreamBufferSize (10) for most cases. +// A larger buffer can improve throughput but uses more memory. +// +// Example - Calling StreamedListObjects: +// +// executor := openfga.NewAPIExecutor(client) +// +// request := openfga.NewAPIExecutorRequestBuilder("StreamedListObjects", "POST", "/stores/{store_id}/streamed-list-objects"). +// WithPathParameter("store_id", storeID). +// WithHeader("Accept", "application/x-ndjson"). +// WithBody(openfga.ListObjectsRequest{ +// AuthorizationModelId: openfga.PtrString(modelID), +// Type: "document", +// Relation: "viewer", +// User: "user:alice", +// }). +// Build() +// +// channel, err := executor.ExecuteStreaming(ctx, request, openfga.DefaultStreamBufferSize) +// if err != nil { +// return err +// } +// defer channel.Close() +// +// for { +// select { +// case result, ok := <-channel.Results: +// if !ok { +// // Stream completed +// return nil +// } +// var response openfga.StreamedListObjectsResponse +// if err := json.Unmarshal(result, &response); err != nil { +// return err +// } +// fmt.Printf("Object: %s\n", response.Object) +// case err := <-channel.Errors: +// if err != nil { +// return err +// } +// } +// } +func (e *apiExecutor) ExecuteStreaming(ctx context.Context, request APIExecutorRequest, bufferSize int) (*APIExecutorStreamingChannel, error) { + // Validate required fields + if err := validateRequest(request); err != nil { + return nil, err + } + + // Build request parameters + path := buildPath(request.Path, request.PathParameters) + + if strings.Contains(path, "{") || strings.Contains(path, "}") { + return nil, reportError("not all path parameters were provided for path: %s", path) + } + + headerParams := prepareHeaders(request.Headers) + // Ensure Accept header is set for streaming unless already overridden + if headerParams["Accept"] == "application/json" { + headerParams["Accept"] = "application/x-ndjson" + } + + queryParams := request.QueryParameters + if queryParams == nil { + queryParams = url.Values{} + } + + storeID := request.PathParameters["store_id"] + + // Get retry configuration — same retry logic as executeInternal for the initial connection phase + retryParams := e.getRetryParams() + + // Track from before the retry loop so telemetry captures total time to connection including any retry waits. + requestStarted := time.Now() + + var lastErr error + + for attemptNum := 0; attemptNum < retryParams.MaxRetry+1; attemptNum++ { + // Prepare HTTP request (must be rebuilt each attempt as the body reader is consumed) + req, err := e.client.prepareRequest(ctx, path, request.Method, request.Body, headerParams, queryParams) + if err != nil { + return nil, err + } + + // Execute HTTP request + httpResponse, err := e.client.callAPI(req) + if err != nil { + lastErr = err + if attemptNum >= retryParams.MaxRetry { + return nil, lastErr + } + if shouldRetry, waitDuration := e.determineRetry(err, nil, attemptNum, retryParams, request.OperationName); shouldRetry { + if e.client.cfg.Debug { + log.Printf("\nWaiting %v to retry streaming %v (attempt %d, error=%v)\n", + waitDuration, request.OperationName, attemptNum, err) + } + select { + case <-time.After(waitDuration): + case <-ctx.Done(): + return nil, ctx.Err() + } + continue + } + return nil, lastErr + } + if httpResponse == nil { + lastErr = reportError("nil HTTP response from API client") + if attemptNum >= retryParams.MaxRetry { + return nil, lastErr + } + if shouldRetry, waitDuration := e.determineRetry(lastErr, nil, attemptNum, retryParams, request.OperationName); shouldRetry { + if e.client.cfg.Debug { + log.Printf("\nWaiting %v to retry streaming %v (attempt %d, error=%v)\n", + waitDuration, request.OperationName, attemptNum, lastErr) + } + select { + case <-time.After(waitDuration): + case <-ctx.Done(): + return nil, ctx.Err() + } + continue + } + return nil, lastErr + } + + // Handle HTTP errors (status >= 300) — these may be retryable (e.g. 429, 500) + if httpResponse.StatusCode >= http.StatusMultipleChoices { + responseBody, readErr := io.ReadAll(httpResponse.Body) + _ = httpResponse.Body.Close() + if readErr != nil { + return nil, readErr + } + apiErr := e.client.handleAPIError(httpResponse, responseBody, request.Body, request.OperationName, storeID) + lastErr = apiErr + + if attemptNum >= retryParams.MaxRetry { + return nil, lastErr + } + + resp := makeAPIExecutorResponse(httpResponse, responseBody) + if shouldRetry, waitDuration := e.determineRetry(apiErr, resp, attemptNum, retryParams, request.OperationName); shouldRetry { + if e.client.cfg.Debug { + log.Printf("\nWaiting %v to retry streaming %v (attempt %d, status=%d, error=%v)\n", + waitDuration, request.OperationName, attemptNum, httpResponse.StatusCode, apiErr) + } + select { + case <-time.After(waitDuration): + case <-ctx.Done(): + return nil, ctx.Err() + } + continue + } + return nil, lastErr + } + + // Record telemetry at connection establishment (not stream completion) so that + // streaming latency remains comparable to non-streaming request durations. + e.recordTelemetry(request.OperationName, storeID, request.Body, req, httpResponse, requestStarted, attemptNum) + + // Success — process streaming response (no retries once the stream is established) + return processStreamingResponseRaw(ctx, httpResponse, bufferSize) + } + + // All retries exhausted + if lastErr != nil { + return nil, lastErr + } + return nil, reportError("request failed without response") +} + +// processStreamingResponseRaw processes an HTTP streaming response. +// It returns an APIExecutorStreamingChannel with raw JSON bytes for each result. +func processStreamingResponseRaw(ctx context.Context, httpResponse *http.Response, bufferSize int) (*APIExecutorStreamingChannel, error) { + streamCtx, cancel := context.WithCancel(ctx) + + // Use default buffer size if not specified or invalid + if bufferSize <= 0 { + bufferSize = DefaultStreamBufferSize + } + + channel := &APIExecutorStreamingChannel{ + Results: make(chan []byte, bufferSize), + Errors: make(chan error, 1), + cancel: cancel, + } + + if httpResponse == nil || httpResponse.Body == nil { + cancel() + return nil, reportError("response or response body is nil") + } + + doneCh := make(chan struct{}) + + // Interrupt blocking scanner.Scan() when the context is cancelled. + // scanner.Scan() blocks on network I/O and only checks context between lines, + // so we must close the body to unblock it immediately. + go func() { + select { + case <-streamCtx.Done(): + _ = httpResponse.Body.Close() + case <-doneCh: + } + }() + + go func() { + defer close(channel.Results) + defer close(channel.Errors) + defer cancel() + defer close(doneCh) + defer func() { _ = httpResponse.Body.Close() }() + + scanner := bufio.NewScanner(httpResponse.Body) + // Allow large NDJSON entries (up to 10MB). Tune as needed. + buf := make([]byte, 0, 64*1024) + scanner.Buffer(buf, 10*1024*1024) + + for scanner.Scan() { + select { + case <-streamCtx.Done(): + channel.Errors <- streamCtx.Err() + return + default: + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + // Parse the StreamResult wrapper to check for errors + var streamResult struct { + Result json.RawMessage `json:"result,omitempty"` + Error *Status `json:"error,omitempty"` + } + if err := json.Unmarshal(line, &streamResult); err != nil { + channel.Errors <- err + return + } + + if streamResult.Error != nil { + channel.Errors <- &StreamStatusError{Status: streamResult.Error} + return + } + + if streamResult.Result != nil { + // Make a copy of the raw JSON to send through the channel + resultCopy := make([]byte, len(streamResult.Result)) + copy(resultCopy, streamResult.Result) + + select { + case <-streamCtx.Done(): + channel.Errors <- streamCtx.Err() + return + case channel.Results <- resultCopy: + } + } + } + } + + if err := scanner.Err(); err != nil { + // Prefer context error if we were canceled to avoid surfacing net/http "use of closed network connection". + if streamCtx.Err() != nil { + channel.Errors <- streamCtx.Err() + return + } + channel.Errors <- err + } + }() + + return channel, nil +} diff --git a/api_executor_test.go b/api_executor_test.go index be33403..6c517b5 100644 --- a/api_executor_test.go +++ b/api_executor_test.go @@ -9,6 +9,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1309,3 +1310,1263 @@ func TestBuildPath_SpecialCases(t *testing.T) { assert.Contains(t, result, "&") }) } + +// ============================================================================ +// Streaming API Tests +// ============================================================================ + +func TestAPIExecutorStreamingChannel_Close(t *testing.T) { + t.Parallel() + + t.Run("close_with_cancel_function", func(t *testing.T) { + t.Parallel() + + cancelCalled := false + channel := &APIExecutorStreamingChannel{ + Results: make(chan []byte), + Errors: make(chan error), + cancel: func() { + cancelCalled = true + }, + } + + channel.Close() + assert.True(t, cancelCalled, "cancel function should be called") + }) + + t.Run("close_with_nil_cancel", func(t *testing.T) { + t.Parallel() + + channel := &APIExecutorStreamingChannel{ + Results: make(chan []byte), + Errors: make(chan error), + cancel: nil, + } + + // Should not panic + assert.NotPanics(t, func() { + channel.Close() + }) + }) + + t.Run("close_multiple_times", func(t *testing.T) { + t.Parallel() + + callCount := 0 + channel := &APIExecutorStreamingChannel{ + Results: make(chan []byte), + Errors: make(chan error), + cancel: func() { + callCount++ + }, + } + + channel.Close() + channel.Close() + channel.Close() + + assert.Equal(t, 3, callCount, "cancel function should be called each time") + }) +} + +func TestDefaultStreamBufferSize(t *testing.T) { + t.Parallel() + + assert.Equal(t, 10, DefaultStreamBufferSize, "default buffer size should be 10") +} + +func TestProcessStreamingResponseRaw(t *testing.T) { + t.Parallel() + + t.Run("nil_response_returns_error", func(t *testing.T) { + t.Parallel() + + channel, err := processStreamingResponseRaw(context.Background(), nil, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "response or response body is nil") + }) + + t.Run("nil_body_returns_error", func(t *testing.T) { + t.Parallel() + + resp := &http.Response{ + StatusCode: 200, + Body: nil, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "response or response body is nil") + }) + + t.Run("uses_default_buffer_size_when_zero", func(t *testing.T) { + t.Parallel() + + body := io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}`)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 0) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + // Check buffer size (indirectly through cap) + assert.Equal(t, DefaultStreamBufferSize, cap(channel.Results)) + }) + + t.Run("uses_default_buffer_size_when_negative", func(t *testing.T) { + t.Parallel() + + body := io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}`)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, -5) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + assert.Equal(t, DefaultStreamBufferSize, cap(channel.Results)) + }) + + t.Run("uses_custom_buffer_size", func(t *testing.T) { + t.Parallel() + + body := io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}`)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 25) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + assert.Equal(t, 25, cap(channel.Results)) + }) + + t.Run("processes_single_result", func(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + // Collect results + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + // Check for errors + select { + case err := <-channel.Errors: + assert.NoError(t, err) + default: + } + + assert.Len(t, results, 1) + assert.JSONEq(t, `{"object":"document:1"}`, string(results[0])) + }) + + t.Run("processes_multiple_results", func(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + + `{"result":{"object":"document:2"}}` + "\n" + + `{"result":{"object":"document:3"}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + select { + case err := <-channel.Errors: + assert.NoError(t, err) + default: + } + + assert.Len(t, results, 3) + assert.JSONEq(t, `{"object":"document:1"}`, string(results[0])) + assert.JSONEq(t, `{"object":"document:2"}`, string(results[1])) + assert.JSONEq(t, `{"object":"document:3"}`, string(results[2])) + }) + + t.Run("handles_stream_error_response", func(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + + `{"error":{"message":"Something went wrong"}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + // First result should come through + result := <-channel.Results + assert.JSONEq(t, `{"object":"document:1"}`, string(result)) + + // Then we should get an error + streamErr := <-channel.Errors + assert.Error(t, streamErr) + assert.Contains(t, streamErr.Error(), "Something went wrong") + }) + + t.Run("handles_invalid_json", func(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + + `invalid json` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + // First result should come through + result := <-channel.Results + assert.JSONEq(t, `{"object":"document:1"}`, string(result)) + + // Then we should get a JSON parsing error + streamErr := <-channel.Errors + assert.Error(t, streamErr) + }) + + t.Run("skips_empty_lines", func(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + + `` + "\n" + + `{"result":{"object":"document:2"}}` + "\n" + + `` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{ + StatusCode: 200, + Body: body, + } + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Len(t, results, 2) + }) + + t.Run("context_cancellation_stops_streaming", func(t *testing.T) { + t.Parallel() + + // Create a pipe where we control when data is written + pr, pw := io.Pipe() + + resp := &http.Response{ + StatusCode: 200, + Body: pr, + } + + ctx, cancel := context.WithCancel(context.Background()) + channel, err := processStreamingResponseRaw(ctx, resp, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + + // Write one result first + go func() { + _, _ = pw.Write([]byte(`{"result":{"object":"doc:1"}}` + "\n")) + // Give time for the result to be processed, then close the pipe to unblock the scanner + time.Sleep(50 * time.Millisecond) + cancel() + pw.Close() + }() + + // Read the first result + result := <-channel.Results + assert.JSONEq(t, `{"object":"doc:1"}`, string(result)) + + // The channel should close after context is cancelled and pipe is closed + // Wait for channels to close + select { + case <-channel.Results: + // Either got another result or channel closed — both are fine + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for channel to close") + } + }) +} + +func TestAPIExecutor_ExecuteStreaming(t *testing.T) { + t.Parallel() + + t.Run("validates_request", func(t *testing.T) { + t.Parallel() + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return makeResp(200, "", nil), nil + }}, nil) + + executor := NewAPIExecutor(client) + + // Missing operation name + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + Method: "POST", + Path: "/test", + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "operationName is required") + }) + + t.Run("validates_path_parameters", func(t *testing.T) { + t.Parallel() + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return makeResp(200, "", nil), nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + // Missing path parameter + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "not all path parameters were provided") + }) + + t.Run("sets_ndjson_accept_header", func(t *testing.T) { + t.Parallel() + + var capturedReq *http.Request + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + // Drain the channel + for range channel.Results { + } + + assert.Equal(t, "application/x-ndjson", capturedReq.Header.Get("Accept")) + }) + + t.Run("handles_http_error", func(t *testing.T) { + t.Parallel() + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 400, + Body: io.NopCloser(strings.NewReader(`{"code":"validation_error","message":"Invalid request"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + }) + + t.Run("successful_streaming", func(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"doc:1"}}` + "\n" + + `{"result":{"object":"doc:2"}}` + "\n" + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(streamData)), + Header: http.Header{}, + }, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Len(t, results, 2) + assert.JSONEq(t, `{"object":"doc:1"}`, string(results[0])) + assert.JSONEq(t, `{"object":"doc:2"}`, string(results[1])) + }) +} + +// ============================================================================ +// Additional Streaming Tests +// ============================================================================ + +func TestAPIExecutor_ExecuteStreaming_TransportError(t *testing.T) { + t.Parallel() + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return nil, errors.New("connection refused") + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "connection refused") +} + +func TestAPIExecutor_ExecuteStreaming_NilHTTPResponse(t *testing.T) { + t.Parallel() + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return nil, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + }, 10) + + assert.Nil(t, channel) + if err != nil { + assert.Contains(t, err.Error(), "nil") + } +} + +func TestAPIExecutor_ExecuteStreaming_PreservesCustomAcceptHeader(t *testing.T) { + t.Parallel() + + var capturedReq *http.Request + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + Headers: map[string]string{"Accept": "application/custom+json"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + for range channel.Results { + } + + assert.Equal(t, "application/custom+json", capturedReq.Header.Get("Accept")) +} + +func TestAPIExecutor_ExecuteStreaming_DefaultsToNdjsonAccept(t *testing.T) { + t.Parallel() + + var capturedReq *http.Request + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + capturedReq = req + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + for range channel.Results { + } + + assert.Equal(t, "application/x-ndjson", capturedReq.Header.Get("Accept")) +} + +func TestAPIExecutor_ExecuteStreaming_ServerError(t *testing.T) { + t.Parallel() + + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 500, + Body: io.NopCloser(strings.NewReader(`{"code":"internal_error","message":"Internal Server Error"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }}, nil) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/test", + PathParameters: map[string]string{"store_id": "123"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) +} + +// ============================================================================ +// Streaming Retry Tests +// ============================================================================ + +func TestAPIExecutor_ExecuteStreaming_RetriesOnTransportError(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= 2 { + return nil, errors.New("connection refused") + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, &RetryParams{MaxRetry: 3, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Equal(t, 3, attemptCount, "expected 3 attempts (2 failures + 1 success)") + assert.Len(t, results, 1) + assert.JSONEq(t, `{"object":"doc:1"}`, string(results[0])) +} + +func TestAPIExecutor_ExecuteStreaming_RetriesOnRateLimitError(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= 1 { + return &http.Response{ + StatusCode: 429, + Body: io.NopCloser(strings.NewReader(`{"code":"rate_limit_exceeded","message":"Rate limit exceeded"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, &RetryParams{MaxRetry: 3, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Equal(t, 2, attemptCount, "expected 2 attempts (1 rate limit + 1 success)") + assert.Len(t, results, 1) +} + +func TestAPIExecutor_ExecuteStreaming_RetriesOnInternalServerError(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= 2 { + return &http.Response{ + StatusCode: 500, + Body: io.NopCloser(strings.NewReader(`{"code":"internal_error","message":"Internal Server Error"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, &RetryParams{MaxRetry: 3, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Equal(t, 3, attemptCount, "expected 3 attempts (2 server errors + 1 success)") + assert.Len(t, results, 1) +} + +func TestAPIExecutor_ExecuteStreaming_ExhaustsRetriesOnPersistentTransportError(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + return nil, errors.New("connection refused") + }}, &RetryParams{MaxRetry: 2, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "connection refused") + assert.Equal(t, 3, attemptCount, "expected 3 attempts (1 initial + 2 retries)") +} + +func TestAPIExecutor_ExecuteStreaming_ExhaustsRetriesOnPersistentServerError(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + return &http.Response{ + StatusCode: 500, + Body: io.NopCloser(strings.NewReader(`{"code":"internal_error","message":"Internal Server Error"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + }}, &RetryParams{MaxRetry: 2, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Equal(t, 3, attemptCount, "expected 3 attempts (1 initial + 2 retries)") +} + +func TestAPIExecutor_ExecuteStreaming_RetriesOnNonSpecificHTTPErrors(t *testing.T) { + // 400 errors fall through to determineRetry's default case, which retries them + // just like the non-streaming executeInternal path. This test verifies that + // streaming matches the same behavior as non-streaming endpoints. + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= 2 { + return &http.Response{ + StatusCode: 400, + Body: io.NopCloser(strings.NewReader(`{"code":"validation_error","message":"Invalid request"}`)), + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, nil + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, &RetryParams{MaxRetry: 3, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Equal(t, 3, attemptCount, "400 errors are retried via default case, matching non-streaming behavior") + assert.Len(t, results, 1) +} + +func TestAPIExecutor_ExecuteStreaming_DoesNotRetryContextCancellation(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + return nil, context.Canceled + }}, &RetryParams{MaxRetry: 3, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Equal(t, 1, attemptCount, "should not retry context cancellation") +} + +func TestAPIExecutor_ExecuteStreaming_DoesNotRetryContextDeadlineExceeded(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + return nil, context.DeadlineExceeded + }}, &RetryParams{MaxRetry: 3, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Equal(t, 1, attemptCount, "should not retry context deadline exceeded") +} + +func TestAPIExecutor_ExecuteStreaming_NoRetryWhenMaxRetryIsZero(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + return nil, errors.New("connection refused") + }}, &RetryParams{MaxRetry: 0, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + assert.Error(t, err) + assert.Nil(t, channel) + assert.Contains(t, err.Error(), "connection refused") + assert.Equal(t, 1, attemptCount, "should only attempt once with MaxRetry=0") +} + +func TestAPIExecutor_ExecuteStreaming_RetriesOnNilHTTPResponseThenSucceeds(t *testing.T) { + t.Parallel() + + attemptCount := 0 + client := newTestClient(t, &testRoundTripper{fn: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= 1 { + return nil, nil + } + return &http.Response{ + StatusCode: 200, + Body: io.NopCloser(strings.NewReader(`{"result":{"object":"doc:1"}}` + "\n")), + Header: http.Header{}, + }, nil + }}, &RetryParams{MaxRetry: 2, MinWaitInMs: 1}) + + executor := NewAPIExecutor(client) + + channel, err := executor.ExecuteStreaming(context.Background(), APIExecutorRequest{ + OperationName: "StreamedListObjects", + Method: "POST", + Path: "/stores/{store_id}/streamed-list-objects", + PathParameters: map[string]string{"store_id": "123"}, + Body: map[string]string{"type": "document"}, + }, 10) + + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Equal(t, 2, attemptCount, "expected 2 attempts (1 nil response + 1 success)") + assert.Len(t, results, 1) +} + +func TestProcessStreamingResponseRaw_StreamErrorWithNilMessage(t *testing.T) { + t.Parallel() + + streamData := `{"error":{"code":500}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{StatusCode: 200, Body: body} + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + streamErr := <-channel.Errors + assert.Error(t, streamErr) + assert.Equal(t, "stream error", streamErr.Error()) +} + +func TestProcessStreamingResponseRaw_OnlyNullResults(t *testing.T) { + t.Parallel() + + // A null result is valid JSON and is passed through by the raw processor. + // Both results come through: "null" (raw JSON bytes) and the document object. + streamData := `{"result":null}` + "\n" + `{"result":{"object":"document:1"}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{StatusCode: 200, Body: body} + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + assert.Len(t, results, 2) + assert.Equal(t, "null", string(results[0])) + assert.JSONEq(t, `{"object":"document:1"}`, string(results[1])) +} + +func TestProcessStreamingResponseRaw_EmptyStream(t *testing.T) { + t.Parallel() + + body := io.NopCloser(strings.NewReader("")) + resp := &http.Response{StatusCode: 200, Body: body} + + channel, err := processStreamingResponseRaw(context.Background(), resp, 10) + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results [][]byte + for result := range channel.Results { + results = append(results, result) + } + + select { + case err := <-channel.Errors: + assert.NoError(t, err) + default: + } + + assert.Empty(t, results) +} + +func TestProcessStreamingResponse_Generic_StreamError(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + + `{"error":{"message":"Server exploded"}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{StatusCode: 200, Body: body} + + channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](context.Background(), resp, 10) + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + result := <-channel.Results + assert.Equal(t, "document:1", result.Object) + + streamErr := <-channel.Errors + assert.Error(t, streamErr) + assert.Contains(t, streamErr.Error(), "Server exploded") +} + +func TestProcessStreamingResponse_Generic_InvalidInnerJSON(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"not_a_valid_field": 123}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{StatusCode: 200, Body: body} + + channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](context.Background(), resp, 10) + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + result := <-channel.Results + assert.Equal(t, "", result.Object) +} + +func TestProcessStreamingResponse_Generic_MultipleResults(t *testing.T) { + t.Parallel() + + streamData := `{"result":{"object":"document:1"}}` + "\n" + + `{"result":{"object":"document:2"}}` + "\n" + + `{"result":{"object":"document:3"}}` + "\n" + body := io.NopCloser(strings.NewReader(streamData)) + resp := &http.Response{StatusCode: 200, Body: body} + + channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](context.Background(), resp, 10) + require.NoError(t, err) + require.NotNil(t, channel) + defer channel.Close() + + var results []string + for result := range channel.Results { + results = append(results, result.Object) + } + + assert.Equal(t, []string{"document:1", "document:2", "document:3"}, results) +} + +func TestStreamedListObjectsChannel_CloseWithNilCancel(t *testing.T) { + t.Parallel() + + channel := &StreamedListObjectsChannel{ + Objects: make(chan StreamedListObjectsResponse), + Errors: make(chan error), + cancel: nil, + } + + assert.NotPanics(t, func() { + channel.Close() + }) +} + +func TestStreamedListObjectsChannel_CloseMultipleTimes(t *testing.T) { + t.Parallel() + + callCount := 0 + channel := &StreamedListObjectsChannel{ + Objects: make(chan StreamedListObjectsResponse), + Errors: make(chan error), + cancel: func() { + callCount++ + }, + } + + channel.Close() + channel.Close() + channel.Close() + + assert.Equal(t, 3, callCount) +} + +func TestAPIExecutorRequestBuilder_WithHeaders_NilIsNoOp(t *testing.T) { + t.Parallel() + + builder := NewAPIExecutorRequestBuilder("Test", "GET", "/test") + builder.WithHeader("X-Existing", "value") + builder.WithHeaders(nil) + + req := builder.Build() + + assert.NotNil(t, req.Headers) + assert.Equal(t, "value", req.Headers["X-Existing"]) +} + +func TestAPIExecutorRequestBuilder_WithPathParameters_NilIsNoOp(t *testing.T) { + t.Parallel() + + builder := NewAPIExecutorRequestBuilder("Test", "GET", "/stores/{store_id}") + builder.WithPathParameter("store_id", "123") + builder.WithPathParameters(nil) + + req := builder.Build() + + assert.NotNil(t, req.PathParameters) + assert.Equal(t, "123", req.PathParameters["store_id"]) +} + +func TestAPIExecutorRequestBuilder_WithQueryParameters_NilIsNoOp(t *testing.T) { + t.Parallel() + + builder := NewAPIExecutorRequestBuilder("Test", "GET", "/test") + builder.WithQueryParameter("page", "1") + builder.WithQueryParameters(nil) + + req := builder.Build() + + assert.NotNil(t, req.QueryParameters) + assert.Equal(t, "1", req.QueryParameters.Get("page")) +} + +func TestConvertToStreamedListObjectsChannel_InvalidJSON(t *testing.T) { + t.Parallel() + + rawChannel := &APIExecutorStreamingChannel{ + Results: make(chan []byte, 1), + Errors: make(chan error, 1), + cancel: func() {}, + } + + rawChannel.Results <- []byte(`{invalid json}`) + close(rawChannel.Results) + close(rawChannel.Errors) + + typedChannel := convertToStreamedListObjectsChannel(context.Background(), rawChannel) + defer typedChannel.Close() + + select { + case _, ok := <-typedChannel.Objects: + if ok { + t.Fatal("Expected no objects for invalid JSON") + } + case err := <-typedChannel.Errors: + assert.Error(t, err) + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for error") + } +} + +func TestConvertToStreamedListObjectsChannel_ForwardsRawErrors(t *testing.T) { + t.Parallel() + + rawChannel := &APIExecutorStreamingChannel{ + Results: make(chan []byte, 1), + Errors: make(chan error, 1), + cancel: func() {}, + } + + close(rawChannel.Results) + rawChannel.Errors <- errors.New("upstream error") + close(rawChannel.Errors) + + typedChannel := convertToStreamedListObjectsChannel(context.Background(), rawChannel) + defer typedChannel.Close() + + select { + case err := <-typedChannel.Errors: + assert.Error(t, err) + assert.Contains(t, err.Error(), "upstream error") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for error") + } +} + +func TestConvertToStreamedListObjectsChannel_Success(t *testing.T) { + t.Parallel() + + rawChannel := &APIExecutorStreamingChannel{ + Results: make(chan []byte, 3), + Errors: make(chan error, 1), + cancel: func() {}, + } + + rawChannel.Results <- []byte(`{"object":"document:1"}`) + rawChannel.Results <- []byte(`{"object":"document:2"}`) + rawChannel.Results <- []byte(`{"object":"document:3"}`) + close(rawChannel.Results) + close(rawChannel.Errors) + + typedChannel := convertToStreamedListObjectsChannel(context.Background(), rawChannel) + defer typedChannel.Close() + + var objects []string + for obj := range typedChannel.Objects { + objects = append(objects, obj.Object) + } + + assert.Equal(t, []string{"document:1", "document:2", "document:3"}, objects) + + select { + case err := <-typedChannel.Errors: + assert.NoError(t, err) + default: + } +} + +func TestConvertToStreamedListObjectsChannel_ContextCancellation(t *testing.T) { + t.Parallel() + + rawChannel := &APIExecutorStreamingChannel{ + Results: make(chan []byte), + Errors: make(chan error, 1), + cancel: func() {}, + } + + ctx, cancel := context.WithCancel(context.Background()) + + typedChannel := convertToStreamedListObjectsChannel(ctx, rawChannel) + defer typedChannel.Close() + + cancel() + + select { + case err := <-typedChannel.Errors: + assert.Error(t, err) + assert.Contains(t, err.Error(), "context canceled") + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for cancellation error") + } +} diff --git a/api_open_fga.go b/api_open_fga.go index 3821a75..d7c2fe8 100644 --- a/api_open_fga.go +++ b/api_open_fga.go @@ -14,6 +14,7 @@ package openfga import ( "context" + "encoding/json" "net/http" "net/url" "time" @@ -28,6 +29,11 @@ type RequestOptions struct { Headers map[string]string `json:"headers,omitempty"` } +type StreamingRequestOptions struct { + RequestOptions + BufferSize int `json:"buffer_size,omitempty"` +} + type OpenFgaApi interface { /* @@ -777,9 +783,9 @@ type OpenFgaApi interface { /* * StreamedListObjectsExecute executes the request - * @return StreamResultOfStreamedListObjectsResponse + * @return *StreamedListObjectsChannel */ - StreamedListObjectsExecute(r ApiStreamedListObjectsRequest) (StreamResultOfStreamedListObjectsResponse, *http.Response, error) + StreamedListObjectsExecute(r ApiStreamedListObjectsRequest) (*StreamedListObjectsChannel, error) /* * Write Add or delete tuples from the store @@ -2448,7 +2454,7 @@ type ApiStreamedListObjectsRequest struct { ApiService OpenFgaApi storeId string body *ListObjectsRequest - options RequestOptions + options StreamingRequestOptions } func (r ApiStreamedListObjectsRequest) Body(body ListObjectsRequest) ApiStreamedListObjectsRequest { @@ -2456,15 +2462,70 @@ func (r ApiStreamedListObjectsRequest) Body(body ListObjectsRequest) ApiStreamed return r } -func (r ApiStreamedListObjectsRequest) Options(options RequestOptions) ApiStreamedListObjectsRequest { +func (r ApiStreamedListObjectsRequest) Options(options StreamingRequestOptions) ApiStreamedListObjectsRequest { r.options = options return r } -func (r ApiStreamedListObjectsRequest) Execute() (StreamResultOfStreamedListObjectsResponse, *http.Response, error) { +// Execute executes the StreamedListObjects request and returns a streaming channel. +// The returned StreamedListObjectsChannel provides Objects and Errors channels for consuming +// the streamed results. The caller must call Close() on the channel when done. +// +// Example usage: +// +// channel, err := client.OpenFgaApi.StreamedListObjects(ctx, storeId). +// Body(body). +// Execute() +// if err != nil { +// return err +// } +// defer channel.Close() +// +// for { +// select { +// case obj, ok := <-channel.Objects: +// if !ok { +// // Objects channel closed; drain any terminal error before returning. +// select { +// case err := <-channel.Errors: +// return err // nil on clean EOF, non-nil on stream error +// default: +// return nil +// } +// } +// fmt.Printf("Object: %s\n", obj.Object) +// case err := <-channel.Errors: +// if err != nil { +// return err +// } +// } +// } +func (r ApiStreamedListObjectsRequest) Execute() (*StreamedListObjectsChannel, error) { return r.ApiService.StreamedListObjectsExecute(r) } +// StreamedListObjectsChannel provides channels for consuming StreamedListObjects results. +// It maintains backward compatibility with the streaming response structure. +type StreamedListObjectsChannel struct { + // Objects channel receives StreamedListObjectsResponse for each streamed object. + // The channel is closed when the stream ends or an error occurs. + Objects chan StreamedListObjectsResponse + // Errors channel receives any errors that occur during streaming. + // Only one error will be sent before the channel is closed. + Errors chan error + // cancel is the function to cancel the streaming context + cancel context.CancelFunc +} + +// Close cancels the streaming context and cleans up resources. +// It is safe to call Close multiple times. +// Always defer Close() after successfully creating a streaming channel. +func (s *StreamedListObjectsChannel) Close() { + if s.cancel != nil { + s.cancel() + } +} + /* - StreamedListObjects Stream all objects of the given type that the user has a relation with - The Streamed ListObjects API is very similar to the the ListObjects API, with two differences: @@ -2486,17 +2547,17 @@ func (a *OpenFgaApiService) StreamedListObjects(ctx context.Context, storeId str /* * Execute executes the request - * @return StreamResultOfStreamedListObjectsResponse + * @return *StreamedListObjectsChannel */ -func (a *OpenFgaApiService) StreamedListObjectsExecute(r ApiStreamedListObjectsRequest) (StreamResultOfStreamedListObjectsResponse, *http.Response, error) { - var returnValue StreamResultOfStreamedListObjectsResponse +func (a *OpenFgaApiService) StreamedListObjectsExecute(r ApiStreamedListObjectsRequest) (*StreamedListObjectsChannel, error) { if err := validatePathParameter("storeId", r.storeId); err != nil { - return returnValue, nil, err + return nil, err } if err := validateParameter("body", r.body); err != nil { - return returnValue, nil, err + return nil, err } + // Use the APIExecutor to execute the streaming request executor := a.client.GetAPIExecutor() request := NewAPIExecutorRequestBuilder("StreamedListObjects", http.MethodPost, "/stores/{store_id}/streamed-list-objects"). @@ -2504,13 +2565,76 @@ func (a *OpenFgaApiService) StreamedListObjectsExecute(r ApiStreamedListObjectsR WithBody(r.body). WithHeaders(r.options.Headers). Build() - response, err := executor.ExecuteWithDecode(r.ctx, request, &returnValue) - if response != nil { - return returnValue, response.HTTPResponse, err + rawChannel, err := executor.ExecuteStreaming(r.ctx, request, r.options.BufferSize) + if err != nil { + return nil, err } - return returnValue, nil, err + // Convert raw JSON bytes to typed StreamedListObjectsResponse + return convertToStreamedListObjectsChannel(r.ctx, rawChannel), nil +} + +// convertToStreamedListObjectsChannel converts an APIExecutorStreamingChannel (raw bytes) to +// a typed StreamedListObjectsChannel. +func convertToStreamedListObjectsChannel(ctx context.Context, rawChannel *APIExecutorStreamingChannel) *StreamedListObjectsChannel { + streamCtx, cancel := context.WithCancel(ctx) + + typedChannel := &StreamedListObjectsChannel{ + Objects: make(chan StreamedListObjectsResponse, cap(rawChannel.Results)), + Errors: make(chan error, 1), + cancel: cancel, + } + + go func() { + defer close(typedChannel.Objects) + defer close(typedChannel.Errors) + defer cancel() + defer rawChannel.Close() + + for { + select { + case <-streamCtx.Done(): + typedChannel.Errors <- streamCtx.Err() + return + case rawResult, ok := <-rawChannel.Results: + if !ok { + // Raw channel closed, check for errors + select { + case err := <-rawChannel.Errors: + if err != nil { + typedChannel.Errors <- err + } + default: + } + return + } + + var response StreamedListObjectsResponse + if err := json.Unmarshal(rawResult, &response); err != nil { + typedChannel.Errors <- err + return + } + + select { + case <-streamCtx.Done(): + typedChannel.Errors <- streamCtx.Err() + return + case typedChannel.Objects <- response: + } + case err, ok := <-rawChannel.Errors: + if !ok { + // Errors channel closed; nil it out to prevent spinning on a closed channel. + rawChannel.Errors = nil + } else if err != nil { + typedChannel.Errors <- err + return + } + } + } + }() + + return typedChannel } type ApiWriteRequest struct { diff --git a/api_open_fga_test.go b/api_open_fga_test.go index 8a67e2c..2fff670 100644 --- a/api_open_fga_test.go +++ b/api_open_fga_test.go @@ -1939,3 +1939,634 @@ func TestOpenFgaApi(t *testing.T) { } }) } + +func TestStreamedListObjectsExecute(t *testing.T) { + t.Parallel() + + configuration, err := NewConfiguration(Configuration{ + ApiUrl: constants.TestApiUrl, + }) + if err != nil { + t.Fatalf("failed to create configuration: %v", err) + } + apiClient := NewAPIClient(configuration) + storeID := "01GXSB9YR785C4FYS3C0RTG7B2" + + t.Run("successful streaming with multiple objects", func(t *testing.T) { + // Streaming response with multiple streamed objects + responseBody := `{"result":{"object":"document:doc1"}}` + "\n" + + `{"result":{"object":"document:doc2"}}` + "\n" + + `{"result":{"object":"document:doc3"}}` + "\n" + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + resp := httpmock.NewStringResponse(200, responseBody) + resp.Header.Set("Content-Type", "application/x-ndjson") + return resp, nil + }, + ) + + requestBody := ListObjectsRequest{ + AuthorizationModelId: PtrString("01GAHCE4YVKPQEKZQHT2R89MQV"), + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("StreamedListObjects failed: %v", err) + } + defer channel.Close() + + var objects []string + for obj := range channel.Objects { + objects = append(objects, obj.Object) + } + + // Check for errors + select { + case err := <-channel.Errors: + if err != nil { + t.Fatalf("unexpected stream error: %v", err) + } + default: + } + + if len(objects) != 3 { + t.Fatalf("expected 3 objects, got %d", len(objects)) + } + if objects[0] != "document:doc1" || objects[1] != "document:doc2" || objects[2] != "document:doc3" { + t.Fatalf("unexpected objects: %v", objects) + } + }) + + t.Run("streaming with single object", func(t *testing.T) { + responseBody := `{"result":{"object":"document:single"}}` + "\n" + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + resp := httpmock.NewStringResponse(200, responseBody) + resp.Header.Set("Content-Type", "application/x-ndjson") + return resp, nil + }, + ) + + requestBody := ListObjectsRequest{ + User: "user:bob", + Relation: "editor", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("StreamedListObjects failed: %v", err) + } + defer channel.Close() + + var objects []string + for obj := range channel.Objects { + objects = append(objects, obj.Object) + } + + if len(objects) != 1 || objects[0] != "document:single" { + t.Fatalf("expected [document:single], got %v", objects) + } + }) + + t.Run("streaming with empty result", func(t *testing.T) { + // Empty streaming response (no objects match) + responseBody := "" + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + resp := httpmock.NewStringResponse(200, responseBody) + resp.Header.Set("Content-Type", "application/x-ndjson") + return resp, nil + }, + ) + + requestBody := ListObjectsRequest{ + User: "user:nobody", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("StreamedListObjects failed: %v", err) + } + defer channel.Close() + + var objects []string + for obj := range channel.Objects { + objects = append(objects, obj.Object) + } + + if len(objects) != 0 { + t.Fatalf("expected 0 objects, got %d", len(objects)) + } + }) + + t.Run("streaming with stream error", func(t *testing.T) { + // Streaming response with an error in the stream + responseBody := `{"result":{"object":"document:doc1"}}` + "\n" + + `{"error":{"message":"Internal stream error occurred"}}` + "\n" + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + resp := httpmock.NewStringResponse(200, responseBody) + resp.Header.Set("Content-Type", "application/x-ndjson") + return resp, nil + }, + ) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("StreamedListObjects failed: %v", err) + } + defer channel.Close() + + // First object should come through + obj := <-channel.Objects + if obj.Object != "document:doc1" { + t.Fatalf("expected first object document:doc1, got %s", obj.Object) + } + + // Then we should get an error + streamErr := <-channel.Errors + if streamErr == nil { + t.Fatal("expected stream error, got nil") + } + if streamErr.Error() != "Internal stream error occurred" { + t.Fatalf("expected 'Internal stream error occurred', got %v", streamErr) + } + }) + + t.Run("HTTP error response", func(t *testing.T) { + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + return httpmock.NewJsonResponse(400, map[string]interface{}{ + "code": "validation_error", + "message": "Invalid request body", + }) + }, + ) + + requestBody := ListObjectsRequest{ + User: "invalid", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + + if err == nil { + if channel != nil { + channel.Close() + } + t.Fatal("expected error for 400 response, got nil") + } + if channel != nil { + t.Fatal("expected nil channel on error") + } + }) + + t.Run("missing store ID validation", func(t *testing.T) { + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), ""). + Body(requestBody). + Execute() + + if err == nil { + if channel != nil { + channel.Close() + } + t.Fatal("expected error for missing store ID, got nil") + } + }) + + t.Run("missing body validation", func(t *testing.T) { + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Execute() + + if err == nil { + if channel != nil { + channel.Close() + } + t.Fatal("expected error for missing body, got nil") + } + }) + + t.Run("context cancellation", func(t *testing.T) { + // Create a slow response that will be cancelled + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + // Return first result, then the response will be closed when context is cancelled + responseBody := `{"result":{"object":"document:doc1"}}` + "\n" + resp := httpmock.NewStringResponse(200, responseBody) + resp.Header.Set("Content-Type", "application/x-ndjson") + return resp, nil + }, + ) + + ctx, cancel := context.WithCancel(context.Background()) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(ctx, storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("StreamedListObjects failed: %v", err) + } + + // Read first object + obj := <-channel.Objects + if obj.Object != "document:doc1" { + t.Fatalf("expected document:doc1, got %s", obj.Object) + } + + // Cancel the context + cancel() + + // Channel should close + channel.Close() + }) + + t.Run("with custom headers", func(t *testing.T) { + var capturedHeaders http.Header + + httpmock.Activate() + defer httpmock.DeactivateAndReset() + + httpmock.RegisterResponder("POST", fmt.Sprintf("%s/stores/%s/streamed-list-objects", configuration.ApiUrl, storeID), + func(req *http.Request) (*http.Response, error) { + capturedHeaders = req.Header.Clone() + responseBody := `{"result":{"object":"document:doc1"}}` + "\n" + resp := httpmock.NewStringResponse(200, responseBody) + resp.Header.Set("Content-Type", "application/x-ndjson") + return resp, nil + }, + ) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Options(StreamingRequestOptions{ + RequestOptions: RequestOptions{ + Headers: map[string]string{ + "X-Custom-Header": "custom-value", + "X-Request-ID": "req-123", + }, + }, + }). + Execute() + if err != nil { + t.Fatalf("StreamedListObjects failed: %v", err) + } + defer channel.Close() + + // Drain the channel + for range channel.Objects { + } + + if capturedHeaders.Get("X-Custom-Header") != "custom-value" { + t.Fatalf("expected X-Custom-Header to be 'custom-value', got '%s'", capturedHeaders.Get("X-Custom-Header")) + } + if capturedHeaders.Get("X-Request-ID") != "req-123" { + t.Fatalf("expected X-Request-ID to be 'req-123', got '%s'", capturedHeaders.Get("X-Request-ID")) + } + }) + + t.Run("retry on 500 error", func(t *testing.T) { + var attempts int32 + + // First two attempts return 500, third succeeds with streaming response. + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cur := atomic.AddInt32(&attempts, 1) + if cur < 3 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"code":"internal_error","message":"transient"}`)) + return + } + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"result":{"object":"document:retried"}}` + "\n")) + })) + defer server.Close() + + cfg, err := NewConfiguration(Configuration{ + ApiUrl: server.URL, + RetryParams: &RetryParams{MaxRetry: 4, MinWaitInMs: 1}, + HTTPClient: &http.Client{}, + }) + if err != nil { + t.Fatalf("failed to create configuration: %v", err) + } + apiClient := NewAPIClient(cfg) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("expected eventual success after retries, got: %v", err) + } + defer channel.Close() + + var objects []string + for obj := range channel.Objects { + objects = append(objects, obj.Object) + } + + select { + case err := <-channel.Errors: + if err != nil { + t.Fatalf("unexpected stream error: %v", err) + } + default: + } + + gotAttempts := int(atomic.LoadInt32(&attempts)) + if gotAttempts != 3 { + t.Fatalf("expected 3 attempts (2 x 500 + 1 success), got %d", gotAttempts) + } + if len(objects) != 1 || objects[0] != "document:retried" { + t.Fatalf("expected [document:retried], got %v", objects) + } + }) + + t.Run("retry on 429 rate limit error", func(t *testing.T) { + var attempts int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cur := atomic.AddInt32(&attempts, 1) + if cur < 2 { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"code":"rate_limit_exceeded","message":"Rate limit exceeded"}`)) + return + } + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"result":{"object":"document:after-rate-limit"}}` + "\n")) + })) + defer server.Close() + + cfg, err := NewConfiguration(Configuration{ + ApiUrl: server.URL, + RetryParams: &RetryParams{MaxRetry: 3, MinWaitInMs: 1}, + HTTPClient: &http.Client{}, + }) + if err != nil { + t.Fatalf("failed to create configuration: %v", err) + } + apiClient := NewAPIClient(cfg) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("expected eventual success after rate limit retry, got: %v", err) + } + defer channel.Close() + + var objects []string + for obj := range channel.Objects { + objects = append(objects, obj.Object) + } + + gotAttempts := int(atomic.LoadInt32(&attempts)) + if gotAttempts != 2 { + t.Fatalf("expected 2 attempts (1 x 429 + 1 success), got %d", gotAttempts) + } + if len(objects) != 1 || objects[0] != "document:after-rate-limit" { + t.Fatalf("expected [document:after-rate-limit], got %v", objects) + } + }) + + t.Run("retry on 400 validation error via default case", func(t *testing.T) { + // Note: 400 errors fall through to determineRetry's default case, which retries + // them just like the non-streaming Check endpoint does. This is consistent behavior. + var attempts int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"code":"validation_error","message":"Invalid request"}`)) + })) + defer server.Close() + + cfg, err := NewConfiguration(Configuration{ + ApiUrl: server.URL, + RetryParams: &RetryParams{MaxRetry: 2, MinWaitInMs: 1}, + HTTPClient: &http.Client{}, + }) + if err != nil { + t.Fatalf("failed to create configuration: %v", err) + } + apiClient := NewAPIClient(cfg) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + + if err == nil { + if channel != nil { + channel.Close() + } + t.Fatal("expected error for persistent 400 responses, got nil") + } + + gotAttempts := int(atomic.LoadInt32(&attempts)) + // 400 errors are retried via the default case in determineRetry (matching non-streaming behavior) + if gotAttempts != 3 { + t.Fatalf("expected 3 attempts (1 initial + 2 retries, matching non-streaming behavior), got %d", gotAttempts) + } + }) + + t.Run("retry on transport error then succeed", func(t *testing.T) { + var attempts int32 + + // Start a real server for the successful attempt + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/x-ndjson") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"result":{"object":"document:recovered"}}` + "\n")) + })) + defer server.Close() + + // Use a custom round tripper that fails once then delegates to real transport + transport := &http.Transport{} + cfg, err := NewConfiguration(Configuration{ + ApiUrl: server.URL, + RetryParams: &RetryParams{MaxRetry: 3, MinWaitInMs: 1}, + HTTPClient: &http.Client{ + Transport: &testStreamRetryTransport{ + attempts: &attempts, + failUntil: 2, + realTransport: transport, + }, + }, + }) + if err != nil { + t.Fatalf("failed to create configuration: %v", err) + } + apiClient := NewAPIClient(cfg) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + if err != nil { + t.Fatalf("expected eventual success after transport retry, got: %v", err) + } + defer channel.Close() + + var objects []string + for obj := range channel.Objects { + objects = append(objects, obj.Object) + } + + gotAttempts := int(atomic.LoadInt32(&attempts)) + if gotAttempts != 2 { + t.Fatalf("expected 2 attempts (1 transport error + 1 success), got %d", gotAttempts) + } + if len(objects) != 1 || objects[0] != "document:recovered" { + t.Fatalf("expected [document:recovered], got %v", objects) + } + }) + + t.Run("exhausts retries on persistent 500", func(t *testing.T) { + var attempts int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&attempts, 1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"code":"internal_error","message":"persistent failure"}`)) + })) + defer server.Close() + + cfg, err := NewConfiguration(Configuration{ + ApiUrl: server.URL, + RetryParams: &RetryParams{MaxRetry: 2, MinWaitInMs: 1}, + HTTPClient: &http.Client{}, + }) + if err != nil { + t.Fatalf("failed to create configuration: %v", err) + } + apiClient := NewAPIClient(cfg) + + requestBody := ListObjectsRequest{ + User: "user:anne", + Relation: "viewer", + Type: "document", + } + + channel, err := apiClient.OpenFgaApi.StreamedListObjects(context.Background(), storeID). + Body(requestBody). + Execute() + + if err == nil { + if channel != nil { + channel.Close() + } + t.Fatal("expected error after exhausting retries, got nil") + } + + gotAttempts := int(atomic.LoadInt32(&attempts)) + if gotAttempts != 3 { + t.Fatalf("expected 3 attempts (1 initial + 2 retries), got %d", gotAttempts) + } + }) +} + +// testStreamRetryTransport is a test helper that fails the first N requests with a transport error +// and then delegates to a real transport for subsequent requests. +type testStreamRetryTransport struct { + attempts *int32 + failUntil int32 + realTransport http.RoundTripper +} + +func (t *testStreamRetryTransport) RoundTrip(req *http.Request) (*http.Response, error) { + cur := atomic.AddInt32(t.attempts, 1) + if cur < t.failUntil { + return nil, fmt.Errorf("connection refused") + } + return t.realTransport.RoundTrip(req) +} diff --git a/client/client.go b/client/client.go index 5220d67..beb6cfa 100644 --- a/client/client.go +++ b/client/client.go @@ -3457,24 +3457,19 @@ func (client *OpenFgaClient) StreamedListObjectsExecute(request SdkClientStreame Context: request.GetBody().Context, AuthorizationModelId: authorizationModelId, } - requestOptions := RequestOptions{} - bufferSize := 0 + requestOptions := fgaSdk.StreamingRequestOptions{} if request.GetOptions() != nil { - requestOptions = request.GetOptions().RequestOptions + requestOptions.RequestOptions = request.GetOptions().RequestOptions body.Consistency = request.GetOptions().Consistency if request.GetOptions().StreamBufferSize != nil { - bufferSize = *request.GetOptions().StreamBufferSize + requestOptions.BufferSize = *request.GetOptions().StreamBufferSize } } - channel, err := fgaSdk.ExecuteStreamedListObjectsWithBufferSize( - &client.APIClient, - request.GetContext(), - *storeId, - body, - requestOptions, - bufferSize, - ) + channel, err := client.OpenFgaApi.StreamedListObjects(request.GetContext(), *storeId). + Body(body). + Options(requestOptions). + Execute() if err != nil { return nil, err diff --git a/example/Makefile b/example/Makefile index b14e64e..738cdb8 100644 --- a/example/Makefile +++ b/example/Makefile @@ -9,6 +9,12 @@ restore: run: restore go run ${project_name}/${project_name}.go +run-streamed-list-objects: + cd streamed_list_objects && go run . + +run-api-executor: + cd api_executor && go run . + run-openfga: docker pull docker.io/openfga/openfga:${openfga_version} && \ docker run -p 8080:8080 docker.io/openfga/openfga:${openfga_version} run diff --git a/example/README.md b/example/README.md index 97386b0..8a25db0 100644 --- a/example/README.md +++ b/example/README.md @@ -7,9 +7,16 @@ Example 1: A bare bones example. It creates a store, and runs a set of calls against it including creating a model, writing tuples and checking for access. **StreamedListObjects Example:** -Demonstrates how to use the `StreamedListObjects` API with both synchronous and asynchronous consumption patterns. +Demonstrates how to use the concrete `StreamedListObjects` client method with typed responses. +This is the recommended approach for calling the StreamedListObjects API. Includes support for configurable buffer sizes to optimize throughput vs memory usage. +**API Executor Example:** +Demonstrates how to use the low-level `APIExecutor` to call all major OpenFGA endpoints +(ListStores, CreateStore, GetStore, WriteAuthorizationModel, Write, Read, Check, ListObjects, +StreamedListObjects, DeleteStore) using `Execute`, `ExecuteWithDecode`, and `ExecuteStreaming`. +This is useful for custom or unsupported endpoints where you need full control over the request and response. + ### Running the Examples Prerequisites: diff --git a/example/api_executor/README.md b/example/api_executor/README.md new file mode 100644 index 0000000..3ec433f --- /dev/null +++ b/example/api_executor/README.md @@ -0,0 +1,127 @@ +# API Executor Example + +Demonstrates using the **low-level `APIExecutor`** to call real OpenFGA API +endpoints — both standard request/response and streaming. + +This approach is useful when: +- You want to call **any endpoint** (including new or custom ones) not yet + supported by the SDK's typed client methods +- You are using an **earlier version of the SDK** that doesn't have a typed + method for a particular endpoint +- You have a **custom endpoint** deployed that extends the OpenFGA API +- You need **full control over the raw request/response** (headers, body bytes, status codes) + +> For the **recommended high-level approach**, see the [`example1`](../example1/) +> and [`streamed_list_objects`](../streamed_list_objects/) examples. + +## What it covers + +The example exercises all three `APIExecutor` methods against a live server: + +| # | Operation | HTTP | Method Used | +|---|-----------|------|-------------| +| 1 | **ListStores** | `GET /stores` | `Execute` (raw bytes) | +| 2 | **CreateStore** | `POST /stores` | `ExecuteWithDecode` | +| 3 | **GetStore** | `GET /stores/{store_id}` | `ExecuteWithDecode` | +| 4 | **WriteAuthorizationModel** | `POST /stores/{store_id}/authorization-models` | `ExecuteWithDecode` | +| 5 | **WriteTuples** | `POST /stores/{store_id}/write` | `Execute` | +| 6 | **ReadTuples** | `POST /stores/{store_id}/read` | `ExecuteWithDecode` | +| 7 | **Check** | `POST /stores/{store_id}/check` | `ExecuteWithDecode` (+ custom header) | +| 8 | **ListObjects** | `POST /stores/{store_id}/list-objects` | `ExecuteWithDecode` | +| 9 | **StreamedListObjects** | `POST /stores/{store_id}/streamed-list-objects` | `ExecuteStreaming` | +| 10 | **DeleteStore** | `DELETE /stores/{store_id}` | `Execute` | + +## How it works + +The `APIExecutor` provides three methods: +1. **`Execute`** — returns raw response bytes (status code, headers, body) +2. **`ExecuteWithDecode`** — returns decoded typed response +3. **`ExecuteStreaming`** — returns a channel of raw JSON bytes (for NDJSON streaming endpoints) + +All requests are built using `NewAPIExecutorRequestBuilder` with a fluent API +for setting the operation name, HTTP method, path, path parameters, query +parameters, headers, and body. + +## Prerequisites + +- OpenFGA server running on `http://localhost:8080` (or set `FGA_API_URL`) + +## Running + +```bash +cd example/api_executor +go run . +``` + +## Key Code Patterns + +### Standard request with decoded response + +```go +executor := fgaClient.GetAPIExecutor() + +var checkResp openfga.CheckResponse +_, err := executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("Check", "POST", "/stores/{store_id}/check"). + WithPathParameter("store_id", storeID). + WithHeader("X-Request-ID", "my-request-123"). + WithBody(openfga.CheckRequest{ + TupleKey: openfga.CheckRequestTupleKey{ + User: "user:alice", + Relation: "writer", + Object: "document:roadmap", + }, + }). + Build(), + &checkResp, +) +fmt.Println(*checkResp.Allowed) +``` + +### Streaming request + +```go +channel, err := executor.ExecuteStreaming(ctx, + openfga.NewAPIExecutorRequestBuilder("StreamedListObjects", "POST", "/stores/{store_id}/streamed-list-objects"). + WithPathParameter("store_id", storeID). + WithBody(openfga.ListObjectsRequest{ + User: "user:alice", + Relation: "reader", + Type: "document", + }). + Build(), + openfga.DefaultStreamBufferSize, +) +if err != nil { + log.Fatal(err) +} +defer channel.Close() + +for { + select { + case result, ok := <-channel.Results: + if !ok { + return // Stream completed + } + var response openfga.StreamedListObjectsResponse + json.Unmarshal(result, &response) + fmt.Println(response.Object) + case err := <-channel.Errors: + if err != nil { + log.Fatal(err) + } + } +} +``` + +## Comparison: Client Method vs APIExecutor + +| Feature | Client Methods | APIExecutor | +|---|---|---| +| **Typed responses** | Yes, built-in | Manual decode or `ExecuteWithDecode` | +| **Endpoint hardcoded** | Yes, one method per endpoint | No, you specify path, method, params | +| **Custom endpoints** | No, only known endpoints | Yes, any endpoint | +| **Custom headers** | Via `Options()` | Via `WithHeader()` | +| **Retry logic** | Yes | Yes (same retry config) | +| **Streaming** | `StreamedListObjects()` | `ExecuteStreaming()` | +| **Recommended for** | Production use of known endpoints | Custom/new/experimental endpoints | diff --git a/example/api_executor/go.mod b/example/api_executor/go.mod new file mode 100644 index 0000000..cd54dc0 --- /dev/null +++ b/example/api_executor/go.mod @@ -0,0 +1,22 @@ +module api_executor + +go 1.25.0 + +toolchain go1.25.4 + +require github.com/openfga/go-sdk v0.7.5 + +require ( + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel v1.42.0 // indirect + go.opentelemetry.io/otel/metric v1.42.0 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.9.0 // indirect +) + +replace github.com/openfga/go-sdk => ../.. // added this to point to local module diff --git a/example/api_executor/main.go b/example/api_executor/main.go new file mode 100644 index 0000000..fb8c7d8 --- /dev/null +++ b/example/api_executor/main.go @@ -0,0 +1,362 @@ +// This example demonstrates how to use the low-level APIExecutor to call real +// OpenFGA API endpoints — both standard request/response and streaming. +// +// It exercises Execute, ExecuteWithDecode, and ExecuteStreaming against a live +// OpenFGA server, covering the most common operations: +// +// 1. ListStores — GET /stores +// 2. CreateStore — POST /stores +// 3. GetStore — GET /stores/{store_id} +// 4. WriteAuthModel — POST /stores/{store_id}/authorization-models +// 5. WriteTuples — POST /stores/{store_id}/write +// 6. ReadTuples — POST /stores/{store_id}/read +// 7. Check — POST /stores/{store_id}/check +// 8. ListObjects — POST /stores/{store_id}/list-objects +// 9. StreamedListObj — POST /stores/{store_id}/streamed-list-objects (streaming) +// 10. DeleteStore — DELETE /stores/{store_id} +// +// For the recommended high-level typed approach, see the example1 and +// streamed_list_objects examples. + +package main + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "os" + + openfga "github.com/openfga/go-sdk" + "github.com/openfga/go-sdk/client" +) + +func main() { + ctx := context.Background() + + apiUrl := os.Getenv("FGA_API_URL") + if apiUrl == "" { + apiUrl = "http://localhost:8080" + } + + // We create a thin SDK client only to obtain an APIExecutor. + // All actual API calls below go through the executor. + fgaClient, err := client.NewSdkClient(&client.ClientConfiguration{ + ApiUrl: apiUrl, + }) + if err != nil { + handleError("NewSdkClient", err) + } + + executor := fgaClient.GetAPIExecutor() + + fmt.Println("=== OpenFGA APIExecutor Example ===\n") + + // ----------------------------------------------------------------- + // 1. ListStores (GET /stores) — raw Execute + // ----------------------------------------------------------------- + fmt.Println("1. ListStores (raw Execute)") + listStoresResp, err := executor.Execute(ctx, + openfga.NewAPIExecutorRequestBuilder("ListStores", http.MethodGet, "/stores").Build(), + ) + if err != nil { + handleError("ListStores", err) + } + var listStores openfga.ListStoresResponse + if err := json.Unmarshal(listStoresResp.Body, &listStores); err != nil { + handleError("ListStores decode", err) + } + fmt.Printf(" Status: %d | Stores count: %d\n\n", listStoresResp.StatusCode, len(listStores.Stores)) + + // ----------------------------------------------------------------- + // 2. CreateStore (POST /stores) — ExecuteWithDecode + // ----------------------------------------------------------------- + fmt.Println("2. CreateStore (ExecuteWithDecode)") + var createStore openfga.CreateStoreResponse + createStoreResp, err := executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("CreateStore", http.MethodPost, "/stores"). + WithBody(openfga.CreateStoreRequest{Name: "api-executor-example"}). + Build(), + &createStore, + ) + if err != nil { + handleError("CreateStore", err) + } + storeID := createStore.Id + fmt.Printf(" Status: %d | Store ID: %s | Name: %s\n\n", createStoreResp.StatusCode, storeID, createStore.Name) + + // ----------------------------------------------------------------- + // 3. GetStore (GET /stores/{store_id}) — path parameters + // ----------------------------------------------------------------- + fmt.Println("3. GetStore (path parameters)") + var getStore openfga.GetStoreResponse + getStoreResp, err := executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("GetStore", http.MethodGet, "/stores/{store_id}"). + WithPathParameter("store_id", storeID). + Build(), + &getStore, + ) + if err != nil { + handleError("GetStore", err) + } + fmt.Printf(" Status: %d | Name: %s | Created: %s\n\n", getStoreResp.StatusCode, getStore.Name, getStore.CreatedAt.Format("2006-01-02T15:04:05Z")) + + // ----------------------------------------------------------------- + // 4. WriteAuthorizationModel (POST /stores/{store_id}/authorization-models) + // ----------------------------------------------------------------- + fmt.Println("4. WriteAuthorizationModel") + var writeModelResp openfga.WriteAuthorizationModelResponse + _, err = executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("WriteAuthorizationModel", http.MethodPost, "/stores/{store_id}/authorization-models"). + WithPathParameter("store_id", storeID). + WithBody(openfga.WriteAuthorizationModelRequest{ + SchemaVersion: "1.1", + TypeDefinitions: []openfga.TypeDefinition{ + { + Type: "user", + Relations: &map[string]openfga.Userset{}, + }, + { + Type: "document", + Relations: &map[string]openfga.Userset{ + "reader": {This: &map[string]interface{}{}}, + "writer": {This: &map[string]interface{}{}}, + }, + Metadata: &openfga.Metadata{ + Relations: &map[string]openfga.RelationMetadata{ + "reader": { + DirectlyRelatedUserTypes: &[]openfga.RelationReference{{Type: "user"}}, + }, + "writer": { + DirectlyRelatedUserTypes: &[]openfga.RelationReference{{Type: "user"}}, + }, + }, + }, + }, + }, + }). + Build(), + &writeModelResp, + ) + if err != nil { + handleError("WriteAuthorizationModel", err) + } + modelID := writeModelResp.AuthorizationModelId + fmt.Printf(" Model ID: %s\n\n", modelID) + + // ----------------------------------------------------------------- + // 5. Write tuples (POST /stores/{store_id}/write) + // ----------------------------------------------------------------- + fmt.Println("5. WriteTuples") + _, err = executor.Execute(ctx, + openfga.NewAPIExecutorRequestBuilder("Write", http.MethodPost, "/stores/{store_id}/write"). + WithPathParameter("store_id", storeID). + WithBody(openfga.WriteRequest{ + Writes: &openfga.WriteRequestWrites{ + TupleKeys: []openfga.TupleKey{ + {User: "user:alice", Relation: "writer", Object: "document:roadmap"}, + {User: "user:bob", Relation: "reader", Object: "document:roadmap"}, + }, + }, + AuthorizationModelId: &modelID, + }). + Build(), + ) + if err != nil { + handleError("Write", err) + } + fmt.Println(" Tuples written: user:alice→writer, user:bob→reader on document:roadmap\n") + + // ----------------------------------------------------------------- + // 6. Read tuples (POST /stores/{store_id}/read) + // ----------------------------------------------------------------- + fmt.Println("6. ReadTuples") + var readResp openfga.ReadResponse + _, err = executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("Read", http.MethodPost, "/stores/{store_id}/read"). + WithPathParameter("store_id", storeID). + WithBody(openfga.ReadRequest{ + TupleKey: &openfga.ReadRequestTupleKey{ + Object: openfga.PtrString("document:roadmap"), + }, + }). + Build(), + &readResp, + ) + if err != nil { + handleError("Read", err) + } + fmt.Printf(" Found %d tuple(s):\n", len(readResp.Tuples)) + for _, t := range readResp.Tuples { + fmt.Printf(" - %s is %s of %s\n", t.Key.User, t.Key.Relation, t.Key.Object) + } + fmt.Println() + + // ----------------------------------------------------------------- + // 7. Check (POST /stores/{store_id}/check) — with custom header + // ----------------------------------------------------------------- + fmt.Println("7. Check (with custom header)") + var checkResp openfga.CheckResponse + _, err = executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("Check", http.MethodPost, "/stores/{store_id}/check"). + WithPathParameter("store_id", storeID). + WithHeader("X-Request-ID", "example-check-123"). + WithBody(openfga.CheckRequest{ + TupleKey: openfga.CheckRequestTupleKey{ + User: "user:alice", + Relation: "writer", + Object: "document:roadmap", + }, + AuthorizationModelId: &modelID, + }). + Build(), + &checkResp, + ) + if err != nil { + handleError("Check", err) + } + fmt.Printf(" user:alice writer document:roadmap → Allowed: %v\n", *checkResp.Allowed) + + // Also check a user who should NOT have access + var checkResp2 openfga.CheckResponse + _, err = executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("Check", http.MethodPost, "/stores/{store_id}/check"). + WithPathParameter("store_id", storeID). + WithBody(openfga.CheckRequest{ + TupleKey: openfga.CheckRequestTupleKey{ + User: "user:bob", + Relation: "writer", + Object: "document:roadmap", + }, + AuthorizationModelId: &modelID, + }). + Build(), + &checkResp2, + ) + if err != nil { + handleError("Check (bob)", err) + } + fmt.Printf(" user:bob writer document:roadmap → Allowed: %v\n\n", *checkResp2.Allowed) + + // ----------------------------------------------------------------- + // 8. ListObjects (POST /stores/{store_id}/list-objects) + // ----------------------------------------------------------------- + fmt.Println("8. ListObjects") + var listObjectsResp openfga.ListObjectsResponse + _, err = executor.ExecuteWithDecode(ctx, + openfga.NewAPIExecutorRequestBuilder("ListObjects", http.MethodPost, "/stores/{store_id}/list-objects"). + WithPathParameter("store_id", storeID). + WithBody(openfga.ListObjectsRequest{ + AuthorizationModelId: &modelID, + User: "user:alice", + Relation: "writer", + Type: "document", + }). + Build(), + &listObjectsResp, + ) + if err != nil { + handleError("ListObjects", err) + } + fmt.Printf(" Objects user:alice can write: %v\n\n", listObjectsResp.Objects) + + // ----------------------------------------------------------------- + // 9. StreamedListObjects (POST /stores/{store_id}/streamed-list-objects) + // Write more tuples first so we have something meaningful to stream. + // ----------------------------------------------------------------- + fmt.Println("9. StreamedListObjects (ExecuteStreaming)") + fmt.Println(" Writing 200 additional tuples for streaming demo...") + for batch := 0; batch < 2; batch++ { + tuples := make([]openfga.TupleKey, 0, 100) + for i := 1; i <= 100; i++ { + tuples = append(tuples, openfga.TupleKey{ + User: "user:alice", + Relation: "reader", + Object: fmt.Sprintf("document:doc-%d", batch*100+i), + }) + } + _, err = executor.Execute(ctx, + openfga.NewAPIExecutorRequestBuilder("Write", http.MethodPost, "/stores/{store_id}/write"). + WithPathParameter("store_id", storeID). + WithBody(openfga.WriteRequest{ + Writes: &openfga.WriteRequestWrites{TupleKeys: tuples}, + AuthorizationModelId: &modelID, + }). + Build(), + ) + if err != nil { + handleError("Write (batch)", err) + } + } + + channel, err := executor.ExecuteStreaming(ctx, + openfga.NewAPIExecutorRequestBuilder("StreamedListObjects", http.MethodPost, "/stores/{store_id}/streamed-list-objects"). + WithPathParameter("store_id", storeID). + WithBody(openfga.ListObjectsRequest{ + AuthorizationModelId: &modelID, + User: "user:alice", + Relation: "reader", + Type: "document", + }). + Build(), + openfga.DefaultStreamBufferSize, + ) + if err != nil { + handleError("ExecuteStreaming", err) + } + defer channel.Close() + + count := 0 + for { + select { + case result, ok := <-channel.Results: + if !ok { + select { + case err := <-channel.Errors: + if err != nil { + handleError("StreamedListObjects stream", err) + } + default: + } + fmt.Printf(" ✓ Streamed %d objects\n\n", count) + goto streamDone + } + var obj openfga.StreamedListObjectsResponse + if err := json.Unmarshal(result, &obj); err != nil { + handleError("decode stream result", err) + } + count++ + if count <= 3 || count%50 == 0 { + fmt.Printf(" Object: %s\n", obj.Object) + } + case err := <-channel.Errors: + if err != nil { + handleError("StreamedListObjects error", err) + } + } + } +streamDone: + + // ----------------------------------------------------------------- + // 10. DeleteStore (DELETE /stores/{store_id}) + // ----------------------------------------------------------------- + fmt.Println("10. DeleteStore (cleanup)") + deleteResp, err := executor.Execute(ctx, + openfga.NewAPIExecutorRequestBuilder("DeleteStore", http.MethodDelete, "/stores/{store_id}"). + WithPathParameter("store_id", storeID). + Build(), + ) + if err != nil { + handleError("DeleteStore", err) + } + fmt.Printf(" Status: %d | Store deleted\n\n", deleteResp.StatusCode) + + fmt.Println("=== All examples completed successfully! ===") +} + +func handleError(context string, err error) { + fmt.Fprintf(os.Stderr, "\nError in %s: %v\n", context, err) + fmt.Fprintln(os.Stderr, "\nMake sure OpenFGA is running on localhost:8080 (or set FGA_API_URL)") + fmt.Fprintln(os.Stderr, "Run: docker run -p 8080:8080 openfga/openfga:latest run") + os.Exit(1) +} diff --git a/example/example1/go.mod b/example/example1/go.mod index 6869c08..31f6cff 100644 --- a/example/example1/go.mod +++ b/example/example1/go.mod @@ -18,5 +18,6 @@ require ( go.opentelemetry.io/otel v1.42.0 // indirect go.opentelemetry.io/otel/metric v1.42.0 // indirect go.opentelemetry.io/otel/trace v1.42.0 // indirect - go.uber.org/multierr v1.11.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.9.0 // indirect ) diff --git a/example/opentelemetry/go.mod b/example/opentelemetry/go.mod index d98e776..40da982 100644 --- a/example/opentelemetry/go.mod +++ b/example/opentelemetry/go.mod @@ -29,7 +29,8 @@ require ( go.opentelemetry.io/otel/metric v1.42.0 // indirect go.opentelemetry.io/otel/trace v1.42.0 // indirect go.opentelemetry.io/proto/otlp v1.9.0 // indirect - go.uber.org/multierr v1.11.0 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.9.0 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.32.0 // indirect diff --git a/example/streamed_list_objects/README.md b/example/streamed_list_objects/README.md index 050013f..a06e725 100644 --- a/example/streamed_list_objects/README.md +++ b/example/streamed_list_objects/README.md @@ -1,6 +1,10 @@ # Streamed List Objects Example -Demonstrates using `StreamedListObjects` to retrieve objects via the streaming API in the Go SDK. +Demonstrates using the **concrete `StreamedListObjects` client method** to retrieve objects via the typed high-level streaming API. + +This is the **recommended approach** for calling `StreamedListObjects` — it provides typed `StreamedListObjectsResponse` objects directly, without requiring manual JSON unmarshalling. + +> For an example using the low-level `APIExecutor` with raw NDJSON streaming, see the [`api_executor`](../api_executor/) example. ## What is StreamedListObjects? @@ -64,7 +68,11 @@ type document The `StreamedListObjects` method returns a response with channels, which is the idiomatic Go way to handle streaming data: ```go -response, err := fgaClient.StreamedListObjects(ctx).Body(request).Execute() +response, err := fgaClient.StreamedListObjects(ctx).Body(client.ClientStreamedListObjectsRequest{ + User: "user:anne", + Relation: "can_read", + Type: "document", +}).Execute() if err != nil { log.Fatal(err) } @@ -74,7 +82,7 @@ for obj := range response.Objects { fmt.Printf("Received: %s\n", obj.Object) } -// Check for errors +// Check for errors after the stream completes if err := <-response.Errors; err != nil { log.Fatal(err) } @@ -137,4 +145,3 @@ The example includes robust error handling that: - Detects connection issues - Avoids logging sensitive data - Provides helpful messages for common issues - diff --git a/example/streamed_list_objects/go.mod b/example/streamed_list_objects/go.mod index 5913d60..35bd1f4 100644 --- a/example/streamed_list_objects/go.mod +++ b/example/streamed_list_objects/go.mod @@ -12,27 +12,28 @@ require ( require ( github.com/antlr4-go/antlr/v4 v4.13.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect + github.com/envoyproxy/protoc-gen-validate v1.1.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect - github.com/openfga/api/proto v0.0.0-20251105142303-feed3db3d69d // indirect + github.com/openfga/api/proto v0.0.0-20240905181937-3583905f61a6 // indirect github.com/sourcegraph/conc v0.3.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel v1.42.0 // indirect go.opentelemetry.io/otel/metric v1.42.0 // indirect go.opentelemetry.io/otel/trace v1.42.0 // indirect - go.uber.org/multierr v1.11.0 // indirect - golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect - golang.org/x/net v0.48.0 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect - google.golang.org/grpc v1.79.3 // indirect - google.golang.org/protobuf v1.36.10 // indirect + go.uber.org/atomic v1.7.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20240904232852-e7e105dedf7e // indirect + golang.org/x/net v0.29.0 // indirect + golang.org/x/sys v0.25.0 // indirect + golang.org/x/text v0.18.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + google.golang.org/grpc v1.66.0 // indirect + google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/example/streamed_list_objects/main.go b/example/streamed_list_objects/main.go index fe19dba..06495de 100644 --- a/example/streamed_list_objects/main.go +++ b/example/streamed_list_objects/main.go @@ -1,3 +1,13 @@ +// This example demonstrates how to use the concrete StreamedListObjects client method +// to stream objects via the typed high-level API. +// +// This is the recommended approach for calling StreamedListObjects - it provides +// typed responses (StreamedListObjectsResponse) directly, without requiring +// manual JSON unmarshalling. +// +// For an example using the low-level APIExecutor with raw NDJSON streaming, +// see the api_executor example. + package main import ( @@ -6,9 +16,10 @@ import ( "fmt" "os" + "github.com/openfga/language/pkg/go/transformer" + openfga "github.com/openfga/go-sdk" "github.com/openfga/go-sdk/client" - "github.com/openfga/language/pkg/go/transformer" ) func main() { @@ -21,10 +32,9 @@ func main() { } // Create initial client for store creation - config := client.ClientConfiguration{ + fgaClient, err := client.NewSdkClient(&client.ClientConfiguration{ ApiUrl: apiUrl, - } - fgaClient, err := client.NewSdkClient(&config) + }) if err != nil { handleError(err) return @@ -73,13 +83,16 @@ func main() { return } - fmt.Println("Streaming objects via computed 'can_read' relation...") - if err := streamObjects(ctx, fga); err != nil { + // ========================================================================= + // Using the concrete StreamedListObjects client method (recommended) + // ========================================================================= + fmt.Println("\nStreaming objects via computed 'can_read' relation...") + if err := streamObjectsViaClient(ctx, fga); err != nil { handleError(err) return } - fmt.Println("Cleaning up...") + fmt.Println("\nCleaning up...") if _, err := fga.DeleteStore(ctx).Execute(); err != nil { fmt.Printf("Failed to delete store: %v\n", err) } @@ -87,6 +100,42 @@ func main() { fmt.Println("Done") } +// streamObjectsViaClient demonstrates using the typed StreamedListObjects client method. +// This is the recommended approach - it provides typed responses directly. +func streamObjectsViaClient(ctx context.Context, fga *client.OpenFgaClient) error { + consistency := openfga.CONSISTENCYPREFERENCE_HIGHER_CONSISTENCY + + // Call StreamedListObjects using the fluent client API + response, err := fga.StreamedListObjects(ctx).Body(client.ClientStreamedListObjectsRequest{ + User: "user:anne", + Relation: "can_read", // Computed: owner OR viewer + Type: "document", + }).Options(client.ClientStreamedListObjectsOptions{ + Consistency: &consistency, + }).Execute() + if err != nil { + return fmt.Errorf("StreamedListObjects failed: %w", err) + } + defer response.Close() + + // Process typed responses directly - no manual JSON unmarshalling needed! + count := 0 + for obj := range response.Objects { + count++ + if count <= 3 || count%500 == 0 { + fmt.Printf(" Object: %s\n", obj.Object) + } + } + + // Check for errors after the stream completes + if err := <-response.Errors; err != nil { + return fmt.Errorf("error during streaming: %w", err) + } + + fmt.Printf("✓ Streamed %d objects via client method\n", count) + return nil +} + func writeAuthorizationModel(ctx context.Context, fgaClient *client.OpenFgaClient) (*client.ClientWriteAuthorizationModelResponse, error) { // Define the authorization model using OpenFGA DSL dslString := `model @@ -158,38 +207,6 @@ func writeTuples(ctx context.Context, fga *client.OpenFgaClient) error { return nil } -func streamObjects(ctx context.Context, fga *client.OpenFgaClient) error { - consistencyPreference := openfga.CONSISTENCYPREFERENCE_HIGHER_CONSISTENCY - - response, err := fga.StreamedListObjects(ctx).Body(client.ClientStreamedListObjectsRequest{ - User: "user:anne", - Relation: "can_read", // Computed: owner OR viewer - Type: "document", - }).Options(client.ClientStreamedListObjectsOptions{ - Consistency: &consistencyPreference, - }).Execute() - if err != nil { - return fmt.Errorf("StreamedListObjects failed: %w", err) - } - defer response.Close() - - count := 0 - for obj := range response.Objects { - count++ - if count <= 3 || count%500 == 0 { - fmt.Printf("- %s\n", obj.Object) - } - } - - // Check for streaming errors - if err := <-response.Errors; err != nil { - return fmt.Errorf("error during streaming: %w", err) - } - - fmt.Printf("✓ Streamed %d objects\n", count) - return nil -} - func handleError(err error) { // Avoid logging sensitive data; only display generic info if err.Error() == "connection refused" { diff --git a/streaming.go b/streaming.go deleted file mode 100644 index 6736e7f..0000000 --- a/streaming.go +++ /dev/null @@ -1,251 +0,0 @@ -/** - * Go SDK for OpenFGA - * - * API version: 1.x - * Website: https://openfga.dev - * Documentation: https://openfga.dev/docs - * Support: https://openfga.dev/community - * License: [Apache-2.0](https://github.com/openfga/go-sdk/blob/main/LICENSE) - * - * NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. - */ - -package openfga - -import ( - "bufio" - "context" - "encoding/json" - "errors" - "io" - "net/http" - "net/url" - "strings" -) - -// StreamResult represents a generic streaming result wrapper with either a result or an error -type StreamResult[T any] struct { - Result *T `json:"result,omitempty" yaml:"result,omitempty"` - Error *Status `json:"error,omitempty" yaml:"error,omitempty"` -} - -// StreamingChannel represents a generic channel for streaming responses -type StreamingChannel[T any] struct { - Results chan T - Errors chan error - cancel context.CancelFunc -} - -// Close cancels the streaming context and cleans up resources -func (s *StreamingChannel[T]) Close() { - if s.cancel != nil { - s.cancel() - } -} - -// ProcessStreamingResponse processes an HTTP response as a streaming NDJSON response -// and returns a StreamingChannel with results and errors -// -// Parameters: -// - ctx: The context for cancellation -// - httpResponse: The HTTP response to process -// - bufferSize: The buffer size for the channels (default 10 if <= 0) -// -// Returns: -// - *StreamingChannel[T]: A channel containing streaming results and errors -// - error: An error if the response is invalid -func ProcessStreamingResponse[T any](ctx context.Context, httpResponse *http.Response, bufferSize int) (*StreamingChannel[T], error) { - streamCtx, cancel := context.WithCancel(ctx) - - // Use default buffer size of 10 if not specified or invalid - if bufferSize <= 0 { - bufferSize = 10 - } - - channel := &StreamingChannel[T]{ - Results: make(chan T, bufferSize), - Errors: make(chan error, 1), - cancel: cancel, - } - - if httpResponse == nil || httpResponse.Body == nil { - cancel() - return nil, errors.New("response or response body is nil") - } - - go func() { - defer close(channel.Results) - defer close(channel.Errors) - defer cancel() - defer func() { _ = httpResponse.Body.Close() }() - - scanner := bufio.NewScanner(httpResponse.Body) - // Allow large NDJSON entries (up to 10MB). Tune as needed. - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 10*1024*1024) - - for scanner.Scan() { - select { - case <-streamCtx.Done(): - channel.Errors <- streamCtx.Err() - return - default: - line := scanner.Bytes() - if len(line) == 0 { - continue - } - - var streamResult StreamResult[T] - if err := json.Unmarshal(line, &streamResult); err != nil { - channel.Errors <- err - return - } - - if streamResult.Error != nil { - msg := "stream error" - if streamResult.Error.Message != nil { - msg = *streamResult.Error.Message - } - channel.Errors <- errors.New(msg) - return - } - - if streamResult.Result != nil { - select { - case <-streamCtx.Done(): - channel.Errors <- streamCtx.Err() - return - case channel.Results <- *streamResult.Result: - } - } - } - } - - if err := scanner.Err(); err != nil { - // Prefer context error if we were canceled to avoid surfacing net/http "use of closed network connection". - if streamCtx.Err() != nil { - channel.Errors <- streamCtx.Err() - return - } - channel.Errors <- err - } - }() - - return channel, nil -} - -// StreamedListObjectsChannel maintains backward compatibility with the old channel structure -type StreamedListObjectsChannel struct { - Objects chan StreamedListObjectsResponse - Errors chan error - cancel context.CancelFunc -} - -// Close cancels the streaming context and cleans up resources -func (s *StreamedListObjectsChannel) Close() { - if s.cancel != nil { - s.cancel() - } -} - -// ProcessStreamedListObjectsResponse processes a StreamedListObjects response -// This is a backward compatibility wrapper around ProcessStreamingResponse -func ProcessStreamedListObjectsResponse(ctx context.Context, httpResponse *http.Response, bufferSize int) (*StreamedListObjectsChannel, error) { - channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](ctx, httpResponse, bufferSize) - if err != nil { - return nil, err - } - - // Create a new channel with the old field name for backward compatibility - compatChannel := &StreamedListObjectsChannel{ - Objects: channel.Results, - Errors: channel.Errors, - cancel: channel.cancel, - } - - return compatChannel, nil -} - -// ExecuteStreamedListObjects executes a StreamedListObjects request -func ExecuteStreamedListObjects(client *APIClient, ctx context.Context, storeId string, body ListObjectsRequest, options RequestOptions) (*StreamedListObjectsChannel, error) { - return ExecuteStreamedListObjectsWithBufferSize(client, ctx, storeId, body, options, 0) -} - -// ExecuteStreamedListObjectsWithBufferSize executes a StreamedListObjects request with a custom buffer size -func ExecuteStreamedListObjectsWithBufferSize(client *APIClient, ctx context.Context, storeId string, body ListObjectsRequest, options RequestOptions, bufferSize int) (*StreamedListObjectsChannel, error) { - channel, err := executeStreamingRequest[ListObjectsRequest, StreamedListObjectsResponse]( - client, - ctx, - "/stores/{store_id}/streamed-list-objects", - storeId, - body, - options, - bufferSize, - "StreamedListObjects", - ) - if err != nil { - return nil, err - } - - // Convert to backward-compatible channel structure - return &StreamedListObjectsChannel{ - Objects: channel.Results, - Errors: channel.Errors, - cancel: channel.cancel, - }, nil -} - -// executeStreamingRequest is a generic function to execute streaming requests -func executeStreamingRequest[TReq any, TRes any]( - client *APIClient, - ctx context.Context, - pathTemplate string, - storeId string, - body TReq, - options RequestOptions, - bufferSize int, - operationName string, -) (*StreamingChannel[TRes], error) { - if storeId == "" { - return nil, reportError("storeId is required and must be specified") - } - - path := pathTemplate - path = strings.ReplaceAll(path, "{"+"store_id"+"}", url.PathEscape(parameterToString(storeId, ""))) - - localVarHeaderParams := make(map[string]string) - localVarQueryParams := url.Values{} - - localVarHTTPContentType := "application/json" - localVarHeaderParams["Content-Type"] = localVarHTTPContentType - localVarHeaderParams["Accept"] = "application/x-ndjson" - - for header, val := range options.Headers { - localVarHeaderParams[header] = val - } - - req, err := client.prepareRequest(ctx, path, http.MethodPost, body, localVarHeaderParams, localVarQueryParams) - if err != nil { - return nil, err - } - - httpResponse, err := client.callAPI(req) - if err != nil { - return nil, err - } - if httpResponse == nil { - return nil, errors.New("nil HTTP response from API client") - } - - if httpResponse.StatusCode >= http.StatusMultipleChoices { - responseBody, readErr := io.ReadAll(httpResponse.Body) - _ = httpResponse.Body.Close() - if readErr != nil { - return nil, readErr - } - err = client.handleAPIError(httpResponse, responseBody, body, operationName, storeId) - return nil, err - } - - return ProcessStreamingResponse[TRes](ctx, httpResponse, bufferSize) -} diff --git a/streaming_test.go b/streaming_test.go index 70b0a3b..33c10a7 100644 --- a/streaming_test.go +++ b/streaming_test.go @@ -56,7 +56,7 @@ func TestStreamedListObjectsChannel_Close(t *testing.T) { } } -func TestStreamedListObjectsWithChannel_Success(t *testing.T) { +func TestStreamedListObjectsWithAPI_Success(t *testing.T) { objects := []string{"document:1", "document:2", "document:3"} expectedResults := []string{} for _, obj := range objects { @@ -94,10 +94,12 @@ func TestStreamedListObjectsWithChannel_Success(t *testing.T) { User: "user:anne", } - channel, err := ExecuteStreamedListObjects(client, ctx, "test-store", request, RequestOptions{}) + channel, err := client.OpenFgaApi.StreamedListObjects(ctx, "test-store"). + Body(request). + Execute() if err != nil { - t.Fatalf("ExecuteStreamedListObjects failed: %v", err) + t.Fatalf("StreamedListObjects failed: %v", err) } defer channel.Close() @@ -122,7 +124,7 @@ func TestStreamedListObjectsWithChannel_Success(t *testing.T) { } } -func TestStreamedListObjectsWithChannel_EmptyLines(t *testing.T) { +func TestStreamedListObjectsWithAPI_EmptyLines(t *testing.T) { responseBody := `{"result":{"object":"document:1"}} {"result":{"object":"document:2"}} @@ -152,10 +154,12 @@ func TestStreamedListObjectsWithChannel_EmptyLines(t *testing.T) { User: "user:anne", } - channel, err := ExecuteStreamedListObjects(client, ctx, "test-store", request, RequestOptions{}) + channel, err := client.OpenFgaApi.StreamedListObjects(ctx, "test-store"). + Body(request). + Execute() if err != nil { - t.Fatalf("ExecuteStreamedListObjects failed: %v", err) + t.Fatalf("StreamedListObjects failed: %v", err) } defer channel.Close() @@ -170,7 +174,7 @@ func TestStreamedListObjectsWithChannel_EmptyLines(t *testing.T) { } } -func TestStreamedListObjectsWithChannel_ErrorInStream(t *testing.T) { +func TestStreamedListObjectsWithAPI_ErrorInStream(t *testing.T) { responseBody := `{"result":{"object":"document:1"}} {"error":{"code":500,"message":"Internal error"}}` @@ -197,10 +201,12 @@ func TestStreamedListObjectsWithChannel_ErrorInStream(t *testing.T) { User: "user:anne", } - channel, err := ExecuteStreamedListObjects(client, ctx, "test-store", request, RequestOptions{}) + channel, err := client.OpenFgaApi.StreamedListObjects(ctx, "test-store"). + Body(request). + Execute() if err != nil { - t.Fatalf("ExecuteStreamedListObjects failed: %v", err) + t.Fatalf("StreamedListObjects failed: %v", err) } defer channel.Close() @@ -224,7 +230,7 @@ func TestStreamedListObjectsWithChannel_ErrorInStream(t *testing.T) { } } -func TestStreamedListObjectsWithChannel_InvalidJSON(t *testing.T) { +func TestStreamedListObjectsWithAPI_InvalidJSON(t *testing.T) { responseBody := `{"result":{"object":"document:1"}} invalid json` @@ -251,10 +257,12 @@ invalid json` User: "user:anne", } - channel, err := ExecuteStreamedListObjects(client, ctx, "test-store", request, RequestOptions{}) + channel, err := client.OpenFgaApi.StreamedListObjects(ctx, "test-store"). + Body(request). + Execute() if err != nil { - t.Fatalf("ExecuteStreamedListObjects failed: %v", err) + t.Fatalf("StreamedListObjects failed: %v", err) } defer channel.Close() @@ -274,7 +282,7 @@ invalid json` } } -func TestStreamedListObjectsWithChannel_ContextCancellation(t *testing.T) { +func TestStreamedListObjectsWithAPI_ContextCancellation(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/x-ndjson") w.WriteHeader(http.StatusOK) @@ -304,10 +312,12 @@ func TestStreamedListObjectsWithChannel_ContextCancellation(t *testing.T) { User: "user:anne", } - channel, err := ExecuteStreamedListObjects(client, ctx, "test-store", request, RequestOptions{}) + channel, err := client.OpenFgaApi.StreamedListObjects(ctx, "test-store"). + Body(request). + Execute() if err != nil { - t.Fatalf("ExecuteStreamedListObjects failed: %v", err) + t.Fatalf("StreamedListObjects failed: %v", err) } defer channel.Close() @@ -332,7 +342,7 @@ func TestStreamedListObjectsWithChannel_ContextCancellation(t *testing.T) { } } -func TestStreamedListObjectsWithChannel_CustomBufferSize(t *testing.T) { +func TestStreamedListObjectsWithAPI_ManyObjects(t *testing.T) { numObjects := 100 expectedResults := []string{} for i := 0; i < numObjects; i++ { @@ -363,11 +373,12 @@ func TestStreamedListObjectsWithChannel_CustomBufferSize(t *testing.T) { User: "user:anne", } - // Use custom buffer size of 50 - channel, err := ExecuteStreamedListObjectsWithBufferSize(client, ctx, "test-store", request, RequestOptions{}, 50) + channel, err := client.OpenFgaApi.StreamedListObjects(ctx, "test-store"). + Body(request). + Execute() if err != nil { - t.Fatalf("ExecuteStreamedListObjectsWithBufferSize failed: %v", err) + t.Fatalf("StreamedListObjects failed: %v", err) } defer channel.Close() @@ -403,13 +414,10 @@ func TestProcessStreamingResponse_Generic(t *testing.T) { t.Fatalf("Failed to make request: %v", err) } - ctx := context.Background() - channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](ctx, resp, 10) - + channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](context.Background(), resp, 10) if err != nil { t.Fatalf("ProcessStreamingResponse failed: %v", err) } - defer channel.Close() receivedObjects := []string{} @@ -417,18 +425,31 @@ func TestProcessStreamingResponse_Generic(t *testing.T) { receivedObjects = append(receivedObjects, obj.Object) } - if err := <-channel.Errors; err != nil { - t.Fatalf("Received error from channel: %v", err) - } - if len(receivedObjects) != 2 { t.Fatalf("Expected 2 objects, got %d", len(receivedObjects)) } +} - if receivedObjects[0] != "document:1" { - t.Errorf("Expected document:1, got %s", receivedObjects[0]) +func TestProcessStreamingResponse_NilResponse(t *testing.T) { + channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](context.Background(), nil, 10) + if err == nil { + t.Fatal("Expected error for nil response, got nil") + } + if channel != nil { + t.Fatal("Expected nil channel for nil response") + } +} + +func TestProcessStreamingResponse_NilBody(t *testing.T) { + resp := &http.Response{ + StatusCode: 200, + Body: nil, + } + channel, err := ProcessStreamingResponse[StreamedListObjectsResponse](context.Background(), resp, 10) + if err == nil { + t.Fatal("Expected error for nil body, got nil") } - if receivedObjects[1] != "document:2" { - t.Errorf("Expected document:2, got %s", receivedObjects[1]) + if channel != nil { + t.Fatal("Expected nil channel for nil body") } } diff --git a/utils.go b/utils.go index 06bb218..0d4b6cc 100644 --- a/utils.go +++ b/utils.go @@ -348,7 +348,7 @@ func validatePathParameter(name string, value string) error { return nil } -func validateParameter(name string, value interface{}) error { +func validateParameter[T any](name string, value *T) error { if value == nil { return reportError("%s is required and must be specified", name) } diff --git a/utils_test.go b/utils_test.go index 7240910..7a47cdc 100644 --- a/utils_test.go +++ b/utils_test.go @@ -296,76 +296,24 @@ func TestValidateParameter(t *testing.T) { testCases := []struct { name string - paramName string - paramValue interface{} + call func() error expectError bool errorMsg string }{ { - name: "valid_string_value", - paramName: "body", - paramValue: "test", - expectError: false, - }, - { - name: "valid_int_value", - paramName: "pageSize", - paramValue: 10, - expectError: false, - }, - { - name: "valid_struct_value", - paramName: "request", - paramValue: struct{ Name string }{Name: "test"}, - expectError: false, - }, - { - name: "valid_pointer_value", - paramName: "options", - paramValue: ToPtr("value"), - expectError: false, - }, - { - name: "valid_slice_value", - paramName: "items", - paramValue: []string{"a", "b", "c"}, - expectError: false, - }, - { - name: "valid_map_value", - paramName: "metadata", - paramValue: map[string]string{"key": "value"}, - expectError: false, - }, - { - name: "valid_empty_string_is_not_nil", - paramName: "body", - paramValue: "", - expectError: false, - }, - { - name: "valid_zero_int_is_not_nil", - paramName: "count", - paramValue: 0, - expectError: false, - }, - { - name: "valid_false_bool_is_not_nil", - paramName: "enabled", - paramValue: false, + name: "non_nil_pointer_is_valid", + call: func() error { return validateParameter("body", &CheckRequest{}) }, expectError: false, }, { - name: "nil_value_returns_error", - paramName: "body", - paramValue: nil, + name: "nil_pointer_returns_error", + call: func() error { return validateParameter("body", (*CheckRequest)(nil)) }, expectError: true, errorMsg: "body is required and must be specified", }, { - name: "nil_value_different_parameter_name", - paramName: "request", - paramValue: nil, + name: "nil_pointer_includes_parameter_name_in_error", + call: func() error { return validateParameter("request", (*WriteRequest)(nil)) }, expectError: true, errorMsg: "request is required and must be specified", }, @@ -375,7 +323,7 @@ func TestValidateParameter(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - err := validateParameter(tc.paramName, tc.paramValue) + err := tc.call() if tc.expectError { require.Error(t, err)