Skip to content

Commit 5a01472

Browse files
author
Evan Sims
committed
feat: further type hinting improvements
1 parent b48718e commit 5a01472

5 files changed

Lines changed: 134 additions & 68 deletions

File tree

openfga_sdk/api/open_fga_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from openfga_sdk.models.write_authorization_model_request import (
4040
WriteAuthorizationModelRequest,
4141
)
42+
from openfga_sdk.models.write_request import WriteRequest
4243
from openfga_sdk.protocols import (
4344
ApiClientResponseProtocol,
4445
ApiResponseProtocol,
@@ -1371,14 +1372,14 @@ async def streamed_list_objects_with_http_info(self, body, **kwargs):
13711372

13721373
async def write(
13731374
self,
1375+
body: WriteRequest,
13741376
options: WriteRequestOptions | None = None,
13751377
) -> ApiResponseProtocol:
13761378
options: WriteRequestOptions = WriteRequestOptions() | (
13771379
options or WriteRequestOptions()
13781380
)
13791381
response_types = self.build_response_types(tuple([200, "object"]))
13801382

1381-
body = None
13821383
query = RestClientRequestQueryParameters.from_options(options)
13831384
fields = RestClientRequestFieldParameters()
13841385
streaming = False

openfga_sdk/client/client.py

Lines changed: 56 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,8 @@ async def list_stores(
127127
self,
128128
options: ListStoresRequestOptions | None = None,
129129
) -> ListStoresResponse | ApiClientResponseProtocol:
130-
options: ListStoresRequestOptions = (
131-
ListStoresRequestOptions() | (options or ListStoresRequestOptions())
130+
options: ListStoresRequestOptions = ListStoresRequestOptions() | (
131+
options or ListStoresRequestOptions()
132132
)
133133

134134
options.headers.add_header(
@@ -149,8 +149,8 @@ async def create_store(
149149
body: CreateStoreRequest,
150150
options: CreateStoreRequestOptions | None = None,
151151
) -> CreateStoreResponse | ApiClientResponseProtocol:
152-
options: CreateStoreRequestOptions = (
153-
CreateStoreRequestOptions() | (options or CreateStoreRequestOptions())
152+
options: CreateStoreRequestOptions = CreateStoreRequestOptions() | (
153+
options or CreateStoreRequestOptions()
154154
)
155155

156156
options.headers.add_header(
@@ -171,8 +171,8 @@ async def get_store(
171171
store_id: str | None = None,
172172
options: GetStoreRequestOptions | None = None,
173173
) -> GetStoreResponse | ApiClientResponseProtocol:
174-
options: GetStoreRequestOptions = (
175-
GetStoreRequestOptions() | (options or GetStoreRequestOptions())
174+
options: GetStoreRequestOptions = GetStoreRequestOptions() | (
175+
options or GetStoreRequestOptions()
176176
)
177177

178178
options.store_id = store_id or options.store_id or self.configuration.store_id
@@ -195,8 +195,8 @@ async def delete_store(
195195
store_id: str | None = None,
196196
options: DeleteStoreRequestOptions | None = None,
197197
) -> None | ApiClientResponseProtocol:
198-
options: DeleteStoreRequestOptions = (
199-
DeleteStoreRequestOptions() | (options or DeleteStoreRequestOptions())
198+
options: DeleteStoreRequestOptions = DeleteStoreRequestOptions() | (
199+
options or DeleteStoreRequestOptions()
200200
)
201201

202202
options.store_id = store_id or options.store_id or self.configuration.store_id
@@ -224,7 +224,8 @@ async def read_authorization_models(
224224
options: ReadAuthorizationModelsRequestOptions | None = None,
225225
) -> None | ApiClientResponseProtocol:
226226
options: ReadAuthorizationModelsRequestOptions = (
227-
ReadAuthorizationModelsRequestOptions() | (options or ReadAuthorizationModelsRequestOptions())
227+
ReadAuthorizationModelsRequestOptions()
228+
| (options or ReadAuthorizationModelsRequestOptions())
228229
)
229230

230231
options.store_id = store_id or options.store_id or self.configuration.store_id
@@ -248,7 +249,8 @@ async def write_authorization_model(
248249
options: WriteAuthorizationModelRequestOptions | None = None,
249250
) -> WriteAuthorizationModelResponse | ApiClientResponseProtocol:
250251
options: WriteAuthorizationModelRequestOptions = (
251-
WriteAuthorizationModelRequestOptions() | (options or WriteAuthorizationModelRequestOptions())
252+
WriteAuthorizationModelRequestOptions()
253+
| (options or WriteAuthorizationModelRequestOptions())
252254
)
253255

254256
options.headers.add_header(
@@ -270,7 +272,8 @@ async def read_authorization_model(
270272
options: ReadAuthorizationModelRequestOptions | None = None,
271273
) -> ReadAuthorizationModelResponse | ApiClientResponseProtocol:
272274
options: ReadAuthorizationModelRequestOptions = (
273-
ReadAuthorizationModelRequestOptions() | (options or ReadAuthorizationModelRequestOptions())
275+
ReadAuthorizationModelRequestOptions()
276+
| (options or ReadAuthorizationModelRequestOptions())
274277
)
275278

276279
options.authorization_model_id = (
@@ -303,7 +306,8 @@ async def read_latest_authorization_model(
303306
options: ReadLatestAuthorizationModelRequestOptions | None = None,
304307
) -> ReadAuthorizationModelResponse | ApiClientResponseProtocol:
305308
options: ReadAuthorizationModelsRequestOptions = (
306-
ReadAuthorizationModelsRequestOptions() | (options or ReadLatestAuthorizationModelRequestOptions())
309+
ReadAuthorizationModelsRequestOptions()
310+
| (options or ReadLatestAuthorizationModelRequestOptions())
307311
)
308312
options.page_size = 1
309313

@@ -334,8 +338,8 @@ async def read_changes(
334338
body: ClientReadChangesRequest,
335339
options: ReadChangesRequestOptions | None = None,
336340
):
337-
options: ReadChangesRequestOptions = (
338-
ReadChangesRequestOptions() | (options or ReadChangesRequestOptions())
341+
options: ReadChangesRequestOptions = ReadChangesRequestOptions() | (
342+
options or ReadChangesRequestOptions()
339343
)
340344

341345
options.headers.add_header(
@@ -356,8 +360,8 @@ async def read(
356360
tuple_key: ReadRequestTupleKey | None = None,
357361
options: ReadRequestOptions | None = None,
358362
):
359-
options: ReadRequestOptions = (
360-
ReadRequestOptions() | (options or ReadRequestOptions())
363+
options: ReadRequestOptions = ReadRequestOptions() | (
364+
options or ReadRequestOptions()
361365
)
362366

363367
options.headers.add_header(
@@ -385,8 +389,8 @@ async def _write_with_transaction(
385389
body: ClientWriteRequest,
386390
options: WriteRequestOptions | None = None,
387391
):
388-
options: WriteRequestOptions = (
389-
WriteRequestOptions() | (options or WriteRequestOptions())
392+
options: WriteRequestOptions = WriteRequestOptions() | (
393+
options or WriteRequestOptions()
390394
)
391395

392396
authorization_model_id = (
@@ -422,8 +426,8 @@ async def _write_single_batch(
422426
is_write: bool,
423427
options: WriteRequestOptions | None = None,
424428
):
425-
options: WriteRequestOptions = (
426-
WriteRequestOptions() | (options or WriteRequestOptions())
429+
options: WriteRequestOptions = WriteRequestOptions() | (
430+
options or WriteRequestOptions()
427431
)
428432

429433
try:
@@ -453,8 +457,8 @@ async def _write_batches(
453457
is_write: bool,
454458
options: WriteRequestOptions | None = None,
455459
) -> list[ClientWriteSingleResponse] | None:
456-
options: WriteRequestOptions = (
457-
WriteRequestOptions() | (options or WriteRequestOptions())
460+
options: WriteRequestOptions = WriteRequestOptions() | (
461+
options or WriteRequestOptions()
458462
)
459463

460464
if tuple_keys is None:
@@ -498,8 +502,8 @@ async def write(
498502
body: ClientWriteRequest,
499503
options: WriteRequestOptions | None = None,
500504
):
501-
options: WriteRequestOptions = (
502-
WriteRequestOptions() | (options or WriteRequestOptions())
505+
options: WriteRequestOptions = WriteRequestOptions() | (
506+
options or WriteRequestOptions()
503507
)
504508

505509
options.headers.use_bulk_request_id()
@@ -523,8 +527,8 @@ async def write_tuples(
523527
body: list[ClientTuple],
524528
options: WriteRequestOptions | None = None,
525529
):
526-
options: WriteRequestOptions = (
527-
WriteRequestOptions() | (options or WriteRequestOptions())
530+
options: WriteRequestOptions = WriteRequestOptions() | (
531+
options or WriteRequestOptions()
528532
)
529533

530534
options.headers.add_header(
@@ -540,8 +544,8 @@ async def delete_tuples(
540544
body: list[ClientTuple],
541545
options: WriteRequestOptions | None = None,
542546
):
543-
options: WriteRequestOptions = (
544-
WriteRequestOptions() | (options or WriteRequestOptions())
547+
options: WriteRequestOptions = WriteRequestOptions() | (
548+
options or WriteRequestOptions()
545549
)
546550

547551
options.headers.add_header(
@@ -560,8 +564,8 @@ async def check(
560564
body: ClientCheckRequest,
561565
options: CheckRequestOptions | None = None,
562566
) -> CheckResponse | ApiClientResponseProtocol:
563-
options: CheckRequestOptions = (
564-
CheckRequestOptions() | (options or CheckRequestOptions())
567+
options: CheckRequestOptions = CheckRequestOptions() | (
568+
options or CheckRequestOptions()
565569
)
566570

567571
options.headers.add_header(
@@ -603,8 +607,8 @@ async def _single_client_batch_check(
603607
semaphore: asyncio.Semaphore,
604608
options: BatchCheckRequestOptions | None = None,
605609
) -> ClientBatchCheckClientResponse:
606-
options: BatchCheckRequestOptions = (
607-
BatchCheckRequestOptions() | (options or BatchCheckRequestOptions())
610+
options: BatchCheckRequestOptions = BatchCheckRequestOptions() | (
611+
options or BatchCheckRequestOptions()
608612
)
609613

610614
await semaphore.acquire()
@@ -635,8 +639,8 @@ async def client_batch_check(
635639
body: list[ClientCheckRequest],
636640
options: BatchCheckRequestOptions | None = None,
637641
) -> list[ClientBatchCheckClientResponse]:
638-
options: BatchCheckRequestOptions = (
639-
BatchCheckRequestOptions() | (options or BatchCheckRequestOptions())
642+
options: BatchCheckRequestOptions = BatchCheckRequestOptions() | (
643+
options or BatchCheckRequestOptions()
640644
)
641645

642646
options.headers.use_bulk_request_id()
@@ -662,8 +666,8 @@ async def _single_batch_check(
662666
semaphore: asyncio.Semaphore,
663667
options: BatchCheckRequestOptions | None = None,
664668
) -> BatchCheckResponse:
665-
options: BatchCheckRequestOptions = (
666-
BatchCheckRequestOptions() | (options or BatchCheckRequestOptions())
669+
options: BatchCheckRequestOptions = BatchCheckRequestOptions() | (
670+
options or BatchCheckRequestOptions()
667671
)
668672

669673
await semaphore.acquire()
@@ -682,8 +686,8 @@ async def batch_check(
682686
body: ClientBatchCheckRequest,
683687
options: BatchCheckRequestOptions | None = None,
684688
):
685-
options: BatchCheckRequestOptions = (
686-
BatchCheckRequestOptions() | (options or BatchCheckRequestOptions())
689+
options: BatchCheckRequestOptions = BatchCheckRequestOptions() | (
690+
options or BatchCheckRequestOptions()
687691
)
688692

689693
options.headers.use_bulk_request_id()
@@ -764,8 +768,8 @@ async def expand(
764768
body: ClientExpandRequest,
765769
options: ExpandRequestOptions | None = None,
766770
):
767-
options: ExpandRequestOptions = (
768-
ExpandRequestOptions() | (options or ExpandRequestOptions())
771+
options: ExpandRequestOptions = ExpandRequestOptions() | (
772+
options or ExpandRequestOptions()
769773
)
770774

771775
authorization_model_id = (
@@ -798,8 +802,8 @@ async def list_objects(
798802
body: ClientListObjectsRequest,
799803
options: ListObjectsRequestOptions | None = None,
800804
):
801-
options: ListObjectsRequestOptions = (
802-
ListObjectsRequestOptions() | (options or ListObjectsRequestOptions())
805+
options: ListObjectsRequestOptions = ListObjectsRequestOptions() | (
806+
options or ListObjectsRequestOptions()
803807
)
804808

805809
authorization_model_id = (
@@ -832,8 +836,8 @@ async def streamed_list_objects(
832836
body: ClientListObjectsRequest,
833837
options: ListObjectsRequestOptions | None = None,
834838
):
835-
options: ListObjectsRequestOptions = (
836-
ListObjectsRequestOptions() | (options or ListObjectsRequestOptions())
839+
options: ListObjectsRequestOptions = ListObjectsRequestOptions() | (
840+
options or ListObjectsRequestOptions()
837841
)
838842

839843
authorization_model_id = (
@@ -863,8 +867,8 @@ async def list_relations(
863867
body: ClientListRelationsRequest,
864868
options: ListRelationsRequestOptions | None = None,
865869
):
866-
options: ListRelationsRequestOptions = (
867-
ListRelationsRequestOptions() | (options or ListRelationsRequestOptions())
870+
options: ListRelationsRequestOptions = ListRelationsRequestOptions() | (
871+
options or ListRelationsRequestOptions()
868872
)
869873

870874
options.headers.use_bulk_request_id()
@@ -896,8 +900,8 @@ async def list_users(
896900
body: ClientListUsersRequest,
897901
options: ListUsersRequestOptions | None = None,
898902
):
899-
options: ListUsersRequestOptions = (
900-
ListUsersRequestOptions() | (options or ListUsersRequestOptions())
903+
options: ListUsersRequestOptions = ListUsersRequestOptions() | (
904+
options or ListUsersRequestOptions()
901905
)
902906

903907
authorization_model_id = (
@@ -931,8 +935,8 @@ async def read_assertions(
931935
self,
932936
options: ReadAssertionsRequestOptions | None = None,
933937
):
934-
options: ReadAssertionsRequestOptions = (
935-
ReadAssertionsRequestOptions() | (options or ReadAssertionsRequestOptions())
938+
options: ReadAssertionsRequestOptions = ReadAssertionsRequestOptions() | (
939+
options or ReadAssertionsRequestOptions()
936940
)
937941

938942
authorization_model_id = (
@@ -951,8 +955,8 @@ async def write_assertions(
951955
body: list[ClientAssertion],
952956
options: WriteAssertionsRequestOptions | None = None,
953957
):
954-
options: WriteAssertionsRequestOptions = (
955-
WriteAssertionsRequestOptions() | (options or WriteAssertionsRequestOptions())
958+
options: WriteAssertionsRequestOptions = WriteAssertionsRequestOptions() | (
959+
options or WriteAssertionsRequestOptions()
956960
)
957961

958962
authorization_model_id = (

openfga_sdk/common/options.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ class BatchCheckRequestOptions(StoreRequestOptionsBase):
9595
@dataclass
9696
class WriteRequestOptions(StoreRequestOptionsBase):
9797
transaction: WriteTransactionOptions = field(
98-
default_factory=WriteTransactionOptions()
98+
default_factory=WriteTransactionOptions
9999
)
100100

101101

test/endpoints/async/test_write_api.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from openfga_sdk.common.api_client import ApiClientResponse
2121
from openfga_sdk.common.open_fga_api import ApiResponse
2222
from openfga_sdk.common.options import WriteRequestOptions
23+
from openfga_sdk.models.write_request import WriteRequest
2324
from openfga_sdk.protocols import (
2425
ApiClientResponseProtocol,
2526
ApiResponseProtocol,
@@ -30,21 +31,13 @@
3031

3132

3233
@pytest.fixture
33-
def rest_client_response(
34-
mock_api_response_write,
35-
mock_api_continuation_token,
36-
) -> RestClientResponseProtocol:
34+
def rest_client_response() -> RestClientResponseProtocol:
3735
mock_response = MagicMock()
3836
mock_response.status = 200
3937

40-
api_response = {
41-
"tuples": mock_api_response_write,
42-
"continuation_token": mock_api_continuation_token,
43-
}
44-
4538
return RestClientResponse(
4639
response=mock_response,
47-
data=orjson.dumps(api_response),
40+
data="",
4841
status=200,
4942
reason="OK",
5043
)
@@ -60,10 +53,9 @@ def api_client_response(
6053
@pytest.fixture
6154
def expected_response(
6255
api_client_response,
63-
mock_api_response_write_deserialized,
6456
) -> ApiResponseProtocol:
6557
response: ApiResponseProtocol = ApiResponse() | api_client_response
66-
response.deserialized = mock_api_response_write_deserialized
58+
response.deserialized = None
6759
return response
6860

6961

@@ -98,7 +90,7 @@ async def test_write_issues_request(
9890
with patch.object(
9991
api.api_client, **api_client_request_conditions
10092
) as api_request:
101-
await api.write(options)
93+
await api.write(WriteRequest(), options)
10294

10395
api_request.assert_called_once()
10496

@@ -116,6 +108,6 @@ async def test_write_returns_expected_response(
116108
)
117109

118110
with patch.object(api.api_client, **api_client_request_conditions):
119-
response = await api.write(options)
111+
response = await api.write(WriteRequest(), options)
120112

121113
assert response == expected_response

0 commit comments

Comments
 (0)