diff --git a/deta/__init__.py b/deta/__init__.py index 9ca878e..8a25320 100644 --- a/deta/__init__.py +++ b/deta/__init__.py @@ -2,6 +2,7 @@ import urllib.error import urllib.request import json +from typing import Union from .base import _Base from .drive import _Drive @@ -9,20 +10,19 @@ try: - from detalib.app import App + from detalib.app import App # pyright: ignore app = App() except Exception: pass try: - from ._async.client import AsyncBase + from ._async.client import AsyncBase # pyright: ignore except ImportError: pass __version__ = "1.2.0" - def Base(name: str): project_key, project_id = _get_project_key_id() return _Base(name, project_key, project_id) @@ -34,20 +34,20 @@ def Drive(name: str): class Deta: - def __init__(self, project_key: str = None, *, project_id: str = None): + def __init__(self, project_key: Union[str, None] = None, *, project_id: Union[str, None] = None): project_key, project_id = _get_project_key_id(project_key, project_id) self.project_key = project_key self.project_id = project_id - def Base(self, name: str, host: str = None): + def Base(self, name: str, host: Union[str, None] = None): return _Base(name, self.project_key, self.project_id, host) - def AsyncBase(self, name: str, host: str = None): + def AsyncBase(self, name: str, host: Union[str, None] = None): from ._async.client import _AsyncBase return _AsyncBase(name, self.project_key, self.project_id, host) - def Drive(self, name: str, host: str = None): + def Drive(self, name: str, host: Union[str, None] = None): return _Drive( name=name, project_key=self.project_key, @@ -73,9 +73,12 @@ def send_email(to, subject, message, charset="UTF-8"): "charset": charset, } + assert api_key + headers = {"X-API-Key": api_key} - req = urllib.request.Request(endpoint, json.dumps(data).encode("utf-8"), headers) + req = urllib.request.Request( + endpoint, json.dumps(data).encode("utf-8"), headers) try: resp = urllib.request.urlopen(req) diff --git a/deta/_async/client.py b/deta/_async/client.py index af2c8f9..cfeb466 100644 --- a/deta/_async/client.py +++ b/deta/_async/client.py @@ -1,10 +1,10 @@ -import typing - +from typing import Union, List import datetime import os -import aiohttp from urllib.parse import quote +import aiohttp + from deta.utils import _get_project_key_id from deta.base import FetchResponse, Util, insert_ttl, BASE_TTL_ATTTRIBUTE @@ -15,7 +15,7 @@ def AsyncBase(name: str): class _AsyncBase: - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Union[str, None] = None): if not project_key: raise AssertionError("No Base name provided") @@ -56,11 +56,11 @@ async def delete(self, key: str): async def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if not isinstance(data, dict): data = {"value": data} @@ -70,7 +70,8 @@ async def insert( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) async with self._session.post( f"{self._base_url}/items", json={"item": data} ) as resp: @@ -78,11 +79,11 @@ async def insert( async def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if not isinstance(data, dict): data = {"value": data} @@ -92,8 +93,12 @@ async def put( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) - async with self._session.put(f"{self._base_url}/items", json={"items": [data]}) as resp: + + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) + async with self._session.put( + f"{self._base_url}/items", json={"items": [data]} + ) as resp: if resp.status == 207: resp_json = await resp.json() if "processed" in resp_json: @@ -102,10 +107,10 @@ async def put( async def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: List[Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if len(items) > 25: raise AssertionError("We can't put more than 25 items at a time.") @@ -126,10 +131,10 @@ async def put_many( async def fetch( self, - query: typing.Union[dict, list] = None, + query: Union[dict, list, None] = None, *, limit: int = 1000, - last: str = None, + last: Union[str, None] = None, desc: bool = False, ): payload = {} @@ -154,8 +159,8 @@ async def update( updates: dict, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if key == "": raise ValueError("Key is empty") diff --git a/deta/base.py b/deta/base.py index 6ae2813..7028d4e 100644 --- a/deta/base.py +++ b/deta/base.py @@ -1,7 +1,6 @@ import os import datetime -from re import I -import typing +from typing import Union, List from urllib.parse import quote from .service import _Service, JSON_MIME @@ -62,18 +61,18 @@ def __init__(self, value): def trim(self): return self.Trim() - def increment(self, value: typing.Union[int, float] = None): + def increment(self, value: Union[int, float, None] = None): return self.Increment(value) - def append(self, value: typing.Union[dict, list, str, int, float, bool]): + def append(self, value: Union[dict, list, str, int, float, bool]): return self.Append(value) - def prepend(self, value: typing.Union[dict, list, str, int, float, bool]): + def prepend(self, value: Union[dict, list, str, int, float, bool]): return self.Prepend(value) class _Base(_Service): - def __init__(self, name: str, project_key: str, project_id: str, host: str = None): + def __init__(self, name: str, project_key: str, project_id: str, host: Union[str, None] = None): assert name, "No Base name provided" host = host or os.getenv("DETA_BASE_HOST") or "database.deta.sh" @@ -110,11 +109,11 @@ def delete(self, key: str): def insert( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): if not isinstance(data, dict): data = {"value": data} @@ -124,7 +123,8 @@ def insert( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) code, res = self._request( "/items", "POST", {"item": data}, content_type=JSON_MIME ) @@ -135,11 +135,11 @@ def insert( def put( self, - data: typing.Union[dict, list, str, int, bool], - key: str = None, + data: Union[dict, list, str, int, bool], + key: Union[str, None] = None, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): """store (put) an item in the database. Overrides an item if key already exists. `key` could be provided as function argument or a field in the data dict. @@ -154,11 +154,13 @@ def put( if key: data["key"] = key - insert_ttl(data, self.__ttl_attribute, expire_in=expire_in, expire_at=expire_at) + insert_ttl(data, self.__ttl_attribute, + expire_in=expire_in, expire_at=expire_at) code, res = self._request( "/items", "PUT", {"items": [data]}, content_type=JSON_MIME ) + if code == 207 and "processed" in res: return res["processed"]["items"][0] else: @@ -166,10 +168,10 @@ def put( def put_many( self, - items: typing.List[typing.Union[dict, list, str, int, bool]], + items: List[Union[dict, list, str, int, bool]], *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): assert len(items) <= 25, "We can't put more than 25 items at a time." _items = [] @@ -189,9 +191,9 @@ def put_many( def _fetch( self, - query: typing.Union[dict, list] = None, - buffer: int = None, - last: str = None, + query: Union[dict, list, None] = None, + buffer: Union[int, None] = None, + last: Union[str, None] = None, desc: bool = False, ) -> typing.Optional[typing.Tuple[int, list]]: """This is where actual fetch happens.""" @@ -204,34 +206,43 @@ def _fetch( if query: payload["query"] = query if isinstance(query, list) else [query] - code, res = self._request("/query", "POST", payload, content_type=JSON_MIME) - return code, res + _, res = self._request( + "/query", "POST", payload, content_type=JSON_MIME) + + assert res + + return res def fetch( self, - query: typing.Union[dict, list] = None, + query: Union[dict, list, None] = None, *, limit: int = 1000, - last: str = None, + last: Union[str, None] = None, desc: bool = False, + ): """ fetch items from the database. `query` is an optional filter or list of filters. Without filter, it will return the whole db. """ + _, res = self._fetch(query, limit, last, desc) - paging = res.get("paging") - return FetchResponse(paging.get("size"), paging.get("last"), res.get("items")) + paging = res.get("paging") # pyright: ignore + + return FetchResponse(paging.get("size"), + paging.get("last"), + res.get("items")) # pyright: ignore def update( self, updates: dict, key: str, *, - expire_in: int = None, - expire_at: typing.Union[int, float, datetime.datetime] = None, + expire_in: Union[int, None] = None, + expire_at: Union[int, float, datetime.datetime, None] = None, ): """ update an item in the database diff --git a/deta/drive.py b/deta/drive.py index 2d4db38..f50b56b 100644 --- a/deta/drive.py +++ b/deta/drive.py @@ -1,5 +1,5 @@ import os -import typing +from typing import Union, List from io import BufferedIOBase, TextIOBase, RawIOBase, StringIO, BytesIO from urllib.parse import quote_plus @@ -20,7 +20,7 @@ def __init__(self, res: BufferedIOBase): def closed(self): return self.__stream.closed - def read(self, size: int = None): + def read(self, size: Union[int, None] = None): return self.__stream.read(size) def iter_chunks(self, chunk_size: int = 1024): @@ -29,7 +29,7 @@ def iter_chunks(self, chunk_size: int = 1024): if not chunk: break yield chunk - + def iter_lines(self, chunk_size: int = 1024): while True: chunk = self.__stream.readline(chunk_size) @@ -48,14 +48,17 @@ def close(self): class _Drive(_Service): def __init__( self, - name: str = None, - project_key: str = None, - project_id: str = None, - host: str = None, + name: Union[str, None] = None, + project_key: Union[str, None] = None, + project_id: Union[str, None] = None, + host: Union[str, None] = None, ): assert name, "No Drive name provided" host = host or os.getenv("DETA_DRIVE_HOST") or "drive.deta.sh" + assert project_key, "Project key must be provided" + assert project_id, "Project id must be provided" + super().__init__( project_key=project_key, project_id=project_id, @@ -78,10 +81,10 @@ def get(self, name: str): f"/files/download?name={self._quote(name)}", "GET", stream=True ) if res: - return DriveStreamingBody(res) + return DriveStreamingBody(res) # pyright: ignore return None - def delete_many(self, names: typing.List[str]): + def delete_many(self, names: List[str]): """Delete many files from drive in single request. `names` are the names of the files to be deleted. Returns a dict with 'deleted' and 'failed' files. @@ -99,13 +102,18 @@ def delete(self, name: str): Returns the name of the file deleted. """ assert name, "Name not provided or empty" + payload = self.delete_many([name]) - failed = payload.get("failed") + + failed = payload.get("failed") # pyright: ignore + if failed: raise Exception(f"Failed to delete '{name}':{failed[name]}") + return name - def list(self, limit: int = 1000, prefix: str = None, last: str = None): + def list(self, limit: int = 1000, prefix: Union[str, None] = None, + last: Union[str, None] = None): """List file names from drive. `limit` is the limit of number of file names to get, defaults to 1000. `prefix` is the prefix of file names. @@ -122,21 +130,22 @@ def list(self, limit: int = 1000, prefix: str = None, last: str = None): def _start_upload(self, name: str): _, res = self._request(f"/uploads?name={self._quote(name)}", "POST") - return res["upload_id"] + return res["upload_id"] # pyright: ignore def _finish_upload(self, name: str, upload_id: str): self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "PATCH") def _abort_upload(self, name: str, upload_id: str): - self._request(f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") + self._request( + f"/uploads/{upload_id}?name={self._quote(name)}", "DELETE") def _upload_part( self, name: str, - chunk: bytes, + chunk: Union[bytes, str], upload_id: str, part: int, - content_type: str = None, + content_type: Union[str, None] = None, ): self._request( f"/uploads/{upload_id}/parts?name={self._quote(name)}&part={part}", @@ -146,7 +155,7 @@ def _upload_part( ) def _get_content_stream( - self, data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] + self, data: Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] ): if isinstance(data, str): return StringIO(data) @@ -157,10 +166,11 @@ def _get_content_stream( def put( self, name: str, - data: typing.Union[str, bytes, TextIOBase, BufferedIOBase, RawIOBase] = None, + data: Union[str, bytes, TextIOBase, + BufferedIOBase, RawIOBase, None] = None, *, - path: str = None, - content_type: str = None, + path: Union[str, None] = None, + content_type: Union[str, None] = None, ) -> str: """Put a file in drive. `name` is the name of the file. @@ -175,13 +185,18 @@ def put( # start upload upload_id = self._start_upload(name) - content_stream = open(path, "rb") if path else self._get_content_stream(data) + if path: + content_stream = open(path, "rb") + else: + assert data + content_stream = self._get_content_stream(data) + part = 1 # upload chunks while True: chunk = content_stream.read(UPLOAD_CHUNK_SIZE) - ## eof stop the loop + # eof stop the loop if not chunk: self._finish_upload(name, upload_id) content_stream.close() diff --git a/deta/service.py b/deta/service.py index b90441b..89015f7 100644 --- a/deta/service.py +++ b/deta/service.py @@ -3,7 +3,7 @@ import json import socket import struct -import typing +from typing import Union import urllib.error from pathlib import Path @@ -33,16 +33,17 @@ def __init__( self.host = host self.timeout = timeout self.keep_alive = keep_alive - self.client = ( - http.client.HTTPSConnection(host, timeout=timeout) if keep_alive else None - ) + self.client = (http.client.HTTPSConnection( + host, timeout=timeout) if keep_alive else None) def _is_socket_closed(self): - if not self.client.sock: + if not self.client or not self.client.sock: return True + fmt = "B" * 7 + "I" * 21 tcp_info = struct.unpack( - fmt, self.client.sock.getsockopt(socket.IPPROTO_TCP, socket.TCP_INFO, 92) + fmt, self.client.sock.getsockopt( + socket.IPPROTO_TCP, socket.TCP_INFO, 92) ) # 8 = CLOSE_WAIT if len(tcp_info) > 0 and tcp_info[0] == 8: @@ -53,16 +54,20 @@ def _request( self, path: str, method: str, - data: typing.Union[str, bytes, dict] = None, - headers: dict = None, - content_type: str = None, + data: Union[str, bytes, dict, None] = None, + headers: Union[dict, None] = None, + content_type: Union[str, None] = None, stream: bool = False, ): + url = self.base_path + path + headers = headers or {} headers["X-Api-Key"] = self.project_key + if content_type: headers["Content-Type"] = content_type + if not self.keep_alive: headers["Connection"] = "close" @@ -85,40 +90,45 @@ def _request( # response res = self._send_request_with_retry(method, url, headers, body) + + assert res + status = res.status if status not in [200, 201, 202, 207]: # need to read the response so subsequent requests can be sent on the client res.read() - if not self.keep_alive: + if not self.keep_alive and self.client: self.client.close() - ## return None if not found + # return None if not found if status == 404: return status, None fp = res.fp if res.fp is not None else '' # FIXME: workaround to fix traceback printing for HTTPError raise urllib.error.HTTPError(url, status, res.reason, res.headers, fp) - ## if stream return the response and client without reading and closing the client + + # if stream return the response and client without reading and closing the client if stream: return status, res - ## return json if application/json - payload = ( - json.loads(res.read()) - if JSON_MIME in res.getheader("content-type") - else res.read() - ) + # return json if application/json + res_content_type = res.getheader("content-type") + if res_content_type and JSON_MIME in res_content_type: + payload = json.loads(res.read()) + else: + payload = res.read() - if not self.keep_alive: + if not self.keep_alive and self.client: self.client.close() + return status, payload def _send_request_with_retry( self, method: str, url: str, - headers: dict = None, - body: typing.Union[str, bytes, dict] = None, + headers: Union[dict, None] = None, + body: Union[str, bytes, dict, None] = None, retry=2, # try at least twice to regain a new connection ): reinitializeConnection = False @@ -129,6 +139,11 @@ def _send_request_with_retry( host=self.host, timeout=self.timeout ) + if headers is None: + headers = {} + + assert self.client + self.client.request( method, url, @@ -137,6 +152,7 @@ def _send_request_with_retry( ) res = self.client.getresponse() return res + except http.client.RemoteDisconnected: reinitializeConnection = True retry -= 1 diff --git a/deta/utils.py b/deta/utils.py index 5bd127d..97a17cb 100644 --- a/deta/utils.py +++ b/deta/utils.py @@ -1,7 +1,9 @@ import os +from typing import Union -def _get_project_key_id(project_key: str = None, project_id: str = None): +def _get_project_key_id(project_key: Union[str, None] = None, + project_id: Union[str, None] = None): project_key = project_key or os.getenv("DETA_PROJECT_KEY", "") if not project_key: