Skip to content

Commit 1b9ed94

Browse files
author
Ronen Hilewicz
committed
Expose InvalidArgmentError
1 parent 1b5fa08 commit 1b9ed94

File tree

10 files changed

+526
-416
lines changed

10 files changed

+526
-416
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ dist
22
__pycache__
33
.DS_Store
44
.env
5+
.envrc
56
.mypy_cache
67
.pytest_cache
78
.vscode
89
.coverage
910
.python-version
1011
.ext
12+
.dmypy.json

poetry.lock

Lines changed: 274 additions & 246 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,16 @@ packages = [
3131
[tool.poetry.dependencies]
3232
python = "^3.9"
3333
aiohttp = "^3.10.2"
34-
grpcio = "^1.64.1"
35-
protobuf = "^5.27.2"
36-
aserto-authorizer = "^0.20.3"
37-
aserto-directory = "^0.33.5"
38-
certifi = ">=2024.8.30"
34+
aserto-directory = "^0.33.6"
35+
aserto-authorizer = "^0.20.5"
3936

4037
[tool.poetry.group.dev.dependencies]
41-
black = "^24.0"
42-
isort= "^5.9.0"
38+
black = "^25.1.0"
39+
isort = "^6.0.1"
4340
pytest-asyncio = "^0.23"
4441
pyright = "^1.1.0"
4542
requests = "^2.31.0"
43+
grpc-stubs = ">=1.53.0.5"
4644

4745
[tool.black]
4846
line-length = 100

src/aserto/client/directory/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Literal, Tuple
22

3+
from grpc import RpcError, StatusCode
4+
35
from aserto.client.directory.channels import Channels
46

57
__all__ = ["Channels"]
@@ -9,6 +11,10 @@ class NotFoundError(Exception):
911
pass
1012

1113

14+
class InvalidArgumentError(Exception):
15+
pass
16+
17+
1218
class ConfigError(Exception):
1319
pass
1420

@@ -20,3 +26,10 @@ def get_metadata(api_key, tenant_id) -> Tuple[Tuple[str, str], ...]:
2026
if tenant_id:
2127
md += (("aserto-tenant-id", tenant_id),)
2228
return md
29+
30+
31+
def translate_rpc_error(err: RpcError) -> None:
32+
if err.code() == StatusCode.NOT_FOUND:
33+
raise NotFoundError from err
34+
if err.code() == StatusCode.INVALID_ARGUMENT:
35+
raise InvalidArgumentError from err

src/aserto/client/directory/aio/__init__.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,41 @@
66
def build_grpc_channel(address: str, ca_cert_path: str) -> Optional[grpc_aio.Channel]:
77
if address == "":
88
return None
9-
9+
1010
return grpc_aio.secure_channel(
11-
target=address,
11+
target=address,
1212
credentials=channel_credentials(cert=ca_cert_path),
1313
)
1414

15+
1516
class Channels:
1617
def __init__(
17-
self,
18-
ca_cert_path: str,
19-
default_address: str = "",
20-
reader_address: str = "",
21-
writer_address: str = "",
22-
importer_address: str = "",
23-
exporter_address: str = "",
24-
model_address: str = "",
25-
) -> None:
26-
validate_addresses(address=default_address, reader_address=reader_address, writer_address=writer_address,
27-
importer_address=importer_address, exporter_address=exporter_address, model_address=model_address)
28-
29-
self._addresses = [default_address, reader_address, writer_address, importer_address, exporter_address, model_address]
18+
self,
19+
ca_cert_path: str,
20+
default_address: str = "",
21+
reader_address: str = "",
22+
writer_address: str = "",
23+
importer_address: str = "",
24+
exporter_address: str = "",
25+
model_address: str = "",
26+
) -> None:
27+
validate_addresses(
28+
address=default_address,
29+
reader_address=reader_address,
30+
writer_address=writer_address,
31+
importer_address=importer_address,
32+
exporter_address=exporter_address,
33+
model_address=model_address,
34+
)
35+
36+
self._addresses = [
37+
default_address,
38+
reader_address,
39+
writer_address,
40+
importer_address,
41+
exporter_address,
42+
model_address,
43+
]
3044
self._channels = dict()
3145
for x in self._addresses:
3246
if x and x not in self._channels:
@@ -37,13 +51,13 @@ def get(self, address: str, default_address: str) -> Optional[grpc_aio.Channel]:
3751
return self._channels[address]
3852
if default_address != "":
3953
return self._channels[default_address]
40-
41-
return None
4254

55+
return None
4356

4457
async def close(self) -> None:
4558
for x in self._addresses:
4659
if x != "" and self._channels[x] is not None:
47-
await self._channels[x].close()
60+
await self._channels[x].close()
61+
4862

49-
__all__ = ["Channels"]
63+
__all__ = ["Channels"]

src/aserto/client/directory/channels.py

Lines changed: 44 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,41 +8,65 @@ def validate_addresses(
88
writer_address: str,
99
importer_address: str,
1010
exporter_address: str,
11-
model_address: str) -> None:
12-
if address == "" and reader_address == "" and writer_address == "" and importer_address == "" and exporter_address == "" and model_address == "":
11+
model_address: str,
12+
) -> None:
13+
if (
14+
address == ""
15+
and reader_address == ""
16+
and writer_address == ""
17+
and importer_address == ""
18+
and exporter_address == ""
19+
and model_address == ""
20+
):
1321
raise ValueError("at least one directory service address must be specified")
1422

23+
1524
def channel_credentials(cert) -> ChannelCredentials:
1625
if cert:
1726
with open(cert, "rb") as f:
1827
return ssl_channel_credentials(f.read())
1928
else:
2029
return ssl_channel_credentials()
21-
30+
31+
2232
def build_grpc_channel(address: str, ca_cert_path: str) -> Optional[Channel]:
2333
if address == "":
2434
return None
25-
35+
2636
return secure_channel(
27-
target=address,
37+
target=address,
2838
credentials=channel_credentials(cert=ca_cert_path),
2939
)
3040

41+
3142
class Channels:
3243
def __init__(
33-
self,
34-
ca_cert_path: str,
35-
default_address: str = "",
36-
reader_address: str = "",
37-
writer_address: str = "",
38-
importer_address: str = "",
39-
exporter_address: str = "",
40-
model_address: str = "",
41-
) -> None:
42-
validate_addresses(address=default_address, reader_address=reader_address, writer_address=writer_address,
43-
importer_address=importer_address, exporter_address=exporter_address, model_address=model_address)
44-
45-
self._addresses = [default_address, reader_address, writer_address, importer_address, exporter_address, model_address]
44+
self,
45+
ca_cert_path: str,
46+
default_address: str = "",
47+
reader_address: str = "",
48+
writer_address: str = "",
49+
importer_address: str = "",
50+
exporter_address: str = "",
51+
model_address: str = "",
52+
) -> None:
53+
validate_addresses(
54+
address=default_address,
55+
reader_address=reader_address,
56+
writer_address=writer_address,
57+
importer_address=importer_address,
58+
exporter_address=exporter_address,
59+
model_address=model_address,
60+
)
61+
62+
self._addresses = [
63+
default_address,
64+
reader_address,
65+
writer_address,
66+
importer_address,
67+
exporter_address,
68+
model_address,
69+
]
4670
self._channels = dict()
4771
for x in self._addresses:
4872
if x and x not in self._channels:
@@ -53,11 +77,10 @@ def get(self, address: str, default_address: str) -> Optional[Channel]:
5377
return self._channels[address]
5478
if default_address != "":
5579
return self._channels[default_address]
56-
57-
return None
5880

81+
return None
5982

6083
def close(self) -> None:
6184
for x in self._addresses:
6285
if x != "" and self._channels[x] is not None:
63-
self._channels[x].close()
86+
self._channels[x].close()

0 commit comments

Comments
 (0)