diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..8cc1835 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +# EditorConfig is awesome: https://EditorConfig.org + +# top-most EditorConfig file +root = true + +# Unix-style newlines with a newline ending every file +[*] +end_of_line = lf +insert_final_newline = true + +# 4 space indentation +[*.py] +indent_style = space +indent_size = 4 +charset = utf-8 +trim_trailing_whitespace = true + +# Matches the exact files either package.json or .travis.yml +[*.yml] +indent_style = space +indent_size = 2 \ No newline at end of file diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..ee3427a --- /dev/null +++ b/.envrc @@ -0,0 +1,7 @@ +# Adds the bin directory to the environment +PATH_add bin + +# Source the necessary environment for all users +set -o allexport +source .env +set +o allexport \ No newline at end of file diff --git a/.gitignore b/.gitignore index 5ca2879..ff89b36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +build/ +dist/ +docs/build + .DS_store .eggs .tox @@ -6,7 +10,13 @@ *.pyc .pytest_cache .mypy_cache\ -.env .cache -*.todo -mypy.ini +coverage.xml +.coverage + +.env +.vscode +.venv +.python-version + +test.py \ No newline at end of file diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..d21a930 --- /dev/null +++ b/.travis.yml @@ -0,0 +1,19 @@ +dist: bionic +language: python +git: + depth: 2 +python: + - "3.7" + - "3.8" +matrix: + include: + python: "3.7" + env: TOXENV=lint +install: + - pip install . + - pip install tox tox-travis +script: tox +after_success: + - pip install pytest pytest-asyncio aiohttp coveralls + - coverage run --source=airbase setup.py test + - coveralls \ No newline at end of file diff --git a/LICENSE b/LICENSE index bde57d5..60c0e48 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2020 Luis Felipe Paris +Copyright (c) 2022 Luis Felipe Paris Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..c1a7121 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include LICENSE +include README.md diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6e317ac --- /dev/null +++ b/Makefile @@ -0,0 +1,32 @@ +# Colours +NC=\033[0m\n +HIGHLIGHT=\033[91m + +## usage: print useful commands +usage: + @echo "$(HIGHLIGHT)Choose a command: $(PWD) $(NC)" + @bash -c "sed -ne 's/^##//p' ./Makefile | column -t -s ':' | sed -e 's/^/ /'" + +## release: Release new version +release: + python setup.py sdist bdist_wheel --universal + twine upload ./dist/* + make clean + +## test: Run tests +test: + tox + make clean + +## lint: Lint and format +lint: + flake8 . + black --line-length 79 --check . + +## clean: delete python artifacts +clean: + python -c "import pathlib; [p.unlink() for p in pathlib.Path('.').rglob('*.py[co]')]" + python -c "import pathlib; [p.rmdir() for p in pathlib.Path('.').rglob('pytest_cache')]" + rm -rdf ./dist + rm -rdf ./build + rm -rdf airbase.egg-info \ No newline at end of file diff --git a/README.md b/README.md index f9b1fd8..554a127 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,26 @@ -# Airtable Python Wrapper +# Asynchronous Airtable Python Wrapper +[![Python 3.7](https://img.shields.io/badge/python-3.7-blue.svg)](https://www.python.org/downloads/release/python-370) +[![Python 3.8](https://img.shields.io/badge/python-3.8-blue.svg)](https://www.python.org/downloads/release/python-380) + +[![PyPI version](https://badge.fury.io/py/airtable-async.svg)](https://badge.fury.io/py/airtable-async) +[![PyPI - Downloads](https://img.shields.io/pypi/dm/airtable-async.svg?label=pypi%20downloads)](https://pypi.org/project/airtable-async/) +[![Build Status](https://travis-ci.org/lfparis/airbase.svg?branch=master)](https://travis-ci.org/lfparis/airbase) +[![Coverage Status](https://coveralls.io/repos/github/lfparis/airbase/badge.svg?branch=master)](https://coveralls.io/github/lfparis/airbase?branch=master) + +## Installing +```bash +pip install airtable-async +``` +Requirements: Python 3.7+ + +## Documentation +*coming soon* + +## Example -### Asynchronous -Requires CPython 3.8 ```python import asyncio -from airbase import AirtableAsync as Airtable +from airbase import Airtable api_key = "your Airtable API key found at https://airtable.com/account" base_key = "name or id of a base" @@ -13,6 +29,7 @@ table_key = "name or id of a table in that base" async def main() -> None: async with Airtable(api_key=api_key) as at: + at: Airtable # Get all bases for a user await at.get_bases() @@ -83,7 +100,7 @@ async def main() -> None: "id": "record id", "fields": {"field1": "value1", "field2": "value2"}, } - await table.update_record() + await table.update_record(record) # Update several records in that table records = [ { @@ -99,142 +116,26 @@ async def main() -> None: "fields": {"field1": "value1", "field2": "value2"}, }, ] - await table.update_records() + await table.update_records(records) # Delete a record in that table record = { "id": "record id", } - await table.delete_record() + await table.delete_record(record) # Delete several records in that table records = [ {"id": "record id"}, {"id": "record id"}, {"id": "record id"}, ] - await table.delete_records() + await table.delete_records(records) if __name__ == "__main__": asyncio.run(main()) ``` -### Synchronous -Works in ironpython and cpython 2.7 and beyond -```python -from airbase import Airtable - -api_key = "your Airtable API key found at https://airtable.com/account" -base_key = "id of a base" -table_key = "name of a table in that base" - - -def main() -> None: - # NOT IMPLEMENTED - with Airtable(api_key=api_key) as at: - at = Airtable(api_key=api_key) - - # Get all bases for a user - at.get_bases() - - # NOT IMPLEMENTED - Get one base by name - # base = at.get_base(base_key, key="name") - # Get one base by id - base = at.get_base(base_key) - # NOT IMPLEMENTED - Get one base by either id or name - # base = at.get_base(base_key) - - # Base Attributes - print(base.id) - print(base.name) - print(base.permission_level) - - # Set base logging level (debug, info, warning, error, etc) - # Default is "info" - base.log = "debug" - - # Get all tables for a base - base.get_tables() - - # Get one table by name - table = base.get_table(table_key) - # NOT IMPLEMENTED - Get one table by id - # table = at.get_table(table_key, key="id") - # NOT IMPLEMENTED - Get one table by either id or name - # table = at.get_table(table_key) - - # Base Attributes - print(table.base) - print(table.name) - print(table.id) - print(table.primary_field_id) - # print(table.primary_field_name) - print(table.fields) - print(table.views) - - # Get a record in that table - table_record = table.get_record("record_id") - # Get all records in that table - table_records = table.get_records() - # NOT IMPLEMENTED - Get all records in a view in that table - # view_records = table.get_records(view="view id or name") - # Get only certain fields for all records in that table - reduced_fields_records = table.get_records( - filter_by_fields=["field1, field2"] - ) - # Get all records in that table that pass a formula - filtered_records = table.get_records( - filter_by_formula="Airtable Formula" - ) - - # Post a record in that table - record = {"fields": {"field1": "value1", "field2": "value2"}} - table.post_record(record) - # Post several records in that table - records = [ - {"fields": {"field1": "value1", "field2": "value2"}}, - {"fields": {"field1": "value1", "field2": "value2"}}, - {"fields": {"field1": "value1", "field2": "value2"}}, - ] - table.post_records(records) - - # Update a record in that table - record = { - "id": "record id", - "fields": {"field1": "value1", "field2": "value2"}, - } - table.update_record() - # Update several records in that table - records = [ - { - "id": "record id", - "fields": {"field1": "value1", "field2": "value2"}, - }, - { - "id": "record id", - "fields": {"field1": "value1", "field2": "value2"}, - }, - { - "id": "record id", - "fields": {"field1": "value1", "field2": "value2"}, - }, - ] - table.update_records() - - # Delete a record in that table - record = { - "id": "record id", - } - table.delete_record() - # NOT IMPLEMENTED - Delete several records in that table - # records = [ - # {"id": "record id"}, - # {"id": "record id"}, - # {"id": "record id"}, - # ] - # table.delete_records() - - -if __name__ == "__main__": - main() +## License -``` \ No newline at end of file +[MIT](https://opensource.org/licenses/MIT) \ No newline at end of file diff --git a/airbase/__init__.py b/airbase/__init__.py index 8a36149..026e9ac 100644 --- a/airbase/__init__.py +++ b/airbase/__init__.py @@ -1,3 +1,2 @@ from __future__ import absolute_import from .airtable import Airtable # noqa: F401 -from .airtable_async import Airtable as AirtableAsync # noqa: F401 diff --git a/airbase/airtable.py b/airbase/airtable.py index 18be640..6f5663b 100644 --- a/airbase/airtable.py +++ b/airbase/airtable.py @@ -1,87 +1,310 @@ -from __future__ import absolute_import +from __future__ import absolute_import, annotations import os +import re import urllib -from .session import Session -from .utils import Logger +from aiohttp import ( + ClientConnectionError, + ClientConnectorError, + ClientSession, + ClientTimeout, + ContentTypeError, + TCPConnector, + ClientResponse, +) +from asyncio import TimeoutError, sleep +from json.decoder import JSONDecodeError +from typing import Any, Dict, Iterable, List, Optional, Union + +from .decorators import chunkify +from .exceptions import AirbaseResponseException, AirbaseException +from .utils import Logger, HTTPSemaphore from .urls import BASE_URL, META_URL +from .validations import validate_records -logger = Logger.start(__name__) +class BaseAirtable: + retries = 5 + def __init__( + self, + logging_level: str = "info", + raise_for_status: bool = False, + verbose: bool = False, + ) -> None: + """ + Airtable Base Class + + Kwargs: + raise_for_status (``string``): Raise if the response status not in 200s. + verbose (``string``): Log stack trace + + """ # noqa: E501 + self.logging_level = logging_level + self.logger = Logger.start(str(self), level=logging_level) + self.raise_for_status = raise_for_status + self.verbose = verbose + + def __str__(self): + obj = re.search(r"(?<=\.)[\w\d_]*(?='>$)", str(self.__class__))[0] + if getattr(self, "name", None): + return f"<{obj}:'{getattr(self, 'name')}' at {hex(id(self))}>" + else: + return f"<{obj} at {hex(id(self))}>" + + def _is_success(self, res: Optional[ClientResponse]) -> bool: + if res and res.status >= 200 and res.status < 300: + return True + else: + return False -class Airtable(object): - session = Session(base_url=BASE_URL) + async def _get_data(self, res: ClientResponse) -> Union[Dict, str, bytes]: + try: + return await res.json(encoding="utf-8") # dict + # else if raw data + except JSONDecodeError: + return await res.text(encoding="utf-8") # string + except ContentTypeError: + return await res.read() # bytes + + async def _request(self, *args, **kwargs) -> Optional[ClientResponse]: + count = 0 + while True: + try: + res = await self._session.request(*args, **kwargs) + err = False + except ( + ClientConnectionError, + ClientConnectorError, + TimeoutError, + ): + err = True + + if err or res.status in (408, 429, 500, 502, 503, 504): + delay = (2**count) * 0.51 + count += 1 + if count > self.retries: + # res may not be defined at this point + # res.raise_for_status() + return None + else: + await sleep(delay) + else: + return res - def __init__(self, api_key=None): + async def raise_or_log_error( + self, + response: Optional[ClientResponse] = None, + error_msg: Optional[str] = None, + ) -> None: + error_msg = ( + await self.get_error_message(response) if response else error_msg + ) + if self.raise_for_status: + if response: + raise AirbaseResponseException(error_msg, response=response) + else: + raise AirbaseException(error_msg) + else: + self.logger.error( + error_msg, exc_info=self.verbose, stack_info=self.verbose + ) + + async def get_error_message(self, response: ClientResponse) -> str: + data = await self._get_data(response) + + error = data.get("error") if data else None + error_type = None + error_message = None + if error: + if isinstance(error, dict): + error_type = error.get("type") + error_message = error.get("message") + elif isinstance(error, str): + error_type = error + + if "meta" in response.url.parts: + obj = response.url.parts[-1] + base = ( + f"'{{'base': '{response.url.parts[4]}'}}'" + if len(response.url.parts) > 4 + else None + ) + resource_info = f" {obj}{' from ' if base else ''}{base if base else ''}" # noqa: E501 + + else: + preposition = "from" if response.method.lower() == "get" else "to" + base = f"'base': '{response.url.parts[2]}'" + table = f"'table': '{response.url.parts[3]}'" + resource_info = f" record{'s' if len(response.url.parts) < 5 else ''} {preposition}: '{{{base}, {table}}}'" # noqa: E501 + + if error_type and error_message: + message = f"{{'type': '{error_type}', 'message': '{error_message}'}}" # noqa: E501 + elif error_type: + message = f"{{'type': '{error_type}'}}" + elif error_message: + message = f"{{'message': '{error_message}'}}" + else: + message = "No error info" + + return f"{response.status} - {response.reason}: Failed to {response.method.lower()}{resource_info} -> '{message}'" # noqa: E501 + + +class Airtable(BaseAirtable): + def __init__( + self, api_key: str = None, timeout: int = 300, **kwargs + ) -> None: """ Airtable class for multiple bases Kwargs: api_key (``string``): Airtable API Key. - """ + timeout (``int``): a ClientTimeout settings structure. 300 seconds (5min) total timeout by default + """ # noqa: E501 + super().__init__(**kwargs) self.api_key = api_key + self.timeout = timeout + self.semaphore = HTTPSemaphore(value=50, interval=1, max_calls=5) + self.open() + + async def __aenter__(self): + return self + + async def __aexit__(self, *err): + await self.close() @property - def api_key(self): + def api_key(self) -> str: if getattr(self, "_api_key", None): return self._api_key @api_key.setter - def api_key(self, key): + def api_key(self, key: str) -> None: self._api_key = key or str(os.environ.get("AIRTABLE_API_KEY")) - self.auth = {"Authorization": "Bearer {}".format(self.api_key)} + self.auth = {"Authorization": f"Bearer {self.api_key}"} + + def open(self) -> None: + conn = TCPConnector(limit=100) + timeout = ClientTimeout(total=self.timeout) + self._session = ClientSession( + connector=conn, + headers=self.auth, + timeout=timeout, + # raise_for_status=self.raise_for_status, + ) - @staticmethod - def _compose_url(base_id, table_name): - """ - Composes the airtable url. + async def close(self) -> None: + await self._session.close() + self._session = None + + async def get_bases(self) -> Optional[List[Base]]: + async with self.semaphore: + url = f"{META_URL}/bases" + res = await self._request("get", url) + + data = await self._get_data(res) + if self._is_success(res): + self.bases = [ + Base( + base["id"], + name=base["name"], + permission_level=base["permissionLevel"], + session=self._session, + logging_level=self.logging_level, + raise_for_status=self.raise_for_status, + verbose=self.verbose, + ) + for base in data["bases"] + ] + self._bases_by_id = {base.id: base for base in self.bases} + self._bases_by_name = {base.name: base for base in self.bases} + self.logger.info(f"Fetched: {len(self.bases)} bases") - Returns: - url (``string``): Composed url. - """ - try: - table_name = urllib.parse.quote(table_name) - except Exception: - table_name = urllib.pathname2url(table_name) - return "{}/{}/{}".format(BASE_URL, base_id, table_name) - - def get_bases(self): - url = "{}/bases".format(META_URL) - data, success = self.session.request("get", url, headers=self.auth) - if success: - self.bases = [ - Base( - base["id"], - self.auth, - name=base["name"], - permission_level=base["permissionLevel"], - logging_level="info", + else: + await self.raise_or_log_error(response=res) + self.bases = None + return self.bases + + async def get_base( + self, value: str, key: Optional[str] = None + ) -> Optional[Base]: + assert key in (None, "id", "name") + if not getattr(self, "bases", None): + if key == "id": + self.logger.info(f"Created Base object with id: {value}") + return Base( + base_id=value, + session=self._session, + logging_level=self.logging_level, + raise_for_status=self.raise_for_status, + verbose=self.verbose, ) - for base in data["bases"] - ] - - def get_base(self, base_id, logging_level="info"): - return Base(base_id, self.auth, logging_level=logging_level) - - def get_enterprise_account( - self, enterprise_account_id, logging_level="info" - ): - url = "{}/enterpriseAccounts/{}".format( - META_URL, enterprise_account_id - ) - data, success = self.session.request("get", url, headers=self.auth) - if success: + await self.get_bases() + if self.bases: + if key == "name": + base = self._bases_by_name.get(value) + elif key == "id": + base = self._bases_by_id.get(value) + else: + bases = [ + base + for base in self.bases + if base.name == value or base.id == value + ] + if bases: + base = bases[0] + if base: + self.logger.info( + f"Fetched Base with {key if key else 'value'}: '{value}'" + ) + return base + else: + error_msg = f"Base with {key if key else 'value'}:'{value}' not found" # noqa: E501 + await self.raise_or_log_error(error_msg=error_msg) + return None + + async def get_enterprise_account( + self, enterprise_account_id + ) -> Optional[Account]: + url = f"{META_URL}/enterpriseAccounts/{enterprise_account_id}" + res = await self._session.request("get", url) + if self._is_success(res): + data = await self._get_data(res) + self.logger.info(f"Fetched Account with id: '{data.get('id')}'") return Account( - data["id"], self.auth, data, logging_level=logging_level + data["id"], + data, + session=self._session, + logging_level=self.logging_level, + ) + else: + await self.raise_or_log_error(res=res) + return None + + async def get_table( + self, base_id: str, table_name: str + ) -> Optional[Table]: + base = await self.get_base(value=base_id, key="id") + if base: + self.logger.info(f"Created Table object with name: '{table_name}'") + return Table( + base, + table_name, + logging_level=self.logging_level, + raise_for_status=self.raise_for_status, + verbose=self.verbose, ) + else: + error_msg = f"Failed to create Table object with name: {table_name} for base with id:'{base_id}'" # noqa: E501 + await self.raise_or_log_error(error_msg=error_msg) + return None -class Account(Airtable): +class Account(BaseAirtable): def __init__( - self, enterprise_account_id, auth, data=None, logging_level="info", - ): + self, enterprise_account_id, data=None, session=None, **kwargs + ) -> None: """ Airtable class for an Enterprise Account. https://airtable.com/api/enterprise @@ -92,10 +315,10 @@ def __init__( Kwargs: logging_level (``string``, default="info"): """ + super().__init__(**kwargs) self.id = enterprise_account_id - self.url = "{}/enterpriseAccounts/{}".format(META_URL, self.id) - self.auth = auth - self.logging_level = logging_level + self.url = f"{META_URL}/enterpriseAccounts/{self.id}" + self._session = session if data: self.workspace_ids = data.get("workspaceIds") self.user_ids = data.get("userIds") @@ -103,15 +326,15 @@ def __init__( self.created_time = data.get("createdTime") -class Base(Airtable): +class Base(BaseAirtable): def __init__( self, base_id, - auth, name=None, permission_level=None, - logging_level="info", - ): + session=None, + **kwargs, + ) -> None: """ Airtable class for one base. @@ -122,43 +345,95 @@ def __init__( api_key (``string``): Airtable API Key. log (``bool``, default=True): If True it logs succesful API calls. """ + super().__init__(**kwargs) self.id = base_id - self.url = "{}/bases/{}".format(META_URL, self.id) self.name = name self.permission_level = permission_level - self.auth = auth - self.log = logging_level - - def get_tables(self): - url = "{}/tables".format(self.url) - data, success = self.session.request("get", url, headers=self.auth) - if success: - self.tables = [ - Table( + self.url = f"{META_URL}/bases/{self.id}" + + self._session = session + self.semaphore = HTTPSemaphore(value=50, interval=1, max_calls=5) + + async def get_tables(self) -> Optional[List[Table]]: # noqa: F821 + async with self.semaphore: + url = f"{self.url}/tables" + res = await self._request("get", url) + if self._is_success(res): + data = await self._get_data(res) + self.tables = [ + Table( + self, + table["name"], + table_id=table["id"], + primary_field_id=table["primaryFieldId"], + fields=table["fields"], + views=table["views"], + logging_level=self.logging_level, + raise_for_status=self.raise_for_status, + verbose=self.verbose, + ) + for table in data["tables"] + ] + self._tables_by_id = {table.id: table for table in self.tables} + self._tables_by_name = { + table.name: table for table in self.tables + } + self.logger.info(f"Fetched: {len(self.tables)} bases") + else: + await self.raise_or_log_error(response=res) + self.tables = None + return self.tables + + async def get_table( + self, value: str, key: Optional[str] = None + ) -> Optional[Table]: + assert key in (None, "id", "name") + if not getattr(self, "tables", None): + if key == "name": + self.logger.info(f"Created Table object with name: {value}") + return Table( self, - table["name"], - table_id=table["id"], - primary_field_id=table["primaryFieldId"], - fields=table["fields"], - views=table["views"], + value, + logging_level=self.logging_level, + raise_for_status=self.raise_for_status, + verbose=self.verbose, ) - for table in data["tables"] - ] - - def get_table(self, table_name): - return Table(self, table_name) + await self.get_tables() + if self.tables: + if key == "name": + table = self._tables_by_name.get(value) + elif key == "id": + table = self._tables_by_id.get(value) + else: + tables = [ + table + for table in self.tables + if table.name == value or table.id == value + ] + if tables: + table = tables[0] + if table: + self.logger.info( + f"Fetched Table with {key if key else 'value'}: {value}" + ) + return table + else: + error_msg = f"Table with {key if key else 'value'}:'{value}' not found" # noqa: E501 + await self.raise_or_log_error(error_msg=error_msg) + return None -class Table(Airtable): +class Table(BaseAirtable): def __init__( self, - base, - name, - table_id=None, - primary_field_id=None, - fields=None, - views=None, - ): + base: Base, + name: str, + table_id: str = None, + primary_field_id: str = None, + fields: List[Dict[str, str]] = None, + views: List[Dict[str, str]] = None, + **kwargs, + ) -> None: """ Airtable class for one table in one base @@ -167,13 +442,14 @@ def __init__( name (``string``): Name of target table. """ + super().__init__(**kwargs) self.base = base self.name = name self.id = table_id self.primary_field_id = primary_field_id self.fields = fields self.views = views - self.url = self._compose_url(self.base.id, self.name) + self.url = self._compose_url() self.primary_field_name = ( [ field["name"] @@ -183,9 +459,10 @@ def __init__( if self.fields and self.primary_field_id else None ) + self._session: ClientSession = base._session @staticmethod - def _basic_log_msg(content): + def _basic_log_msg(content: Iterable) -> str: """ Constructs a basic logger message """ @@ -194,12 +471,12 @@ def _basic_log_msg(content): plural = "s" else: plural = "" - message = "{} record{}".format(len(content), plural) + message = f"{len(content)} record{plural}" else: message = "1 record" return message - def _add_record_to_url(self, record_id): + def _add_record_to_url(self, record_id: str) -> str: """ Composes the airtable url with a record id @@ -208,9 +485,152 @@ def _add_record_to_url(self, record_id): Returns: url (``string``): Composed url. """ - return "{}/{}".format(self.url, record_id) + return f"{self.url}/{record_id}" + + def _compose_url(self) -> str: + """ + Composes the airtable url. - def get_record(self, record_id): + Returns: + url (``string``): Composed url. + """ + return f"{BASE_URL}/{self.base.id}/{urllib.parse.quote(self.name)}" + + def _get_record_primary_key_value_or_id(self, record: dict) -> str: + if ( + record.get("fields") + and self.primary_field_name + and record["fields"].get(self.primary_field_name) + ): + return record["fields"][self.primary_field_name] + else: + return record.get("id") + + async def _request_record( + self, + method: str, + record: Dict, + typecast: bool = False, + **kwargs, + ) -> Dict: + """ + Posts/Gets/Patches/Deletes a record in a table. + + Args: + method (``str``): 'post', 'patch' or 'delete; + record (``dictionary``): a recoprd to CRUD. + Kwargs: + typecast (``bool``, optional): if True, payload can create new options in singleSelect and multipleSelects fields + """ # noqa: E501 + + operation = method + content_type = "application/json" + url = self._add_record_to_url(record.get("id")) + data = {} + + # CREATE + if method == "post": + data = {"fields": record["fields"]} + url = self.url + # READ + elif method == "get": + operation = "fetch" + # UPDATE + elif method == "patch": + operation = "update" + data = {"fields": record["fields"]} + url = self._add_record_to_url(record.get("id")) + # DELETE + elif method == "delete": + content_type = "application/x-www-form-urlencoded" + else: + raise AirbaseException("Invalid HTTP method") + + headers = {"Content-Type": content_type} + + if typecast: + data["typecast"] = True + async with self.base.semaphore: + res = await self._session.request( + method, url, json=data, headers=headers + ) + data = await self._get_data(res) + message = self._get_record_primary_key_value_or_id( + data + ) or self._basic_log_msg(data) + + if self._is_success(res): + self.logger.info( + f"{operation.title()}{'e' if operation[-1] != 'e' else ''}d: {message}" # noqa: E501 + ) + else: + await self.raise_or_log_error(response=res) + return data + + @chunkify + async def _request_records( + self, + method: str, + records: Iterable[Dict], + typecast: bool = False, + **kwargs, + ) -> Dict: + """ + Posts/Patches/Deletes records to a table in batches of 10. + + Args: + method (``str``): 'post', 'patch' or 'delete; + records (``list``): a list of records (``dictionary``) to CRUD. + Kwargs: + typecast (``bool``, optional): if True, payload can create new options in singleSelect and multipleSelects fields + """ # noqa: E501 + + operation = method + content_type = "application/json" + data = {} + params = [] + + # CREATE + if method == "post": + data = { + "records": [{"fields": record["fields"]} for record in records] + } + # UPDATE + elif method == "patch": + operation = "update" + data = { + "records": [ + {"id": record.get("id"), "fields": record.get("fields")} + for record in records + ] + } + # DELETE + elif method == "delete": + content_type = "application/x-www-form-urlencoded" + params = [("records[]", record.get("id")) for record in records] + else: + raise AirbaseException("Invalid HTTP method") + + headers = {"Content-Type": content_type} + message = self._basic_log_msg(records) + + if typecast: + data["typecast"] = True + async with self.base.semaphore: + res = await self._session.request( + method, self.url, json=data, params=params, headers=headers + ) + data = await self._get_data(res) + + if self._is_success(res): + self.logger.info( + f"{operation.title()}{'e' if operation[-1] != 'e' else ''}d: {message}" # noqa: E501 + ) + else: + await self.raise_or_log_error(response=res) + return data + + async def get_record(self, record_id: str) -> dict: """ Gets one record from a table. @@ -218,49 +638,49 @@ def get_record(self, record_id): record_id (``string``): ID of record. Returns: records (``list``): If succesful, a list of existing records (``dictionary``). - """ # noqa - url = self._add_record_to_url(record_id) - data, success = self.session.request( - "get", url, headers=self.base.auth + """ # noqa: E501 + return await self._request_record( + method="get", + record={"id": record_id}, ) - if success: - logger.info("Fetched: %s from table: %s", record_id, self.name) - return data - else: - logger.warning( - "Failed to get: %s from table: %s: %s", record_id, self.name, - ) - return - def get_records(self, filter_by_fields=None, filter_by_formula=None): + async def get_records( + self, + view: str = None, + filter_by_fields: list = None, + filter_by_formula: str = None, + ) -> list: """ Gets all records from a table. Kwargs: filter_by_fields (``list``, optional): list of fields(``string``) to return. Minimum 2 fields. filter_by_formula (``str``, optional): literally a formula. + view (``str``, optional): view id or name. Returns: records (``list``): If succesful, a list of existing records (``dictionary``). """ # noqa - params = {} + params: Dict[str, Any] = {} # filters if filter_by_fields: params["fields"] = filter_by_fields if filter_by_formula: params["filterByFormula"] = filter_by_formula + if view: + params["view"] = view records = [] while True: - data, success = self.session.request( - "get", self.url, params=params, headers=self.base.auth - ) - if not success: - logger.warning("Table: %s could not be retreived.", self.name) + async with self.base.semaphore: + res = await self._request("get", self.url, params=params) + if not self._is_success(res): + await self.raise_or_log_error(response=res) break + data = await self._get_data(res) try: records.extend(data["records"]) - except KeyError: + except (AttributeError, KeyError, TypeError): pass # pagination if "offset" in data: @@ -269,12 +689,15 @@ def get_records(self, filter_by_fields=None, filter_by_formula=None): break if len(records) != 0: - logger.info( - "Fetched %s records from table: %s", len(records), self.name, + self.logger.info( + f"Fetched {len(records)} records from table: {self.name}" ) - return records + self.records = records + else: + self.records = [] + return self.records - def post_record(self, record, message=None): + async def post_record(self, record: Dict, typecast: bool = False) -> Dict: """ Adds a record to a table. @@ -283,55 +706,34 @@ def post_record(self, record, message=None): Kwargs: message (``string``, optional): Custom logger message. """ - if not message: - message = self._basic_log_msg(record) - headers = {"Content-Type": "application/json"} - headers.update(self.base.auth) - data = {"fields": record["fields"]} - data, success = self.session.request( - "post", self.url, json_data=data, headers=headers + await validate_records(record, record_id=False) + return await self._request_record( + method="post", + record=record, + typecast=typecast, ) - if success: - logger.info("Posted: %s", message) - return True - else: - logger.warning("Failed to post: %s", message) - def post_records(self, records, message=None): + async def post_records( + self, records: Iterable[Dict], typecast: bool = False + ) -> None: """ Adds records to a table in batches of 10. Args: records (``list``): a list of records (``dictionary``) to post. - Kwargs: - message (``string``, optional): Name to use for logger. - """ - if message: - log_msg = message - headers = {"Content-Type": "application/json"} - headers.update(self.base.auth) - records_iter = ( - records[i : i + 10] for i in range(0, len(records), 10) + Returns: + True if succesful + """ # noqa: E501 + await validate_records(records, record_id=False) + return await self._request_records( + method="post", + records=records, + typecast=typecast, ) - for sub_list in records_iter: - if not message: - log_msg = self._basic_log_msg(sub_list) - - data = { - "records": [ - {"fields": record["fields"]} for record in sub_list - ] - } - - data, success = self.session.request( - "post", self.url, json_data=data, headers=headers - ) - if success: - logger.info("Posted: %s", log_msg) - else: - logger.warning("Failed to post: %s", log_msg) - def update_record(self, record, message=None): + async def update_record( + self, record: Dict, typecast: bool = False + ) -> Dict: """ Updates a record in a table. @@ -342,64 +744,30 @@ def update_record(self, record, message=None): Returns: records (``list``): If succesful, a list of existing records (``dictionary``). """ # noqa - try: - if not message: - message = record.get("id") - url = self._add_record_to_url(record.get("id")) - headers = {"Content-Type": "application/json"} - headers.update(self.base.auth) - data = {"fields": record["fields"]} - data, success = self.session.request( - "patch", url, headers=headers, json_data=data - ) - if success: - logger.info("Updated: %s ", message) - return True - else: - logger.warning("Failed to update: %s", message) - except Exception: - logger.warning("Invalid record format provided.") + await validate_records(record) + return await self._request_record( + method="patch", + record=record, + typecast=typecast, + ) - def update_records(self, records, message=None): + async def update_records( + self, records: Iterable[Dict], typecast: bool = False + ) -> Dict: """ Updates records in a table in batches of 10. Args: records (``list``): a list of records (``dictionary``) with updated values. - Kwargs: - message (``string``, optional): Custom logger message. - """ # noqa - try: - if message: - log_msg = message - headers = {"Content-Type": "application/json"} - headers.update(self.base.auth) - records_iter = ( - records[i : i + 10] for i in range(0, len(records), 10) - ) - for sub_list in records_iter: - if not message: - log_msg = self._basic_log_msg(sub_list) - data = { - "records": [ - { - "id": record.get("id"), - "fields": record.get("fields"), - } - for record in sub_list - ] - } - data, success = self.session.request( - "patch", self.url, headers=headers, json_data=data - ) - if success: - logger.info("Updated: %s ", log_msg) - else: - logger.warning("Failed to update: %s", log_msg) - except Exception: - logger.warning("Invalid record format provided.") + Returns: + True if succesful + """ # noqa: E501 + await validate_records(records) + return await self._request_records( + method="patch", records=records, typecast=typecast + ) - def delete_record(self, record, message=None): + async def delete_record(self, record: Dict) -> Dict: """ Deletes a record from a table. @@ -408,48 +776,23 @@ def delete_record(self, record, message=None): Kwargs: message (``string``, optional): Custom logger message. """ - if not message: - message = record.get("id") - url = self._add_record_to_url(record["id"]) - data, success = self.session.request( - "delete", url, headers=self.base.auth + await validate_records(record, fields=False) + return await self._request_record( + method="delete", + record=record, ) - if success: - logger.info("Deleted: %s", message) - return True - else: - logger.warning("Failed to delete: %s", message) - def delete_records(self, records, message=None): + async def delete_records(self, records: Iterable[Dict]) -> Dict: """ - Deletes records from a table in batches of 10. + Deletes records in a table in batches of 10. Args: - records (``list``): a list of records (``dictionary``) to delete. - Kwargs: - message (``string``, optional): Custom logger message. - """ - raise NotImplementedError - - headers = {"Content-Type": "application/x-www-form-urlencoded"} - headers.update(self.base.auth) - if message: - log_msg = message - records_iter = ( - records[i : i + 10] for i in range(0, len(records), 10) + records (``list``): a list of records (``dictionary``) to delete + Returns: + True if succesful + """ # noqa: E501 + await validate_records(records, fields=False) + return await self._request_records( + method="delete", + records=records, ) - for sub_list in records_iter: - if not message: - log_msg = self._basic_log_msg(sub_list) - - data = { - "records": [{"id": record.get("id")} for record in sub_list] - } - - data, success = self.session.request( - "delete", self.url, urlencode=data, headers=headers - ) - if success: - logger.info("Deleted: %s", log_msg) - else: - logger.warning("Failed to delete: %s", log_msg) diff --git a/airbase/airtable_async.py b/airbase/airtable_async.py deleted file mode 100644 index 0dcddfd..0000000 --- a/airbase/airtable_async.py +++ /dev/null @@ -1,617 +0,0 @@ -from __future__ import absolute_import - -import asyncio -import os -import urllib - -from aiohttp import ( - ClientConnectionError, - ClientConnectorError, - ClientSession, - ContentTypeError, - TCPConnector, - ClientResponse, -) -from json.decoder import JSONDecodeError -from typing import Any, Dict, Iterable, List # Optional, Union - -from .utils import Logger, HTTPSemaphore -from .urls import BASE_URL, META_URL - - -logger = Logger.start(__name__) - - -class BaseAirtable: - retries = 5 - - def _is_success(self, res: ClientResponse) -> bool: - if res.status >= 200 and res.status < 300: - return True - else: - return False - - async def _get_data(self, res: ClientResponse): - try: - return await res.json(encoding="utf-8") # dict - # else if raw data - except JSONDecodeError: - return await res.text(encoding="utf-8") # string - except ContentTypeError: - return await res.read() # bytes - - async def _request(self, *args, **kwargs): - try: - res = await self._session.request(*args, **kwargs) - err = False - except ( - ClientConnectionError, - ClientConnectorError, - asyncio.TimeoutError, - ): - err = True - - count = 1 - step = 5 - while err or res.status in (408, 429, 503, 504): - await asyncio.sleep(0.1 * count ** 2) - - try: - res = await self._session.request(*args, **kwargs) - err = False - except ( - ClientConnectionError, - ClientConnectorError, - asyncio.TimeoutError, - ): - err = True - - if count >= self.retries * step: - break - count += step - if res.status in (408, 429, 503, 504): - # res.raise_for_status() - pass - return res - - -class Airtable(BaseAirtable): - def __init__(self, api_key: str = None): - """ - Airtable class for multiple bases - - Kwargs: - api_key (``string``): Airtable API Key. - """ - self.api_key = api_key - self.semaphore = HTTPSemaphore(value=50, interval=1, max_calls=5) - - async def __aenter__(self): - conn = TCPConnector(limit=100) - self._session = ClientSession(connector=conn, headers=self.auth) - return self - - async def __aexit__(self, *err): - await self._session.close() - self._session = None - - @property - def api_key(self): - if getattr(self, "_api_key", None): - return self._api_key - - @api_key.setter - def api_key(self, key: str): - self._api_key = key or str(os.environ.get("AIRTABLE_API_KEY")) - self.auth = {"Authorization": "Bearer {}".format(self.api_key)} - - async def get_bases(self) -> List[Base]: # noqa: F821 - async with self.semaphore: - url = "{}/bases".format(META_URL) - res = await self._request("get", url) - - if self._is_success(res): - data = await self._get_data(res) - self.bases = [ - Base( - base["id"], - name=base["name"], - permission_level=base["permissionLevel"], - session=self._session, - logging_level="info", - ) - for base in data["bases"] - ] - self._bases_by_id = {base.id: base for base in self.bases} - self._bases_by_name = {base.name: base for base in self.bases} - return self.bases - - async def get_base(self, value: str, key: str): - assert key in (None, "id", "name") - if not getattr(self, "bases", None): - await self.get_bases() - if key == "name": - return self._bases_by_name.get(value) - elif key == "id": - return self._bases_by_id.get(value) - else: - bases = [ - base - for base in self.bases - if base.name == value or base.id == value - ] - if bases: - return bases[0] - - async def get_enterprise_account( - self, enterprise_account_id, logging_level="info" - ): - url = "{}/enterpriseAccounts/{}".format( - META_URL, enterprise_account_id - ) - res = await self._session.request("get", url, headers=self.auth) - if Airtable._is_success(res): - data = await Airtable._get_data(res) - return Account( - data["id"], - data, - session=self._session, - logging_level=logging_level, - ) - - -class Account(BaseAirtable): - def __init__( - self, - enterprise_account_id, - data=None, - session=None, - logging_level="info", - ): - """ - Airtable class for an Enterprise Account. - https://airtable.com/api/enterprise - - Args: - enterprise_account_id (``string``): ID of Entreprise Account - - Kwargs: - logging_level (``string``, default="info"): - """ - self.id = enterprise_account_id - self.url = "{}/enterpriseAccounts/{}".format(META_URL, self.id) - self._session = session - self.logging_level = logging_level - if data: - self.workspace_ids = data.get("workspaceIds") - self.user_ids = data.get("userIds") - self.email_domains = data.get("emailDomains") - self.created_time = data.get("createdTime") - - -class Base(BaseAirtable): - def __init__( - self, - base_id, - name=None, - permission_level=None, - session=None, - logging_level="info", - ): - """ - Airtable class for one base. - - Args: - BASE_ID (``string``): ID of target base. - - Kwargs: - api_key (``string``): Airtable API Key. - log (``bool``, default=True): If True it logs succesful API calls. - """ - self.id = base_id - self.name = name - self.permission_level = permission_level - self.url = "{}/bases/{}".format(META_URL, self.id) - - self._session = session - self.semaphore = HTTPSemaphore(value=50, interval=1, max_calls=5) - - self.log = logging_level - - async def get_tables(self) -> List[Table]: # noqa: F821 - async with self.semaphore: - url = "{}/tables".format(self.url) - res = await self._request("get", url) - if self._is_success(res): - data = await self._get_data(res) - self.tables = [ - Table( - self, - table["name"], - table_id=table["id"], - primary_field_id=table["primaryFieldId"], - fields=table["fields"], - views=table["views"], - ) - for table in data["tables"] - ] - self._tables_by_id = {table.id: table for table in self.tables} - self._tables_by_name = { - table.name: table for table in self.tables - } - return self.tables - - async def get_table(self, value: str, key: str): - assert key in (None, "id", "name") - if not getattr(self, "tables", None): - await self.get_tables() - if key == "name": - return self._tables_by_name.get(value) - elif key == "id": - return self._tables_by_id.get(value) - else: - tables = [ - table - for table in self.tables - if table.name == value or table.id == value - ] - if tables: - return tables[0] - - -class Table(BaseAirtable): - def __init__( - self, - base: Base, - name: str, - table_id: str = None, - primary_field_id: str = None, - fields: list = None, - views: list = None, - ) -> None: - """ - Airtable class for one table in one base - - Args: - base (``string``): Base class - name (``string``): Name of target table. - - """ - self.base = base - self.name = name - self.id = table_id - self.primary_field_id = primary_field_id - self.fields = fields - self.views = views - self.url = self._compose_url() - self.primary_field_name = ( - [ - field["name"] - for field in self.fields - if field["id"] == self.primary_field_id - ][0] - if self.fields and self.primary_field_id - else None - ) - self._session = base._session - - @staticmethod - def _basic_log_msg(content: Iterable) -> str: - """ - Constructs a basic logger message - """ - if isinstance(content, list): - if len(content) > 1: - plural = "s" - else: - plural = "" - message = "{} record{}".format(len(content), plural) - else: - message = "1 record" - return message - - def _add_record_to_url(self, record_id: str) -> str: - """ - Composes the airtable url with a record id - - Args: - record_id (``string``, optional): ID of target record. - Returns: - url (``string``): Composed url. - """ - return f"{self.url}/{record_id}" - - def _compose_url(self) -> str: - """ - Composes the airtable url. - - Returns: - url (``string``): Composed url. - """ - return f"{BASE_URL}/{self.base.id}/{urllib.parse.quote(self.name)}" - - async def _multiple(self, func, records: list) -> bool: - """ - Posts/Patches/Deletes records to a table in batches of 10. - - Args: - func (``method``): a list of records (``dictionary``) to post. - records (``list``): a list of records (``dictionary``) to post/patch/delete. - Kwargs: - message (``string``, optional): Name to use for logger. - """ # noqa: E501 - records_iter = ( - records[i : i + 10] for i in range(0, len(records), 10) - ) - - tasks = [] - for sub_list in records_iter: - tasks.append(asyncio.create_task(func(sub_list))) - results = await asyncio.gather(*tasks) - if any(not r for r in results): - return False - else: - return True - - async def get_record(self, record_id: str) -> dict: - """ - Gets one record from a table. - - Args: - record_id (``string``): ID of record. - Returns: - records (``list``): If succesful, a list of existing records (``dictionary``). - """ # noqa: E501 - url = self._add_record_to_url(record_id) - async with self.base.semaphore: - res = await self._request("get", url) - data = await self._get_data(res) - if self._is_success(res): - val = data["fields"].get(self.primary_field_name) or record_id - logger.info(f"Fetched record: <{val}> from table: {self.name}") - return data - else: - logger.error( - f"{res.status}: Failed to get record: <{record_id}> from table: {self.name} -> {data.get('error')}" # noqa: E501 - ) - return {} - - async def get_records( - self, - view: str = None, - filter_by_fields: list = None, - filter_by_formula: str = None, - ) -> list: - """ - Gets all records from a table. - - Kwargs: - filter_by_fields (``list``, optional): list of fields(``string``) to return. Minimum 2 fields. - filter_by_formula (``str``, optional): literally a formula. - view (``str``, optional): view id or name. - Returns: - records (``list``): If succesful, a list of existing records (``dictionary``). - """ # noqa - params: Dict[str, Any] = {} - - # filters - if filter_by_fields: - params["fields"] = filter_by_fields - if filter_by_formula: - params["filterByFormula"] = filter_by_formula - if view: - params["view"] = view - - records = [] - while True: - async with self.base.semaphore: - res = await self._request("get", self.url, params=params) - if not self._is_success(res): - logger.warning(f"Table: {self.name} could not be retreived.") - break - data = await self._get_data(res) - try: - records.extend(data["records"]) - except (AttributeError, KeyError, TypeError): - pass - # pagination - if "offset" in data: - params["offset"] = data["offset"] - else: - break - - if len(records) != 0: - logger.info( - f"Fetched {len(records)} records from table: {self.name}" - ) - self.records = records - else: - self.records = [] - return self.records - - async def post_record(self, record: dict) -> bool: - """ - Adds a record to a table. - - Args: - record (``dictionary``): Record to post. - Kwargs: - message (``string``, optional): Custom logger message. - """ - message = self._basic_log_msg(record) - headers = {"Content-Type": "application/json"} - data = {"fields": record["fields"]} - async with self.base.semaphore: - res = await self._request( - "post", self.url, json=data, headers=headers - ) - if self._is_success(res): - logger.info(f"Posted: {message}") - return True - else: - data = await self._get_data(res) - logger.error( - f"{res.status}: Failed to post: {message} -> '{data.get('error').get('message')}'" # noqa:E501 - ) - return False - - async def _post_records(self, records: list) -> bool: - headers = {"Content-Type": "application/json"} - message = self._basic_log_msg(records) - - data = { - "records": [{"fields": record["fields"]} for record in records] - } - async with self.base.semaphore: - res = await self._session.request( - "post", self.url, json=data, headers=headers - ) - if self._is_success(res): - logger.info(f"Posted: {message}") - return True - else: - data = await self._get_data(res) - logger.error( - f"{res.status}: Failed to post: {message} -> '{data.get('error').get('message')}'" # noqa:E501 - ) - return False - - async def post_records(self, records: list) -> None: - """ - Adds records to a table in batches of 10. - - Args: - records (``list``): a list of records (``dictionary``) to post. - Returns: - True if succesful - """ # noqa: E501 - return await self._multiple(self._post_records, records) - - async def update_record(self, record: dict) -> bool: - """ - Updates a record in a table. - - Args: - record (``dictionary``): Record with updated values. - Kwargs: - message (``string``, optional): Name of record to use for logger. - Returns: - records (``list``): If succesful, a list of existing records (``dictionary``). - """ # noqa - message = record["fields"].get(self.primary_field_name) or record.get( - "id" - ) - url = self._add_record_to_url(record.get("id")) - headers = {"Content-Type": "application/json"} - data = {"fields": record.get("fields")} - async with self.base.semaphore: - res = await self._request("patch", url, json=data, headers=headers) - if self._is_success(res): - logger.info(f"Updated: {message}") - return True - else: - data = await self._get_data(res) - logger.error( - f"{res.status}: Failed to update: {message} -> '{data.get('error').get('message')}'" # noqa:E501 - ) - return False - - async def _update_records(self, records: list) -> bool: - headers = {"Content-Type": "application/json"} - message = self._basic_log_msg(records) - data = { - "records": [ - {"id": record.get("id"), "fields": record.get("fields")} - for record in records - ] - } - async with self.base.semaphore: - res = await self._request( - "patch", self.url, headers=headers, json=data - ) - if self._is_success(res): - logger.info(f"Updated: {message}") - return True - else: - data = await self._get_data(res) - logger.error( - f"{res.status}: Failed to update: {message} -> '{data.get('error').get('message')}'" # noqa:E501 - ) - return False - - async def update_records(self, records: list) -> bool: - """ - Updates records in a table in batches of 10. - - Args: - records (``list``): a list of records (``dictionary``) with updated values. - Returns: - True if succesful - """ # noqa: E501 - return await self._multiple(self._update_records, records) - - async def delete_record(self, record: dict) -> bool: - """ - Deletes a record from a table. - - Args: - record (``dictionary``): Record to remove. - Kwargs: - message (``string``, optional): Custom logger message. - """ - message = record["fields"].get(self.primary_field_name) or record.get( - "id" - ) - url = self._add_record_to_url(record["id"]) - async with self.base.semaphore: - res = await self._session.request("delete", url) - if self._is_success(res): - logger.info(f"Deleted: {message}") - return True - else: - data = await self._get_data(res) - logger.error( - f"{res.status}: Failed to delete: {message} -> '{data.get('error').get('message')}'" # noqa:E501 - ) - return False - - async def _delete_records(self, records: list) -> bool: - """ - Deletes records from a table in batches of 10. - - Args: - records (``list``): a list of records (``dictionary``) to delete. - Kwargs: - message (``string``, optional): Custom logger message. - """ - - headers = {"Content-Type": "application/x-www-form-urlencoded"} - message = self._basic_log_msg(records) - - data = {"records[]": [record.get("id") for record in records]} - params = urllib.parse.urlencode(data, True) - - async with self.base.semaphore: - res = await self._request( - "delete", self.url, params=params, headers=headers - ) - if self._is_success(res): - logger.info(f"Deleted: {message}") - return True - else: - data = await self._get_data(res) - logger.error( - f"{res.status}: Failed to delete: {message} -> '{data.get('error').get('message')}'" # noqa:E501 - ) - return False - - async def delete_records(self, records: list) -> bool: - """ - Deletes records in a table in batches of 10. - - Args: - records (``list``): a list of records (``dictionary``) to delete - Returns: - True if succesful - """ # noqa: E501 - return await self._multiple(self._delete_records, records) diff --git a/airbase/archive/tools.py b/airbase/archive/tools.py deleted file mode 100644 index 3aef63e..0000000 --- a/airbase/archive/tools.py +++ /dev/null @@ -1,1251 +0,0 @@ -from __future__ import absolute_import - -import copy -import json - -from datetime import datetime, timezone -from .airtable import Airtable -from .utils import Logger - - -logger = Logger.start(__name__) - - -def elapsed_time(start_time): - """ - """ - NOW = datetime.now(timezone.utc).replace(tzinfo=None) - timedelta = int((NOW - start_time).total_seconds()) + 1 - return timedelta - - -def pretty_print(obj): - """ - """ - print(json.dumps(obj, sort_keys=True, indent=4)) - - -def compose_formula(value_dict): - """ - - Args: - value_dict (``dictionary``): Keys = field_name, Values = [value, equality_or_inequality (``boolean``)] - Returns: - formula(``str``) - """ # noqa - - try: - filter_formula = "AND(" - for i, field in enumerate(value_dict): - value = value_dict[field][0] - equality = value_dict[field][1] - filter_field = "{} = '{}'".format("{" + field + "}", value) - if not equality: - filter_field = "NOT({})".format(filter_field) - filter_formula += filter_field - if i < len(value_dict) - 1: - filter_formula += "," - filter_formula += ")" - return filter_formula - except KeyError: - logger.warning("Error implementing compose_formula()") - - -def compose_time_formula(filter_time_field, interval): - if filter_time_field: - filter_time_field = "{" + filter_time_field + "}" - return "DATETIME_DIFF(NOW(), {}, 'seconds') < {}".format( - filter_time_field, interval - ) - - -def record_exists(record, table, fields): - """ - Checks if a record already exists in a table by looking at matching fields. - - Args: - record (``dictionary``): Record to check if exists in ``table``. - table (``list``): List of records from airtable. - fields (``list``): List of fields (``string``) to check for matching values. - Returns: - existing_record (``dictionary``): If exists. If not returns ``None``. - """ # noqa - filter_data = {key: record["fields"][key] for key in fields} - for i, existing_record in enumerate(table): - fields_found = 0 - for key, value in filter_data.items(): - other_value = existing_record["fields"].get(key) - if other_value == value: - fields_found += 1 - if fields_found == len(fields): - return existing_record, i - return None, None - - -def link_record( - record, table, filters_r, filters_t, field=None, contains=False -): - """ - Links records from another table to a record based on filter criteria. - - Args: - record (``dictionary``): Airtable record. - table (``list``): List of records from airtable. - filters_r (``list``): list of fields(``string``) in ``record`` to search in ``table``. - filters_t (``list``): matching fields(``string``) in ``table`` to search in. - Kwargs: - field (``string``, optional): Name of unique field to add linked records to. - contains (``boolean``, optional): Do you want to search all fields that contain ``filter_r``? - Returns: - record (``dictionary``): If exists. If not returns ``None``. - """ # noqa - filters_r = [x.strip() for x in filters_r] - filters_t = [x.strip() for x in filters_t] - table = table or [] - - if contains: - contains_filters_r = [] - for filter_r in filters_r: - for field in record["fields"]: - if filter_r.lower() in field.lower(): - contains_filters_r.append(field) - filters_r = contains_filters_r - - new_record = copy.deepcopy(record) - for filter_r in filters_r: - link_ids = [] - for row in table: - try: - thetas = [ - x.strip() for x in record["fields"][filter_r].split(",") - ] - for theta in thetas: - for filter_t in filters_t: - val = row["fields"][filter_t] - if isinstance(theta, str) and isinstance(val, str): - theta = theta.lower() - val = val.lower() - if theta == val: - link_ids.append(row["id"]) - break - except (KeyError, AttributeError): - pass - - # get target field - if not field or len(filters_r) > 1: - field = filter_r - # link records - if len(link_ids) > 0: - new_record["fields"][field] = sorted(link_ids) - else: - new_record["fields"][field] = None - return new_record - - -def combine_records(record_a, record_b, join_fields=None): - """ - Combines unique information from two records into 1. - - Args: - record_a (``dictionary``): New airtable record. - record_b (``dictionary``): Old airtable record (This will be dictate the ``id``) - Kwargs: - join_fields (``list``, optional): list of fields(``string``) to combine. - Returns: - record (``dictionary``): If succesful, the combined ``record``, else ``record_a``. - """ # noqa - try: - record = {"id": record_b["id"], "fields": {}} - - if join_fields: - keys = join_fields - else: - keys = record_a["fields"] - for key in keys: - field = record_a["fields"][key] - if isinstance(field, list): - field = record_a["fields"][key] - for item in record_b["fields"][key]: - if item not in record_a["fields"][key]: - field.append(item) - elif isinstance(field, str): - field = ( - record_a["fields"][key] + ", " + record_b["fields"][key] - ) - elif isinstance(field, int) or ( - isinstance(field, float) or isinstance(field, tuple) - ): - field = record_a["fields"][key] + record_b["fields"][key] - record["fields"][key] = field - return record - except Exception: - return record_a - - -def filter_record(record_a, record_b, filter_fields=None): - """ - Filters a record for unique information. - - Args: - record_a (``dictionary``): New airtable record. - record_b (``dictionary``): Old airtable record (This will be dictate the ``id``) - Kwargs: - filter_fields (``list``, optional): list of fields(``string``) to filter. - Returns: - record (``dictionary``): If succesful, the filtered ``record``, else ``record_a``. - """ # noqa - try: - record = {"id": record_b["id"], "fields": {}} - if filter_fields: - keys = filter_fields - else: - keys = record_a["fields"] - except Exception: - logger.warning("Could not filter record.") - return record_a - - for key in keys: - try: - if record_a["fields"][key] != record_b["fields"][key]: - record["fields"][key] = record_a["fields"][key] - except KeyError: - if record_a["fields"][key]: - record["fields"][key] = record_a["fields"][key] - return record - - -def override_record(record, existing_record, overrides): - """ - Removes fields from record if user has overriden them on airtable. - - Args: - record (``dictionary``): Record from which fields will be removed if overwritten. - existing_record (``dictionary``): Record to check for overrides. - overrides (``list``): List of dictionaries - Each dictionary is composed of two items: 1. The override checkbox field name, 2. The override field name - {"ref_field": "field name", "override_field": "field name"} - Return: - record. - """ # noqa - for override in overrides: - ref_field = override.get("ref_field") - override_field = override.get("override_field") - if existing_record["fields"].get(ref_field): - record["fields"][override_field] = existing_record["fields"][ - override_field - ] - return record - - -def compare_records( - record_a, record_b, method, overrides=None, filter_fields=None -): - """ - Compares a record in a table. - - Args: - record_a (``dictionary``): record to compare - record_b (``dictionary``): record to compare against. - method (``string``): Either "overwrite" or "combine" - Kwargs: - overrides (``list``): List of dictionaries - Each dictionary is composed of two items: 1. The override checkbox field name, 2. The override field name - {"ref_field": "field name", "override_field": "field name"} - filter_fields (``list``, optional): list of fields(``string``) to update. - Returns: - records (``list``): If succesful, a list of existing records (``dictionary``). - """ # noqa - try: - if overrides: - record = override_record(record_a, record_b, overrides) - if method == "overwrite": - record = filter_record( - record_a, record_b, filter_fields=filter_fields - ) - elif method == "combine": - record = combine_records( - record_a, record_b, join_fields=filter_fields - ) - return record - except Exception: - logger.warning("Invalid record format provided.") - - -def graft_fields(record, fields, separator=",", sort=True): - - for field in fields: - value = record["fields"].get(field) - if value: - if separator in value: - value_list = value.split(",") - if sort: - value_list = value_list.sort() - else: - value_list = [value] - record["fields"][field] = value_list - return record - - -def is_record(value): - """ - Checks whether a value is a Record ID or a list of Record IDs - - Args: - value (``obj``): any value retrieved from an airtable record field. - Returns: - (``bool``): True if value is Record ID or a list of Record IDs - """ - if isinstance(value, list) and value: - value = value[0] - return isinstance(value, str) and value[0:3] == "rec" and len(value) == 17 - - -def add_tables_to_target_data( - ref_airtable, target_table_names, extracted_data -): - """ - Checks whether a target table has been added to extracted data, and if not it adds it with a dummy dict. - - Args: - ref_airtable (``object``): Airtable() of back-end trigger table - target_table_name (``str``): Name of back-end target table. - extracted_data (``dict``): Extracted data as produced in extract_ref_data() - Returns: - extracted_data (``dict``) - """ # noqa - - for target_table_name in target_table_names: - if target_table_name not in extracted_data["target_info"]: - target_table = Airtable( - ref_airtable.BASE_ID, target_table_name, log=False - ) - target_table = target_table.get_table() - - target_info = { - record["fields"]["Field"]: record["fields"]["Example"] - for record in target_table - if record["fields"]["Field"] - in ("**BASE ID**", "**TABLE NAME**") - } - - try: - target_time_filter = [ - target_record["fields"]["Field"] - for target_record in target_table - if target_record["fields"].get("Filter - Time") - and target_record["fields"].get("Target Table") - and ref_airtable.TABLE_NAME - in target_record["fields"].get("Target Table") - ][0] - except IndexError: - target_time_filter = None - - try: - target_value_filter = [ - target_record["fields"]["Field"] - for target_record in target_table - if target_record["fields"].get("Filter - Value") - and target_record["fields"].get("Target Table") - == ref_airtable.TABLE_NAME - ][0] - except IndexError: - target_value_filter = None - - extracted_data["target_info"][target_table_name] = { - "base_id": target_info["**BASE ID**"], - "table_name": target_info["**TABLE NAME**"], - "trigger_record_id_field": None, - "match_fields": None, - "filter_time_fields": { - "target": target_time_filter, - "trigger": None, - }, - "filter_value_fields": { - "target": target_value_filter, - "trigger": None, - }, - } - - return extracted_data - - -def analyse_value(value): - """ - Analyses an airtable value to see if it needs to be: - 1. flattened from a list to a string - 2. a more detailed error message composed - - Args: - value (``obj``): any value retrieved from an airtable record field. - Returns: - flatten (``bool``): True if value needs to be flattened. - error_msg (``string``): Custom error message if any. - """ - flatten = None - error_msg = None - - if is_record(value): - error_msg = "{}{}{}".format( - "If this field indicates the record id ", - "of the target record, please check 'Flatten' and identify ", - "the name of the target table in 'Target Table.", - ) - - elif isinstance(value, str) and len(value) == 10: - try: - int(value[:4]) - error_msg = "{}{}".format( - "If this field is a date, input should be a ", - "string in YYYY-mm-DD format, i.e. 2019-06-19.", - ) - except ValueError: - pass - - if isinstance(value, list) and len(value) == 1: - flatten = True - - return flatten, error_msg - - -def get_fields(table): - """ - Get all non-empty fields from an airtable table - Args: - table (``list``): List of records retrieved using the get_table method or formatted to match an airtable table. - Returns: - fields_table (``list``): List of records, where each record represents a field in the input table and is structured as follows: - { - "fields": { - "Field": , - "Type": , - "Example": , - "Custom Error Message": , - } - } - """ # noqa - start_time = datetime.today() - - retrieved_fields = [] - fields_table = [] - - for record in table: - for field, value in record["fields"].items(): - if field not in retrieved_fields: - retrieved_fields.append(field) - - flatten, error_msg = analyse_value(value) - if flatten: - value = value[0] - - fields_record = { - "fields": { - "Field": field, - "Type": type(value).__name__, - "Example": str(value), - } - } - - if error_msg: - fields_record["fields"]["Custom Error Message"] = error_msg - - fields_table.append(fields_record) - - end_time = datetime.today() - logger.info( - "Retrieved %s fields in: %s", - len(retrieved_fields), - end_time - start_time, - ) - return fields_table - - -def replace_values(field, value): - # Simplify attachement objects - if isinstance(value, list) and isinstance(value[0], dict): - new_value = [{"url": obj["url"]} for obj in value if "url" in obj] - else: - new_value = value - return new_value - - -def replace_fields(record, data): - """ - """ - new_record = {"id": record["id"], "fields": {}} - - field_names = ( - ( - data["fields"][field]["trigger"]["Field"], - data["fields"][field]["target"]["Field"], - ) - for field in data["fields"] - ) - - for trigger_field, target_field in field_names: - error_msgs = [ - data["fields"][trigger_field]["trigger"].get( - "Custom Error Message" - ), - data["fields"][trigger_field]["target"].get( - "Custom Error Message" - ), - ] - - try: - if data["method"] == "push": - field_name = target_field - field_value = record["fields"][trigger_field] - if isinstance(field_value, list) and data["fields"][ - trigger_field - ]["trigger"].get("Flatten"): - if any(field_value): - field_value = ", ".join( - [v for v in field_value if isinstance(v, str)] - ) - else: - field_value = None - - elif data["method"] == "pull": - field_name = trigger_field - field_value = record["fields"][target_field] - if isinstance(field_value, list) and data["fields"][ - trigger_field - ]["target"].get("Flatten"): - if any(field_value): - field_value = ", ".join( - [v for v in field_value if isinstance(v, str)] - ) - else: - field_value = None - - new_record["fields"][field_name] = replace_values( - field_name, field_value - ) - - except KeyError: - pass - - if not ( - data["fields"][trigger_field]["type_match"] - or any("record id" in str(error_msg) for error_msg in error_msgs) - ): - if data["method"] == "push": - field_name = trigger_field - src_type = data["fields"][trigger_field]["trigger"]["Type"] - tgt_type = data["fields"][trigger_field]["target"]["Type"] - tgt_example = data["fields"][trigger_field]["target"][ - "Example" - ] - error_msg = data["fields"][trigger_field]["target"].get( - "Custom Error Message" - ) - elif data["method"] == "pull": - field_name = target_field - src_type = data["fields"][trigger_field]["target"]["Type"] - tgt_type = data["fields"][trigger_field]["trigger"]["Type"] - tgt_example = data["fields"][trigger_field]["trigger"][ - "Example" - ] - error_msg = data["fields"][trigger_field]["trigger"].get( - "Custom Error Message" - ) - message = "<{}>{}{}{}{}".format( - field_name, - " may need to be removed from payload,", - " because data types don't match. ", - "Expecting a <{}>, but received a <{}> ".format( - src_type, tgt_type - ), - "i.e. <{}>.".format(tgt_example), - ) - if error_msg: - message += " Clue: {}".format(error_msg) - - logger.warning(message) - - if len(new_record["fields"]) > 0: - return new_record - - -def is_within_time_interval(start_time, interval, filter_field, record): - """ - """ - if filter_field and record: - filter_time = record["fields"].get(filter_field) - if filter_time: - filter_time = datetime.strptime( - filter_time, "%Y-%m-%dT%H:%M:%S.%fZ" - ) - return (start_time - filter_time).total_seconds() <= interval - - -def get_documentation(read_airtable, bridge_airtable): - """ - Documents how an airtable table is structured, in another airtable and includes: - base_id, table name and all non-empty fields. - - Args: - read_airtable (``object``): Airtable() to document. - write_airtable (``object``): Airtable() to write documentation in. - Returns: - None - - """ # noqa - start_time = datetime.today() - logger.info("Began documentation of table: %s ", read_airtable.TABLE_NAME) - read_table = read_airtable.get_table() - if read_table: - read_fields = get_fields(read_table) - read_fields.append( - { - "fields": { - "Field": "**BASE ID**", - "Type": type(read_airtable.BASE_ID).__name__, - "Example": read_airtable.BASE_ID, - "Custom Error Message": "This field is for reference only", - } - } - ) - read_fields.append( - { - "fields": { - "Field": "**TABLE NAME**", - "Type": type(read_airtable.TABLE_NAME).__name__, - "Example": read_airtable.TABLE_NAME, - "Custom Error Message": "This field is for reference only", - } - } - ) - crud_table(bridge_airtable, read_fields, ["Field"]) - - else: - logger.warning( - "Please create one dummy recod in table %s\ - in order to document the table" - ) - - end_time = datetime.today() - logger.info( - "Finished documentation of table %s in: %s", - read_airtable.TABLE_NAME, - end_time - start_time, - ) - - -def get_method_order(method): - """ - """ - if method == "pull": - return "1" - elif method == "grab": - return "2" - elif method == "push": - return "3" - - -def extract_ref_data(ref_airtable): - """ - The user must first establish and define a reference bridge/back-end airtable table that links tables together via one of three operations/methods: - "pull", "push" or "grab". See for more information. - - This function extracts the data from the reference bridge/back-end airtable table so compose_link_data() can then create the guide/manual/roadmap for the function link_tables() - - Args: - ref_airtable (``Airtable``): Airtable - - Returns: - extracted_data (``dict``): a dict including: - 1. linked_data (linked fields by target_table and method) - 2. trigger_info (base_id and table_name) - 3. target_info (trigger_record_id_field, match_fields, filter_fields by target_table) - """ # noqa - - start_time = datetime.today() - logger.info( - "STARTED: extracting data from table: %s ", ref_airtable.TABLE_NAME - ) - - ref_table = ref_airtable.get_table() - - extracted_data = {"link_data": {}, "trigger_info": {}, "target_info": {}} - - for ref_record in ref_table: - - trigger_field_name = ref_record["fields"]["Field"] - - # trigger base id - if trigger_field_name == "**BASE ID**": - extracted_data["trigger_info"]["base_id"] = ref_record["fields"][ - "Example" - ] - continue - - # trigger table name - elif trigger_field_name == "**TABLE NAME**": - extracted_data["trigger_info"]["table_name"] = ref_record[ - "fields" - ]["Example"] - continue - - # trigger filter time field - elif ref_record["fields"].get("Filter - Time"): - target_tables = ref_record["fields"]["Target Table"].split(", ") - extracted_data = add_tables_to_target_data( - ref_airtable, target_tables, extracted_data - ) - for target_table in target_tables: - extracted_data["target_info"][target_table][ - "filter_time_fields" - ]["trigger"] = trigger_field_name - continue - - # trigger filter value field - elif ref_record["fields"].get("Filter - Value"): - target_tables = ref_record["fields"]["Target Table"].split(", ") - extracted_data = add_tables_to_target_data( - ref_airtable, target_tables, extracted_data - ) - for target_table in target_tables: - extracted_data["target_info"][target_table][ - "filter_value_fields" - ]["trigger"] = trigger_field_name - - # unique target - record id - elif ref_record["fields"].get("Unique Target - Record ID"): - target_tables = ref_record["fields"]["Target Table"].split(", ") - extracted_data = add_tables_to_target_data( - ref_airtable, target_tables, extracted_data - ) - for target_table in target_tables: - extracted_data["target_info"][target_table][ - "trigger_record_id_field" - ] = trigger_field_name - - # generator filtering only linked fields - linked_fields = ( - field - for field in ref_record["fields"] - if field[:7] in ("PUSH - ", "PULL - ", "GRAB - ") - ) - - for linked_field in linked_fields: - linked_table_name = linked_field[7:] - - extracted_data = add_tables_to_target_data( - ref_airtable, [linked_table_name], extracted_data - ) - - method = linked_field[:4].lower() - - order = get_method_order(linked_field[:4].lower()) - ordered_table_name = order + " - " + linked_table_name - - # create datum if linked_table_name not in link_data - if ordered_table_name not in extracted_data["link_data"]: - extracted_data["link_data"][ordered_table_name] = { - "fields": {}, - "method": method, - } - target_airtable = Airtable( - ref_airtable.BASE_ID, linked_table_name, log=False - ) - target_record = target_airtable.get_record( - ref_record["fields"][linked_field][0] - ) - - extracted_data["link_data"][ordered_table_name]["fields"][ - trigger_field_name - ] = { - "trigger": ref_record["fields"], - "target": target_record["fields"], - "type_match": ref_record["fields"]["Type"] - == target_record["fields"]["Type"], - } - - # unique target - match fields - if ref_record["fields"].get("Unique Target - Match Field"): - target_tables = ref_record["fields"]["Target Table"].split( - ", " - ) - extracted_data = add_tables_to_target_data( - ref_airtable, target_tables, extracted_data - ) - for target_table in target_tables: - extracted_data["target_info"][target_table][ - "match_fields" - ] = { - "trigger": trigger_field_name, - "target": target_record["fields"]["Field"], - } - - # if field is a list force type match - if ref_record["fields"].get("Flatten"): - extracted_data["link_data"][ordered_table_name]["fields"][ - trigger_field_name - ]["type_match"] = True - - # pretty_print(extracted_data) - - end_time = datetime.today() - logger.info( - "FINISHED: extracting data from table: %s in: %s", - ref_airtable.TABLE_NAME, - end_time - start_time, - ) - - return extracted_data - - -def compose_link_data(extracted_data, log=True): - """ - The user must first establish and define a reference bridge/back-end airtable table that links tables together via one of three operations/methods: - "pull", "push" or "grab". See for more information. - - This function creates a dictionary that serves as a guide/manual/roadmap for the function link_tables() - - Args: - extracted_data (``dict``): a dict including - - Kwargs: - log (``bool``, default=True): Print to logger if succesful - - Returns: - link_data (``dict``): a dict detailing how to connect the tables. - i.e. - link_data = { - "order - target_ref_table_name": { - "trigger": { - "base_id": , - "table_name": , - }, - "target": { - "base_id": , - "table_name": , - "ref_field": or None, - }, - "method": , - "match_fields": { - "trigger": , - "target": - }, - "filter_time_fields": { - "trigger": , - "target": - }, - "fields": { - "field_1_name": { - "trigger": trigger_field_info_1, - "target": tgt_field_info_1, - "type_match": True, - }, - "field_2_name": { - "trigger": trigger_field_info_2, - "target": tgt_field_info_2, - "type_match": True, - }, - "field_n-1_name": { - "trigger": trigger_field_info_n-1, - "target": tgt_field_info_n-1, - "type_match": True, - }, - "field_n_name": { - "trigger": trigger_field_info_n, - "target": tgt_field_info_n, - "type_match": True, - }, - }, - }, - } - - """ # noqa - start_time = datetime.today() - logger.info( - "STARTED: getting link data of table: %s ", - extracted_data["trigger_info"]["table_name"], - ) - - link_data = extracted_data["link_data"] - trigger_info = extracted_data["trigger_info"] - - remove_list = [] - - for ordered_table_name in link_data: - try: - target_info = extracted_data["target_info"][ordered_table_name[4:]] - - link_data[ordered_table_name]["trigger"] = { - "base_id": trigger_info["base_id"], - "table_name": trigger_info["table_name"], - "record_id_field": target_info["trigger_record_id_field"], - } - - link_data[ordered_table_name]["target"] = { - "base_id": target_info["base_id"], - "table_name": target_info["table_name"], - } - - link_data[ordered_table_name]["match_fields"] = target_info[ - "match_fields" - ] - - link_data[ordered_table_name]["filter_time_fields"] = target_info[ - "filter_time_fields" - ] - - link_data[ordered_table_name]["filter_value_fields"] = target_info[ - "filter_value_fields" - ] - except KeyError: - remove_list.append(ordered_table_name) - logger.warning( - "No matching data or unique record id found for table: %s", - ordered_table_name[4:], - ) - - for item in remove_list: - link_data.pop(item) - - if log: - pretty_print(link_data) - - end_time = datetime.today() - logger.info( - "FINISHED: getting link data of table %s in: %s", - extracted_data["trigger_info"]["table_name"], - end_time - start_time, - ) - - return link_data - - -def _sort_data_by_method(data): - """ - """ - if data["method"] == "push": - origin = "trigger" - dest = "target" - - elif data["method"] == "pull": - origin = "target" - dest = "trigger" - - origin_airtable = Airtable( - data[origin]["base_id"], data[origin]["table_name"], log=False - ) - dest_airtable = Airtable(data[dest]["base_id"], data[dest]["table_name"]) - filter_time_field = data["filter_time_fields"].get(origin) - filter_value_field = data["filter_value_fields"].get("trigger") - - if data["match_fields"]: - dest_match_fields = [data["match_fields"].get(dest)] - else: - dest_match_fields = None - - return ( - origin_airtable, - dest_airtable, - filter_time_field, - filter_value_field, - dest_match_fields, - ) - - -def _grab(data, interval, start_time): - """ - """ - origin_airtable = Airtable( - data["target"]["base_id"], data["target"]["table_name"], log=False - ) - dest_airtable = Airtable( - data["trigger"]["base_id"], data["trigger"]["table_name"] - ) - - filter_time_field = data["filter_time_fields"].get("trigger") - - logger.info( - "Start: Grabbing data from table: <%s> to table: <%s>", - origin_airtable.TABLE_NAME, - dest_airtable.TABLE_NAME, - ) - - grab_fields = { - field: data["fields"][field]["target"]["Field"] - for field in data["fields"] - } - - interval += elapsed_time(start_time) - formula = compose_time_formula(filter_time_field, interval) - - dest_table = dest_airtable.get_table(filter_by_formula=formula) or [] - - for dest_record in dest_table: - for grab_field, origin_field in grab_fields.items(): - origin_record_ids = dest_record["fields"].get(grab_field) - if origin_record_ids: - origin_record_ids = origin_record_ids.split(", ") - if is_record(origin_record_ids): - payload = "" - for i, origin_record_id in enumerate(origin_record_ids): - origin_record = origin_airtable.get_record( - origin_record_id - ) - - grab_value = origin_record["fields"].get(origin_field) - if grab_value: - payload += grab_value - if i < len(origin_record_ids) - 1: - payload += ", " - - record = { - "id": dest_record["id"], - "fields": {grab_field: payload}, - } - - dest_airtable.update_record(record, message=record["id"]) - - -def _post(data, interval, start_time): - """ - """ - - if data["method"] in ("push", "pull"): - ( - origin_airtable, - dest_airtable, - filter_time_field, - filter_value_field, - dest_match_fields, - ) = _sort_data_by_method(data) - - else: - _grab(data, interval) - return - - logger.info( - "Start: %sing data from table: <%s> to table: <%s>", - data["method"].title(), - origin_airtable.TABLE_NAME, - dest_airtable.TABLE_NAME, - ) - - formula = None - - if filter_time_field: - interval += elapsed_time(start_time) - formula_time = compose_time_formula(filter_time_field, interval) - formula = formula_time - - if filter_value_field: - formula_value = compose_formula({filter_value_field: ["", False]}) - formula = formula_value - - if filter_time_field and filter_value_field: - formula = "AND({},{})".format(formula_time, formula_value) - - origin_table = origin_airtable.get_table(filter_by_formula=formula) - - if origin_table: - origin_table = [ - replace_fields(record, data) for record in origin_table - ] - crud_table( - dest_airtable, - origin_table, - dest_match_fields, - update=False, - delete=False, - ) - - -def _patch(data, interval, start_time): - """ - """ - if data["method"] in ("push", "pull"): - ( - origin_airtable, - dest_airtable, - filter_time_field, - filter_value_field, - dest_match_fields, - ) = _sort_data_by_method(data) - - else: - _grab(data, interval) - return - - logger.info( - "Start: %sing data from table: <%s> to table: <%s>", - data["method"].title(), - origin_airtable.TABLE_NAME, - dest_airtable.TABLE_NAME, - ) - - formula_time = None - formula_value = None - - if filter_time_field: - interval += elapsed_time(start_time) - formula_time = compose_time_formula(filter_time_field, interval) - - if filter_value_field: - formula_value = compose_formula({filter_value_field: ["", False]}) - - origin_table = origin_airtable.get_table(filter_by_formula=formula_time) - - if origin_table: - origin_table = {record["id"]: record for record in origin_table} - - trigger_airtable = Airtable( - data["trigger"]["base_id"], - data["trigger"]["table_name"], - log=False, - ) - - trigger_table = ( - trigger_airtable.get_table(filter_by_formula=formula_value) or [] - ) - - # record ids generator - record_ids = ( - ( - trigger_record["id"], - trigger_record["fields"].get( - data["trigger"]["record_id_field"] - ), - ) - for trigger_record in trigger_table - if trigger_record["fields"].get(data["trigger"]["record_id_field"]) - ) - - for trigger_id, target_id in record_ids: - if isinstance(target_id, list): - target_id = target_id[0] - - if data["method"] == "push": - origin_record_id = trigger_id - dest_record_id = target_id - - elif data["method"] == "pull": - origin_record_id = target_id - dest_record_id = trigger_id - - origin_record = origin_table.get(origin_record_id) - - if origin_record: - dest_record = dest_airtable.get_record(dest_record_id) - origin_record = replace_fields(origin_record, data) - - if origin_record: - origin_record = compare_records( - origin_record, dest_record, "overwrite" - ) - if origin_record["fields"]: - dest_airtable.update_record( - origin_record, message=dest_record_id - ) - - -def link_tables(link_data, interval, start_time): - """ - Link two tables based on relationships established in a back end airtable and defined via link_data. - - Args: - link_data (``dict``): a dict detailing how to connect the tables. See compose_link_data() method. - Returns: - None - - """ # noqa - - for i in range(3): - for table, data in link_data.items(): - if table[0] == str(i + 1): - this_start_time = datetime.today() - - if data["method"] in ("push", "pull"): - if data["trigger"]["record_id_field"]: - _patch(data, interval, start_time) - else: - _post(data, interval, start_time) - - elif data["method"] == "grab": - _grab(data, interval, start_time) - - end_time = datetime.today() - logger.info( - "Finished linking tables in: %s", - end_time - this_start_time, - ) - - -def crud_table( - existing_airtable, - new_table, - match_fields, - overrides=None, - update=True, - delete=True, -): - """ - Create, Read, Update and Delete records for airtable table - - Args: - existing_airtable (``Airtable``): Airtable - new_table (``list``): list of new records. Each record should be a ``dictionary``. - match_fields (``list``): List of field names (``string``) via which new records will match with old records. - Kwargs: - overrides (``list``): List of dictionaries - Each dictionary is composed of two items: 1. The override checkbox field name, 2. The override field name - {"ref_field": "field name", "override_field": "field name"} - update (``bool``, default = ``True``): If True records will be updated. - delete (``bool``, default = ``True``): If True records that are not in the new table will be deleted. - Returns: - None - - """ # noqa - start_time = datetime.today() - - existing_table = existing_airtable.get_table() - - updated_indices = [] - for record in new_table: - record_name = record["fields"][list(record["fields"].keys())[0]] - if existing_table: - existing_record, existing_index = record_exists( - record, existing_table, match_fields - ) - if existing_record: - if update: - updated_indices.append(existing_index) - - record = compare_records( - record, - existing_record, - "overwrite", - overrides=overrides, - ) - if record["fields"]: - existing_airtable.update_record(record) - else: - existing_airtable.post_record(record, message=record_name) - else: - existing_airtable.post_record(record, message=record_name) - - if existing_table: - updated_indices = set(updated_indices) - all_indices = {i for i in range(len(existing_table))} - dead_indices = all_indices - updated_indices - if len(dead_indices) > 0: - for index in dead_indices: - dead_record = existing_table[index] - record_name = dead_record["fields"][ - list(dead_record["fields"].keys())[0] - ] - if delete: - existing_airtable.delete_record( - dead_record, message=record_name - ) - - end_time = datetime.today() - logger.info( - "CRUDed %s records in: %s", len(new_table), end_time - start_time - ) diff --git a/airbase/decorators.py b/airbase/decorators.py new file mode 100644 index 0000000..f3fc5f8 --- /dev/null +++ b/airbase/decorators.py @@ -0,0 +1,35 @@ +from asyncio import create_task, gather +from functools import wraps +from typing import Callable + +MAX_CHUNK_SIZE = 10 + + +def chunkify(func: Callable): + @wraps(func) + async def inner(self, *args, **kwargs): + records = kwargs["records"] + method = kwargs["method"] + typecast = kwargs.get("typecast") or False + + records_iter = ( + records[i : i + MAX_CHUNK_SIZE] + for i in range(0, len(records), MAX_CHUNK_SIZE) + ) + + tasks = [] + for sub_list in records_iter: + tasks.append( + create_task(func(self, method, sub_list, typecast)) + ) + task_return_values = await gather(*tasks) + + unpacked_results = [] + for task_return_value in task_return_values: + if task_return_value.get("records"): + unpacked_results.extend(task_return_value["records"]) + else: + unpacked_results.append(task_return_value) + return unpacked_results + + return inner diff --git a/airbase/exceptions.py b/airbase/exceptions.py new file mode 100644 index 0000000..20e931c --- /dev/null +++ b/airbase/exceptions.py @@ -0,0 +1,15 @@ +from aiohttp import ClientResponse +from typing import Optional + + +class AirbaseException(Exception): + """AirbaseException""" + + pass + + +class AirbaseResponseException(AirbaseException): + """AirbaseResponseException""" + + def __init__(self, *args, response: Optional[ClientResponse] = None): + super().__init__(*args) diff --git a/airbase/py.typed b/airbase/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/airbase/session/__init__.py b/airbase/session/__init__.py deleted file mode 100644 index a21936c..0000000 --- a/airbase/session/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from .session import Session # noqa: F401 diff --git a/airbase/session/session.py b/airbase/session/session.py deleted file mode 100644 index 9926f3e..0000000 --- a/airbase/session/session.py +++ /dev/null @@ -1,336 +0,0 @@ -from __future__ import absolute_import - -import json -import sys - -try: - from requests import codes - from requests import Session as _Session - from requests.adapters import HTTPAdapter - from requests.exceptions import ConnectionError, Timeout - - SUCCESS_CODES = ( - codes.ok, - codes.created, - codes.accepted, - codes.partial_content, - ) - -except AttributeError: - raise AttributeError( - "Stack frames are disabled, please enable stack frames.\ - If in pyRevit, place the following at the top of your file: \ - '__fullframeengine__ = True' and reload pyRevit." - ) - -try: - from System.Net import ( - SecurityProtocolType, - ServicePointManager, - WebRequest, - ) - from System.IO import File, StreamReader - from System.Text.Encoding import UTF8 - - # from System.Threading import Tasks - - ServicePointManager.SecurityProtocol = SecurityProtocolType.Tls12 - - SUCCESS_CODES = ("OK", "Created", "Accepted", "Partial Content") -except ImportError: - pass - -from ..utils import Logger # noqa: E402 - -logger = Logger.start(__name__) - - -class Request(object): - def __init__(self, request, stream=False, message=""): - self.response = request - self.stream = stream - self.message = message - - @property - def data(self): - if not getattr(self, "_data", None): - if sys.implementation.name == "ironpython": - pass - else: # if sys.implementation.name == "cpython" - # if response is a json object - try: - self._data = self.response.json() - # else if raw data - except json.decoder.JSONDecodeError: - self._data = self.response.content - return self._data - - @property - def success(self): - if not getattr(self, "_success", None): - if sys.implementation.name == "ironpython": - pass - else: # if sys.implementation.name == "cpython" - self._success = self.response.status_code in SUCCESS_CODES - - if not self._success: - self._log_error() - - return self._success - - def _log_error(self): - if sys.implementation.name == "ironpython": - error_msg = "" - else: # if sys.implementation.name == "cpython" - try: - response_error = ( - json.loads(self.response.text).get("error") - or json.loads(self.response.text).get("errors") - or self.response.json().get("message") - ) - except json.decoder.JSONDecodeError: - response_error = self.response.text - if response_error: - if isinstance(response_error, list): - error_msg = ", ".join( - [error["detail"] for error in response_error] - ) - elif isinstance(response_error, dict): - error_msg = response_error.get( - "message" - ) or response_error.get("type") - else: - error_msg = str(response_error) - else: - error_msg = self.response.status_code - - logger.warning( - "Failed to {} - ERROR: {}".format(self.message, error_msg) - ) - - -class Session(object): - def __init__(self, timeout=2, max_retries=3, base_url=None): - """ - Kwargs: - timeout (``int``, default=2): maximum time for one request in minutes. - max_retries (``int``, default=3): maximum number of retries. - base_url (``str``, optional): Base URL for this Session - """ # noqa:E501 - try: - self.session = _Session() - if base_url: - adapter = HTTPAdapter(max_retries=max_retries) - self.session.mount(base_url, adapter) - self.timeout = int(timeout * 60) # in secs - self.success_codes = (codes.ok, codes.created, codes.accepted) - except Exception: - self.session = None - self.timeout = int(timeout * 60 * 1000) # in ms - self.success_codes = ("OK", "Created", "Accepted") - - @staticmethod - def _add_url_params(url, params): - """ - Appends an encoded dict as url parameters to the call API url - Args: - url (``str``): uri for API call. - params (``dict``): dictionary of request uri parameters. - Returns: - url (``str``): url with params - """ - url_params = "" - count = 0 - for key, value in params.items(): - if count == 0: - url_params += "?" - else: - url_params += "&" - url_params += key + "=" - url_params += str(params[key]) - count += 1 - return url + url_params - - @staticmethod - def _url_encode(data): - """ - Encodes a dict into a url encoded string. - Args: - data (``dict``): source data - Returns: - urlencode (``str``): url encoded string - """ - urlencode = "" - count = len(data) - for key, value in data.items(): - urlencode += key + "=" + str(value) - if count != 1: - urlencode += "&" - count -= 1 - return urlencode - - def _request_cpython(self, *args, **kwargs): - method, url = args - headers = kwargs.get("headers") - params = kwargs.get("params") - json_data = kwargs.get("json_data") - byte_data = kwargs.get("byte_data") - urlencode = kwargs.get("urlencode") - filepath = kwargs.get("filepath") - stream = kwargs.get("stream") - - try: - if headers: - self.session.headers = headers - - # get file contents as bytes - if filepath: - with open(filepath, "rb") as fp: - data = fp.read() - # else raw bytes - elif byte_data: - data = byte_data - # else urlencode - elif urlencode: - data = urlencode - else: - data = None - - return self.session.request( - method.lower(), - url, - params=params, - json=json_data, - data=data, - timeout=self.timeout, - stream=stream, - ) - - except (ConnectionError, Timeout) as e: - raise e - - def _request_ironython(self, *args, **kwargs): - method, url = args - headers = kwargs.get("headers") - params = kwargs.get("params") - json_data = kwargs.get("json_data") - byte_data = kwargs.get("byte_data") - urlencode = kwargs.get("urlencode") - filepath = kwargs.get("filepath") - - try: - # prepare params - if params: - url = self._add_url_params(url, params) - - web_request = WebRequest.Create(url) - web_request.Method = method.upper() - web_request.Timeout = self.timeout - - # prepare headers - if headers: - for key, value in headers.items(): - if key == "Content-Type": - web_request.ContentType = value - elif key == "Content-Length": - web_request.ContentLength = value - else: - web_request.Headers.Add(key, value) - - byte_arrays = [] - if json_data: - byte_arrays.append( - UTF8.GetBytes(json.dumps(json_data, ensure_ascii=False)) - ) - if filepath: - byte_arrays.append(File.ReadAllBytes(filepath)) - if byte_data: - pass - # TODO - Add byte input for System.Net - if urlencode: - byte_arrays.append(UTF8.GetBytes(self._url_encode(urlencode))) - - for byte_array in byte_arrays: - web_request.ContentLength = byte_array.Length - with web_request.GetRequestStream() as req_stream: - req_stream.Write(byte_array, 0, byte_array.Length) - try: - with web_request.GetResponse() as response: - success = response.StatusDescription in SUCCESS_CODES - - with response.GetResponseStream() as response_stream: - with StreamReader(response_stream) as stream_reader: - data = json.loads(stream_reader.ReadToEnd()) - except SystemError: - return None, None - finally: - web_request.Abort() - - except Exception as e: - raise e - - return data, success - - def request( - self, - method, - url, - headers=None, - params=None, - json_data=None, - byte_data=None, - urlencode=None, - filepath=None, - stream=False, - message="", - ): - """ - Request wrapper for cpython and ironpython. - Args: - method (``str``): api method. - url (``str``): uri for API call. - Kwargs: - headers (``dict``, optional): dictionary of request headers. - params (``dict``, optional): dictionary of request uri parameters. - json_data (``json``, optional): request body if Content-Type is json. - urlencode (``dict``, optional): request body if Content-Type is urlencoded. - filepath (``str``, optional): filepath of object to upload. - stream (``bool``, default=False) whether to sream content of not - message (``str``, optional): filepath of object to upload. - - Returns: - data (``json``): Body of response. - success (``bool``): True if response returned a accepted, created or ok status code. - """ # noqa:E501 - if sys.implementation.name == "ironpython": - return self._request_ironpython( - method, - url, - headers=headers, - params=params, - json_data=json_data, - byte_data=byte_data, - urlencode=urlencode, - filepath=filepath, - stream=stream, - ) - else: # if sys.implementation.name == "cpython" - response = Request( - self._request_cpython( - method, - url, - headers=headers, - params=params, - json_data=json_data, - byte_data=byte_data, - urlencode=urlencode, - filepath=filepath, - stream=stream, - ), - message=message, - ) - return response.data, response.success - - -if __name__ == "__main__": - pass diff --git a/airbase/tools.py b/airbase/tools.py index 72cda0c..936781c 100644 --- a/airbase/tools.py +++ b/airbase/tools.py @@ -1,6 +1,8 @@ from __future__ import absolute_import -import copy +import pandas as pd + +from typing import List from .utils import Logger @@ -8,7 +10,133 @@ logger = Logger.start(__name__) -def is_record(value): +async def compare( + df_1: pd.DataFrame, df_2: pd.DataFrame, primary_keys: List[str] +) -> pd.DataFrame: + """ + Compare two pandas DataFrames and return a DataFrame with rows sorted by CRUD operation + + Args: + df_1 (``pd.DataFrame``): DataFrame to compare (i.e. new payload) + df_2 (``pd.DataFrame``): DataFrame to compare against (i.e. existing data) + primary_keys(``list``): List of `str`` of the name of the primary keys to join the DataFrames + + Returns: + pd.DataFrame + """ # noqa: E501 + + # list of headers + df_1_headers: List[str] = df_1.columns.values.tolist() + df_2_headers: List[str] = df_2.columns.values.tolist() + + # combined and overlapping headers for dropping columns later + combined_headers: set = set(df_1_headers) | set(df_2_headers) + overlapping_headers: set = set(df_1_headers) & set(df_2_headers) + # list of overlapping_headers without primary keys + reduced_headers_a: List[str] = list( + overlapping_headers - set(primary_keys) + ) + # list of combined_headers without primary keys + reduced_headers_b: List[str] = list(combined_headers - set(primary_keys)) + + # full outer join of both DataFrames on the primary_keys + # where the left is df_1 (with suffix '_x' applied to its column names) + # and the right is df_2 (with suffix '_y' applied to its column names) + combined_df: pd.DataFrame = df_1.merge( + df_2, how="outer", on=primary_keys, indicator=True + ) + + # rows to be created are found by looking at unique rows in df_1 (left) + create_df: pd.DataFrame = ( + combined_df.loc[lambda x: x["_merge"] == "left_only"] + .drop( + columns=[f"{other_header}_y" for other_header in reduced_headers_a] + ) # df_2 columns are dropped + .rename( + columns={ + f"{other_header}_x": other_header + for other_header in reduced_headers_a + } + ) # df_1 columns are renamed back to original (without suffix) + .drop(columns=["_merge"]) # _merge indicator column is dropped + ) + + # rows to be deleted are found by looking at unique rows in df_2 (right) + delete_df: pd.DataFrame = ( + combined_df.loc[lambda x: x["_merge"] == "right_only"] + .drop( + columns=[f"{other_header}_x" for other_header in reduced_headers_a] + ) # df_1 columns are dropped + .rename( + columns={ + f"{other_header}_y": other_header + for other_header in reduced_headers_a + } + ) # df_1 columns are renamed back to original (without suffix) + .drop(columns=["_merge"]) # _merge indicator column is dropped + ) + + # rows to be updated are found by looking at rows in df_1 (left) + # that differ from df_2 (right) minus the rows to be created (create_df) + + # full outer join of both DataFrames on no keys + # where the left is df_1 (with suffix '_x' applied to its column names) + # and the right is df_2 (with suffix '_y' applied to its column names) + df_1_unique_rows: pd.DataFrame = ( + df_1.merge(df_2, how="outer", indicator=True, sort=True) + .loc[ + lambda x: x["_merge"] == "left_only" + ] # only df_1 unique rows are kept + .drop(columns=["_merge"]) # _merge indicator column is dropped + ) + + # full outer join of both DataFrames on no primary keys + # where the left is df_1_unique_rows + # (with suffix '_x' applied to its column names) + # and the right is create_df + # (with suffix '_y' applied to its column names) + update_df: pd.DataFrame = ( + df_1_unique_rows.merge( + create_df, indicator=True, how="outer", on=primary_keys + ) + .loc[ + lambda x: x["_merge"] != "both" + ] # only df_1_unique_rows unique rows are kept + .drop( + columns=[f"{other_header}_y" for other_header in reduced_headers_b] + ) # create_df columns are dropped + .rename( + columns={ + f"{other_header}_x": other_header + for other_header in reduced_headers_b + } + ) # df_1_unique_rows columns are renamed back to original + .drop(columns=["_merge"]) # _merge indicator column is dropped + ) + + # insert crud_type for each + create_df.insert( + loc=len(df_1_headers), + column="crud_type", + value=["create"] * create_df.shape[0], + ) + + update_df.insert( + loc=len(df_1_headers), + column="crud_type", + value=["update"] * update_df.shape[0], + ) + + delete_df.insert( + loc=len(df_2_headers), + column="crud_type", + value=["delete"] * delete_df.shape[0], + ) + + return create_df.append(update_df, sort=True).append(delete_df, sort=True) + + +async def is_record(value): """ Checks whether a value is a Record ID or a list of Record IDs @@ -22,7 +150,7 @@ def is_record(value): return isinstance(value, str) and value[0:3] == "rec" and len(value) == 17 -def get_primary_keys_as_hashable(record, primary_keys): +async def get_primary_keys_as_hashable(record, primary_keys): hashable_keys = [] for key in primary_keys: val = record["fields"].get(key) @@ -33,13 +161,13 @@ def get_primary_keys_as_hashable(record, primary_keys): return tuple(hashable_keys) if hashable_keys else None -def graft_fields(record, fields, separator=",", sort=True): +async def graft_fields(record, fields, separator=",", sort=True): for field in fields: value = record["fields"].get(field) if value: if separator in value: - value_list = value.split(",") + value_list = value.split(separator) if sort: value_list = value_list.sort() else: @@ -48,8 +176,11 @@ def graft_fields(record, fields, separator=",", sort=True): return record -def link_tables( - table_a, table_b, fields_to_link_in_a, primary_key_b, +async def link_tables( + table_a, + table_b, + fields_to_link_in_a, + primary_key_b, ): """ Links records from another table to a record based on filter criteria. @@ -69,9 +200,9 @@ def link_tables( if record_b["fields"].get(primary_key_b) } - new_table = [] + # new_table = [] for record_a in table_a: - new_record = copy.deepcopy(record_a) + # new_record = copy.deepcopy(record_a) for field_to_link in fields_to_link_in_a: field_to_link = field_to_link.strip() val = record_a["fields"][field_to_link] @@ -79,16 +210,16 @@ def link_tables( continue keys = (x.strip() for x in val.split(",")) - new_record["fields"][field_to_link] = [ + record_a["fields"][field_to_link] = [ table_b_by_primary_key.get(key) for key in keys if table_b_by_primary_key.get(key) ] - new_table.append(new_record) - return new_table + # new_table.append(new_record) + return table_a -def combine_records(record_a, record_b, join_fields=None): +async def combine_records(record_a, record_b, join_fields=None): """ Combines unique information from two records into 1. @@ -128,7 +259,7 @@ def combine_records(record_a, record_b, join_fields=None): return record_a -def filter_record(record_a, record_b, filter_fields=None): +async def filter_record(record_a, record_b, filter_fields=None): """ Filters a record for unique information. @@ -152,15 +283,44 @@ def filter_record(record_a, record_b, filter_fields=None): for key in keys: try: - if record_a["fields"][key] != record_b["fields"][key]: + if isinstance(record_a["fields"][key], list): + if ( + isinstance(record_a["fields"][key][0], dict) + and "url" in record_a["fields"][key][0] + ): + record_a_items = set( + [ + item["url"].split("/")[-1] + for item in record_a["fields"][key] + ] + ) + record_b_items = set( + [ + item["url"].split("/")[-1] + for item in record_b["fields"][key] + ] + ) + else: + record_a_items = set( + item for item in record_a["fields"][key] + ) + record_b_items = set( + item for item in record_b["fields"][key] + ) + diff = record_a_items - record_b_items + if len(diff) != 0: + record["fields"][key] = record_a["fields"][key] + + elif record_a["fields"][key] != record_b["fields"][key]: record["fields"][key] = record_a["fields"][key] - except KeyError: + + except (KeyError, IndexError): if record_a["fields"][key]: record["fields"][key] = record_a["fields"][key] return record -def override_record(record, existing_record, overrides): +async def override_record(record, existing_record, overrides): """ Removes fields from record if user has overriden them on airtable. @@ -183,14 +343,14 @@ def override_record(record, existing_record, overrides): return record -def compare_records( +async def compare_records( record_a, record_b, method, overrides=None, filter_fields=None ): """ Compares a record in a table. Args: - record_a (``dictionary``): record to compare + record_a (``dictionary``): record to compare record_b (``dictionary``): record to compare against. method (``string``): Either "overwrite" or "combine" Kwargs: @@ -217,7 +377,7 @@ def compare_records( logger.warning("Invalid record format provided.") -def replace_values(field, value): +async def replace_values(field, value): # Simplify attachement objects if isinstance(value, list) and isinstance(value[0], dict): new_value = [{"url": obj["url"]} for obj in value if "url" in obj] diff --git a/airbase/tools_async.py b/airbase/tools_async.py deleted file mode 100644 index 4452388..0000000 --- a/airbase/tools_async.py +++ /dev/null @@ -1,354 +0,0 @@ -from __future__ import absolute_import - -import pandas as pd - -from typing import List - -from .utils import Logger - - -logger = Logger.start(__name__) - - -async def compare( - df_1: pd.DataFrame, df_2: pd.DataFrame, primary_keys: List[str] -) -> pd.DataFrame: - """ - Compare two pandas DataFrames and return a DataFrame with rows sorted by CRUD operation - - Args: - df_1 (``pd.DataFrame``): DataFrame to compare (i.e. new payload) - df_2 (``pd.DataFrame``): DataFrame to compare against (i.e. existing data) - primary_keys(``list``): List of `str`` of the name of the primary keys to join the DataFrames - - Returns: - pd.DataFrame - """ # noqa: E501 - - # list of headers - df_1_headers: List[str] = df_1.columns.values.tolist() - df_2_headers: List[str] = df_2.columns.values.tolist() - - # combined and overlapping headers for dropping columns later - combined_headers: set = set(df_1_headers) | set(df_2_headers) - overlapping_headers: set = set(df_1_headers) & set(df_2_headers) - # list of overlapping_headers without primary keys - reduced_headers_a: List[str] = list( - overlapping_headers - set(primary_keys) - ) - # list of combined_headers without primary keys - reduced_headers_b: List[str] = list(combined_headers - set(primary_keys)) - - # full outer join of both DataFrames on the primary_keys - # where the left is df_1 (with suffix '_x' applied to its column names) - # and the right is df_2 (with suffix '_y' applied to its column names) - combined_df: pd.DataFrame = df_1.merge( - df_2, how="outer", on=primary_keys, indicator=True - ) - - # rows to be created are found by looking at unique rows in df_1 (left) - create_df: pd.DataFrame = ( - combined_df.loc[lambda x: x["_merge"] == "left_only"] - .drop( - columns=[f"{other_header}_y" for other_header in reduced_headers_a] - ) # df_2 columns are dropped - .rename( - columns={ - f"{other_header}_x": other_header - for other_header in reduced_headers_a - } - ) # df_1 columns are renamed back to original (without suffix) - .drop(columns=["_merge"]) # _merge indicator column is dropped - ) - - # rows to be deleted are found by looking at unique rows in df_2 (right) - delete_df: pd.DataFrame = ( - combined_df.loc[lambda x: x["_merge"] == "right_only"] - .drop( - columns=[f"{other_header}_x" for other_header in reduced_headers_a] - ) # df_1 columns are dropped - .rename( - columns={ - f"{other_header}_y": other_header - for other_header in reduced_headers_a - } - ) # df_1 columns are renamed back to original (without suffix) - .drop(columns=["_merge"]) # _merge indicator column is dropped - ) - - # rows to be updated are found by looking at rows in df_1 (left) - # that differ from df_2 (right) minus the rows to be created (create_df) - - # full outer join of both DataFrames on no keys - # where the left is df_1 (with suffix '_x' applied to its column names) - # and the right is df_2 (with suffix '_y' applied to its column names) - df_1_unique_rows: pd.DataFrame = ( - df_1.merge(df_2, how="outer", indicator=True, sort=True) - .loc[ - lambda x: x["_merge"] == "left_only" - ] # only df_1 unique rows are kept - .drop(columns=["_merge"]) # _merge indicator column is dropped - ) - - # full outer join of both DataFrames on no primary keys - # where the left is df_1_unique_rows - # (with suffix '_x' applied to its column names) - # and the right is create_df - # (with suffix '_y' applied to its column names) - update_df: pd.DataFrame = ( - df_1_unique_rows.merge( - create_df, indicator=True, how="outer", on=primary_keys - ) - .loc[ - lambda x: x["_merge"] != "both" - ] # only df_1_unique_rows unique rows are kept - .drop( - columns=[f"{other_header}_y" for other_header in reduced_headers_b] - ) # create_df columns are dropped - .rename( - columns={ - f"{other_header}_x": other_header - for other_header in reduced_headers_b - } - ) # df_1_unique_rows columns are renamed back to original - .drop(columns=["_merge"]) # _merge indicator column is dropped - ) - - # insert crud_type for each - create_df.insert( - loc=len(df_1_headers), - column="crud_type", - value=["create"] * create_df.shape[0], - ) - - update_df.insert( - loc=len(df_1_headers), - column="crud_type", - value=["update"] * update_df.shape[0], - ) - - delete_df.insert( - loc=len(df_2_headers), - column="crud_type", - value=["delete"] * delete_df.shape[0], - ) - - return create_df.append(update_df, sort=True).append(delete_df, sort=True) - - -async def is_record(value): - """ - Checks whether a value is a Record ID or a list of Record IDs - - Args: - value (``obj``): any value retrieved from an airtable record field. - Returns: - (``bool``): True if value is Record ID or a list of Record IDs - """ - if isinstance(value, list) and value: - value = value[0] - return isinstance(value, str) and value[0:3] == "rec" and len(value) == 17 - - -async def get_primary_keys_as_hashable(record, primary_keys): - hashable_keys = [] - for key in primary_keys: - val = record["fields"].get(key) - if isinstance(val, list): - val = tuple(val) - if val: - hashable_keys.append(val) - return tuple(hashable_keys) if hashable_keys else None - - -async def graft_fields(record, fields, separator=",", sort=True): - - for field in fields: - value = record["fields"].get(field) - if value: - if separator in value: - value_list = value.split(",") - if sort: - value_list = value_list.sort() - else: - value_list = [value] - record["fields"][field] = value_list - return record - - -async def link_tables( - table_a, table_b, fields_to_link_in_a, primary_key_b, -): - """ - Links records from another table to a record based on filter criteria. - - Args: - table_a (``list``): List of records. - table_b (``list``): List of records to link to. - fields_to_link_a (``list``): list of fields(``string``) in ``table_a`` to search in ``table_b``. - primary_key_b (``str``): key to search in ``table_b`` - Returns: - record (``dictionary``): If exists. If not returns ``None``. - """ # noqa: E501 - primary_key_b = primary_key_b.strip() - table_b_by_primary_key = { - record_b["fields"].get(primary_key_b): record_b["id"] - for record_b in table_b - if record_b["fields"].get(primary_key_b) - } - - # new_table = [] - for record_a in table_a: - # new_record = copy.deepcopy(record_a) - for field_to_link in fields_to_link_in_a: - field_to_link = field_to_link.strip() - val = record_a["fields"][field_to_link] - if not val: - continue - - keys = (x.strip() for x in val.split(",")) - record_a["fields"][field_to_link] = [ - table_b_by_primary_key.get(key) - for key in keys - if table_b_by_primary_key.get(key) - ] - # new_table.append(new_record) - return table_a - - -async def combine_records(record_a, record_b, join_fields=None): - """ - Combines unique information from two records into 1. - - Args: - record_a (``dictionary``): New airtable record. - record_b (``dictionary``): Old airtable record (This will be dictate the ``id``) - Kwargs: - join_fields (``list``, optional): list of fields(``string``) to combine. - Returns: - record (``dictionary``): If succesful, the combined ``record``, else ``record_a``. - """ # noqa - try: - record = {"id": record_b["id"], "fields": {}} - - if join_fields: - keys = join_fields - else: - keys = record_a["fields"] - for key in keys: - field = record_a["fields"][key] - if isinstance(field, list): - field = record_a["fields"][key] - for item in record_b["fields"][key]: - if item not in record_a["fields"][key]: - field.append(item) - elif isinstance(field, str): - field = ( - record_a["fields"][key] + ", " + record_b["fields"][key] - ) - elif isinstance(field, int) or ( - isinstance(field, float) or isinstance(field, tuple) - ): - field = record_a["fields"][key] + record_b["fields"][key] - record["fields"][key] = field - return record - except Exception: - return record_a - - -async def filter_record(record_a, record_b, filter_fields=None): - """ - Filters a record for unique information. - - Args: - record_a (``dictionary``): New airtable record. - record_b (``dictionary``): Old airtable record (This will be dictate the ``id``) - Kwargs: - filter_fields (``list``, optional): list of fields(``string``) to filter. - Returns: - record (``dictionary``): If succesful, the filtered ``record``, else ``record_a``. - """ # noqa - try: - record = {"id": record_b["id"], "fields": {}} - if filter_fields: - keys = filter_fields - else: - keys = record_a["fields"] - except Exception: - logger.warning("Could not filter record.") - return record_a - - for key in keys: - try: - if record_a["fields"][key] != record_b["fields"][key]: - record["fields"][key] = record_a["fields"][key] - except KeyError: - if record_a["fields"][key]: - record["fields"][key] = record_a["fields"][key] - return record - - -async def override_record(record, existing_record, overrides): - """ - Removes fields from record if user has overriden them on airtable. - - Args: - record (``dictionary``): Record from which fields will be removed if overwritten. - existing_record (``dictionary``): Record to check for overrides. - overrides (``list``): List of dictionaries - Each dictionary is composed of two items: 1. The override checkbox field name, 2. The override field name - {"ref_field": "field name", "override_field": "field name"} - Return: - record. - """ # noqa - for override in overrides: - ref_field = override.get("ref_field") - override_field = override.get("override_field") - if existing_record["fields"].get(ref_field): - record["fields"][override_field] = existing_record["fields"][ - override_field - ] - return record - - -async def compare_records( - record_a, record_b, method, overrides=None, filter_fields=None -): - """ - Compares a record in a table. - - Args: - record_a (``dictionary``): record to compare - record_b (``dictionary``): record to compare against. - method (``string``): Either "overwrite" or "combine" - Kwargs: - overrides (``list``): List of dictionaries - Each dictionary is composed of two items: 1. The override checkbox field name, 2. The override field name - {"ref_field": "field name", "override_field": "field name"} - filter_fields (``list``, optional): list of fields(``string``) to update. - Returns: - records (``list``): If succesful, a list of existing records (``dictionary``). - """ # noqa - try: - if overrides: - record = override_record(record_a, record_b, overrides) - if method == "overwrite": - record = filter_record( - record_a, record_b, filter_fields=filter_fields - ) - elif method == "combine": - record = combine_records( - record_a, record_b, join_fields=filter_fields - ) - return record - except Exception: - logger.warning("Invalid record format provided.") - - -async def replace_values(field, value): - # Simplify attachement objects - if isinstance(value, list) and isinstance(value[0], dict): - new_value = [{"url": obj["url"]} for obj in value if "url" in obj] - else: - new_value = value - return new_value diff --git a/airbase/urls.py b/airbase/urls.py index d2e42a9..22f14a5 100644 --- a/airbase/urls.py +++ b/airbase/urls.py @@ -1,2 +1,2 @@ BASE_URL = "https://api.airtable.com/v0" -META_URL = "{}/meta".format(BASE_URL) +META_URL = f"{BASE_URL}/meta" diff --git a/airbase/utils/__init__.py b/airbase/utils/__init__.py index 1e3ff0e..eebded9 100644 --- a/airbase/utils/__init__.py +++ b/airbase/utils/__init__.py @@ -1,10 +1,9 @@ from __future__ import absolute_import -import sys - from collections import deque from json import dumps from pprint import pformat +from typing import Any, Union try: from collections.abc import Iterable, Mapping @@ -12,12 +11,10 @@ from collections import Iterable, Mapping from .logger import Logger # noqa - -if sys.version_info >= (3, 7): - from .semaphore import HTTPSemaphore # noqa: F401 +from .semaphore import HTTPSemaphore # noqa: F401 -def pretty_print(obj, sort=True, _print=True): +def pretty_print(obj: Any, sort: bool = True, _print: bool = True) -> str: """ """ try: if isinstance(obj, Mapping): @@ -31,14 +28,15 @@ def pretty_print(obj, sort=True, _print=True): if _print: print(output) + return output -def _pretty_print(obj, sort=True): +def _pretty_print(obj: Union[Mapping, Iterable], sort: bool = True) -> str: return dumps(obj, sort_keys=sort, indent=4, ensure_ascii=False) -def _obj_to_dict(obj): +def _obj_to_dict(obj: Any) -> dict: try: d = {pformat(obj): _clean(obj.__dict__)} except AttributeError: @@ -46,7 +44,9 @@ def _obj_to_dict(obj): return d -def _clean(obj, is_mapping=True): +def _clean( + obj: Union[Mapping, Iterable], is_mapping=True +) -> Union[Mapping, Iterable]: if is_mapping: clean_data = {} iterable = obj.items() diff --git a/airbase/utils/logger.py b/airbase/utils/logger.py index cdf6bd0..c336f46 100644 --- a/airbase/utils/logger.py +++ b/airbase/utils/logger.py @@ -3,7 +3,7 @@ class Logger(object): @staticmethod - def start(name, level="info"): + def start(name, level : str = "info"): LOG_FORMAT = "%(asctime)s (%(name)s): %(levelname)s - %(message)s" DATE_FORMAT = "%Y-%m-%d %H:%M:%S (UTC/GMT %z)" diff --git a/airbase/utils/semaphore.py b/airbase/utils/semaphore.py index ff156a8..ff1c6ec 100644 --- a/airbase/utils/semaphore.py +++ b/airbase/utils/semaphore.py @@ -11,6 +11,8 @@ class HTTPSemaphore(BoundedSemaphore): + """ """ + def __init__( self, value: int = 10, diff --git a/airbase/validations.py b/airbase/validations.py index 67970ea..7f829f4 100644 --- a/airbase/validations.py +++ b/airbase/validations.py @@ -1,3 +1,9 @@ +from __future__ import absolute_import + +from typing import Dict, Iterable, Union + +from .exceptions import AirbaseException + FIELD_TYPES = ( "singleLineText", "email", @@ -28,10 +34,56 @@ PERMISSION_LEVELS = ("read", "comment", "edit", "create") +async def validate_records( + records: Union[Iterable[Dict], Dict], record_id=True, fields=True +) -> None: + """ + Validates a Record or Records. Raises an AirbaseException if invalid. + + Args: + records (``dict``): a record or a a list of records + """ + if isinstance(records, list) and records: + for r in records: + await validate_records(r) + + elif isinstance(records, dict): + if record_id: + if records.get("id"): + if not isinstance(records["id"], str): + AirbaseException( + "Invalid Type: record['id'] must be a string." + ) + elif not ( + records["id"][0:3] == "rec" and len(records["id"]) == 17 + ): + AirbaseException( + "Invalid Record ID: record['id'] must be a string in the following format: 'rec[a-zA-Z0-9]{17}'." # noqa: E501 + ) + else: + AirbaseException( + "Invalid Record: record must include a key 'id' with its corresponding record id value." # noqa: E501 + ) + + if fields: + if records.get("fields"): + if not isinstance(records["fields"], dict): + AirbaseException( + "Invalid Type: record['fields'] must be a dictionary." + ) + else: + AirbaseException( + "Invalid Record: Record must include a key 'fields' with its corresponding field names and values." # noqa: E501 + ) + + else: + raise AirbaseException("Invalid Type: record must be a dictionary.") + + def is_value_acceptable(val, field_type): assert ( field_type in FIELD_TYPES - ), "{} is not an acceptable field type".format(field_type) + ), f"{field_type} is not an acceptable field type" if isinstance(val, str) and field_type in ( "singleLineText", diff --git a/cruder/__init__.py b/cruder/__init__.py deleted file mode 100644 index f6857d5..0000000 --- a/cruder/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from __future__ import absolute_import -from .cruder import Cruder # noqa:F401 diff --git a/cruder/cruder_async.py b/cruder/cruder_async.py deleted file mode 100644 index 2eb9df5..0000000 --- a/cruder/cruder_async.py +++ /dev/null @@ -1,484 +0,0 @@ -from __future__ import absolute_import - -import asyncio -import pandas as pd - -import we.airflow.plugins.utils.snowflake as sf - -from airflow.contrib.hooks.snowflake_hook import SnowflakeHook -from gspread import authorize -from typing import Iterable, Optional - -from .airtable import AirtableAsync as Airtable -from .airtable import tools_async as at_tools -from .airtable.utils import Logger - - -logger = Logger.start(__name__) - - -class Cruder: - @staticmethod - def _fmt_col_names( - column: str, prefix: str, style: str, abbrs: list - ) -> str: - """Format Column Names""" - style = style.lower() - col_name = "" - if prefix: - col_name += prefix - for part in column.split("_"): - if style == "upper" or (abbrs and part.lower() in abbrs): - part = part.upper() - elif style in ("camel", "title"): - part = part.title() - if style != "camel": - col_name += part + " " - return col_name.strip() - - @staticmethod - async def get_csv_data( - filepath: str, - prefix: str = None, - style: str = None, - abbrs: list = None, - ) -> pd.DataFrame: - """ - Get CSV file as a pandas DataFrame. - - Args: - filepath: full path to csv file. - - Kwargs: - prefix (default=None): desired prefix - style (default=None): "lower", "upper", "camel" or "title" - abbrs (default=None): list of lowercase abbrs to make upper case - - Return: - data: If succesful, pandas DataFrame - """ # noqa: E501 - df = pd.read_csv(filepath, encoding="utf-8") - - if style: - df.rename( - columns=lambda x: Cruder._fmt_col_names( - x, prefix, style, abbrs - ), - inplace=True, - ) - - logger.info(f"Fetched {df.shape[0]} rows/records from {filepath}") - return df - - @staticmethod - async def get_snowflake_data( - filepath: str, - sf_hook: SnowflakeHook, - prefix: str = None, - style: str = None, - abbrs: list = None, - ) -> pd.DataFrame: - """ - Get SQL formatted Snowflake query. - - Args: - filepath: full path to SQL Query. - sf_hook: Snowflake Hook - - Kwargs: - prefix (efault=None): desired prefix - style (default=None): "lower", "upper", "camel" or "title" - abbrs (default=None): list of lowercase abbrs to make upper case - - Return: - data: If succesful, pandas DataFrame - """ # noqa: E501 - - with open(filepath, "r") as fp: - query = fp.read() - - df = sf.get_tbl_query(sf_hook=sf_hook, query=query, return_raw=False) - - if style: - df.rename( - columns=lambda x: Cruder._fmt_col_names( - x, prefix, style, abbrs - ), - inplace=True, - ) - logger.info(f"Fetched {df.shape[0]} rows/records from Snowflake") - return df - - @staticmethod - async def convert_to_df( - source: Optional[Iterable] = None, format: Optional[str] = None - ) -> pd.DataFrame: - """ - Get SQL query formatted for airtable - - Args: - df: pandas DataFrame - - Kwargs: - output (default="records"): "records" or "rows" - - Return: - data: If succesful, list of records or rows - """ # noqa: E501 - assert source in ( - "records", - "rows", - ), "{} is not an acceptable output type".format(source) - - if isinstance(source, dict): - pass - - elif isinstance(source, list): - pass - - @staticmethod - async def convert_df(df: pd.DataFrame, output: str = "records") -> list: - """ - Get SQL query formatted for airtable - - Args: - df: pandas DataFrame - - Kwargs: - output (default="records"): "records" or "rows" - - Return: - data: If succesful, list of records or rows - """ # noqa: E501 - assert output in ( - "records", - "rows", - ), "{} is not an acceptable output type".format(output) - - if output == "records": - df = df.where(pd.notnull(df), None) - records = [{"fields": fields} for fields in df.to_dict("records")] - return records - - elif output == "rows": - headers = df.columns.values.tolist() - rows = df.values.tolist() - rows.insert(0, headers) - return rows - - @staticmethod - async def crud(data: list, dest: str, **kwargs) -> None: - """ - Post Snowflake Query to Airtable or Sheets - - Args: - data: List of records or rows to CRUD - dest: airtable or sheets - - Kwargs: - credentials (``oauth2client.service_account.ServiceAccountCredentials`` or ``string``): Google Credentials Object for Service Account or Airtable API Key - type(``string``): Type of CRUD operation ("overwrite", "update", "partial" or "append") - target(``dict``): - table (``string``): Airtable Table Name or ID - wks_index (``int``): Worksheet Index - id (``string``): Airtable Base ID or Spreadheet Key - primary_keys (``list``): List of field names to be used as primary keys - links (``list``, optional): List of lowercase abbrs to make upper case - table (``string``): Name of Airtable Table to link to - primary_key (``string``): Name of field to use as a primary key in link table - fields (``list``): List of fields to link - arrays (``list``, optional): list of fields to turn into arrays for Multiple Select - overrides (``list``, optional): List of dictionaries with two entries: - override_field (``string``): Name of field that user can override - ref_field (``string``): Name of field (checkbox) that flags a user override - value_input_option (``string``, default="USER_ENTERED"): Sheets input style ("RAW" or "USER_ENTERED") - """ # noqa: E501 - if dest == "airtable": - await Cruder._at_crud(data, **kwargs) - elif dest == "sheets": - Cruder._sh_crud(data, **kwargs) - - @staticmethod - async def _at_get_linked_tables( - base: object, records: list, link: dict - ) -> None: - linked_table = await base.get_table(link["table"], key="name") - link["table"] = await linked_table.get_records() - await at_tools.link_tables( - records, link["table"], link["fields"], link["primary_key"], - ) - - @staticmethod - async def _at_sort_record( - record: dict, - primary_keys: list, - arrays: list, - overrides: list, - existing_records: list, - existing_records_indices_by_primary_key: dict, - post_records: list, - update_records: list, - existing_indices: list, - ) -> None: - # turn array values into lists - if arrays: - record = await at_tools.graft_fields(record, arrays) - - # check for existing via primary keys - hashable_keys = await at_tools.get_primary_keys_as_hashable( # noqa: E501 - record, primary_keys - ) - existing_index = ( - existing_records_indices_by_primary_key.get(hashable_keys) - if hashable_keys - else None - ) - - # if record exists - if existing_index is not None: - # add index to existing indices list - existing_indices.append(existing_index) - existing_record = existing_records[existing_index] - - # remove overriden fields - if overrides: - record = await at_tools.override_record( - record, existing_record, overrides - ) - - # filter records to only keep new data - record = await at_tools.filter_record(record, existing_record) - - # if new data, then append to update_records - if record["fields"]: - update_records.append(record) - - # else append to post_records - else: - post_records.append(record) - - @staticmethod - async def _at_crud(records: list, **kwargs) -> None: - """ - Post Snowflake Query to Airtable - - Args: - records: List of records to CRUD - - Kwargs: - credentials (``string``): Airtable API Key - type(``string``): Type of CRUD operation ("overwrite", "update", "partial" or "append") - target(``dict``): - table (``string``): Airtable Table Name or ID - id (``string``): Airtable Base ID - primary_keys (``list``): List of field names to be used as primary keys - links (``list``, optional): List of lowercase abbrs to make upper case - table (``string``): Name of Airtable Table to link to - primary_key (``string``): Name of field to use as a primary key in link table - fields (``list``): List of fields to link - arrays (``list``, optional): list of fields to turn into arrays for Multiple Select - overrides (``list``, optional): List of dictionaries with two entries: - override_field (``string``): Name of field that user can override - ref_field (``string``): Name of field (checkbox) that flags a user override - """ # noqa: E501 - - api_key = kwargs.get("credentials") - mode = kwargs.get("type") - table_key = kwargs.get("target").get("table") - base_id = kwargs.get("id") - primary_keys = kwargs.get("primary_keys") - links = kwargs.get("links") or [] - arrays = kwargs.get("arrays") - prefix = kwargs.get("prefix") - overrides = kwargs.get("overrides") - - # Create Airtable() instance for this base & table - async with Airtable(api_key=api_key) as at: - base = await at.get_base(base_id) - await base.get_tables() - table = [ - table - for table in base.tables - if table.name == table_key or table.id == table_key - ][0] - - # Get records in that table - existing_records = await table.get_records() - - # Get linked tables - await asyncio.gather( - *[ - Cruder._at_get_linked_tables(base, records, link) - for link in links - ], - return_exceptions=False, - ) - - # If records in table - if existing_records: - post_records = [] # Records to post - update_records = [] # Records to update - delete_records = [] # Records to delete - existing_indices = [] # Indices of records in existing_records - - existing_records_indices_by_primary_key = {} - for i, existing_record in enumerate(existing_records): - hashable_keys = await at_tools.get_primary_keys_as_hashable( # noqa: E501 - existing_record, primary_keys - ) - if hashable_keys: - existing_records_indices_by_primary_key[ - hashable_keys - ] = i - - await asyncio.gather( - *[ - Cruder._at_sort_record( - record, - primary_keys, - arrays, - overrides, - existing_records, - existing_records_indices_by_primary_key, - post_records, - update_records, - existing_indices, - ) - for record in records - ], - return_exceptions=False, - ) - - # create new records - await table.post_records(post_records) - # update existing records - await table.update_records(update_records) - - # Get dead record indices - all_indices = set(range(len(existing_records))) - existing_indices = set(existing_indices) - dead_indices = all_indices - existing_indices - - if len(dead_indices) > 0: - # loop through records ro delete - for index in dead_indices: - # get dead record - dead_record = existing_records[index] - - if mode == "overwrite": - delete_records.append(dead_record) - continue - if mode == "update": - del_field = "Delete" - if prefix: - del_field = "AUTO_Delete" - - if not dead_record["fields"].get(del_field): - record = { - "id": dead_record["id"], - "fields": {del_field: True}, - } - delete_records.append(record) - - if delete_records: - # delete record - if mode == "overwrite": - await table.delete_records(delete_records) - # flag manual deletion via 'Delete' checkbox field - elif mode == "update": - await table.update_records(delete_records) - - # If no records in table - else: - if arrays: - # turn array values into lists - records = [ - await at_tools.graft_fields(record, arrays) - for record in records - ] - await table.post_records(records) - - @staticmethod - def _sh_crud(rows: list, **kwargs) -> None: - """ - Post Snowflake Query to Sheets - - Args: - rows: List of rows to CRUD - - Kwargs: - credentials (``oauth2client.service_account.ServiceAccountCredentials``): Google Credentials Object for Service Account - type(``string``): Type of CRUD operation ("overwrite" or "append") - target(``dict``): - wks_index (``int``): Worksheet Index - id (``string``): Spreadheet Key - value_input_option (``string``, default="USER_ENTERED"): Sheets input style ("RAW" or "USER_ENTERED") - """ # noqa: E501 - credentials = kwargs.get("credentials") - sh_key = kwargs.get("id") - wks_index = kwargs.get("target").get("wks_index") - txn_type = kwargs.get("type") - value_input_option = kwargs.get("value_input_option") or "USER_ENTERED" - - n_of_rows = len(rows) - n_of_columns = len(rows[0]) - row_offset = 0 - - # get the work sheet - client = authorize(credentials) - sh = client.open_by_key(sh_key) - wks = sh.get_worksheet(wks_index) - - # overwrite - if txn_type == "overwrite": - wks.resize(rows=n_of_rows, cols=n_of_columns) - cell_list = wks.range(1, 1, n_of_rows, n_of_columns) - - # append - elif txn_type == "append": - rows = rows[1:] - n_of_rows = len(rows) - row_count = wks.row_count - col_count = wks.col_count - if n_of_columns > col_count: - col_count = n_of_columns - wks.resize(rows=row_count + n_of_rows, cols=col_count) - wks = sh.get_worksheet(wks_index) - cell_list = wks.range( - row_count + 1, 1, row_count + n_of_rows, n_of_columns - ) - row_offset = row_count - - # TODO elif txn_type in ("update", "partial") - else: - return - - for i, row in enumerate(rows): - for j, item in enumerate(row): - cell_list[i * n_of_columns + j].value = item if item else "" - - row_chunk = 1500 - max_cells = n_of_columns * row_chunk - for i in range(int(len(cell_list) / max_cells + 1)): - min_index = i * max_cells - max_index = i * max_cells + max_cells - 1 - if max_index > len(cell_list): - max_index = len(cell_list) - - sublist = cell_list[min_index:max_index] - try: - wks.update_cells( - sublist, value_input_option=value_input_option - ) - except TypeError: - wks.update_cells(sublist) - start_row = int(min_index / n_of_columns) + row_offset + 1 - end_row = start_row + row_chunk - 1 - if end_row > n_of_rows + row_offset: - end_row = n_of_rows + row_offset - logger.info(f"Posted rows {start_row} to {end_row} to Sheets") - - if max_index == len(cell_list) - 1: - break - - logger.info("Completed posting to Sheets") diff --git a/cruder/cruder_sync.py b/cruder/cruder_sync.py deleted file mode 100644 index 9f1dea0..0000000 --- a/cruder/cruder_sync.py +++ /dev/null @@ -1,416 +0,0 @@ -from __future__ import absolute_import - -import pandas as pd - -import we.airflow.plugins.utils.snowflake as sf - -from airflow.contrib.hooks.snowflake_hook import SnowflakeHook -from gspread import authorize - -from .airtable import Airtable -from .airtable import tools as at_tools -from .airtable.utils import Logger - - -logger = Logger.start(__name__) - - -class Cruder: - @staticmethod - def _fmt_col_names(column: str, prefix: str, style: str, abbrs: list): - """Format Column Names""" - style = style.lower() - col_name = "" - if prefix: - col_name += prefix - for part in column.split("_"): - if style == "upper" or (abbrs and part.lower() in abbrs): - part = part.upper() - elif style in ("camel", "title"): - part = part.title() - if style != "camel": - col_name += part + " " - return col_name.strip() - - @staticmethod - def get_csv_data( - filepath: str, - prefix: str = None, - style: str = None, - abbrs: list = None, - ) -> pd.DataFrame: - """ - Get CSV file as a pandas DataFrame. - - Args: - filepath: full path to csv file. - - Kwargs: - prefix (default=None): desired prefix - style (default=None): "lower", "upper", "camel" or "title" - abbrs (default=None): list of lowercase abbrs to make upper case - - Return: - data: If succesful, pandas DataFrame - """ # noqa: E501 - df = pd.read_csv(filepath, encoding="utf-8") - - if style: - df.rename( - columns=lambda x: Cruder._fmt_col_names( - x, prefix, style, abbrs - ), - inplace=True, - ) - - logger.info(f"Fetched {df.shape[0]} rows/records from {filepath}") - return df - - @staticmethod - def get_snowflake_data( - filepath: str, - sf_hook: SnowflakeHook, - prefix: str = None, - style: str = None, - abbrs: list = None, - ) -> pd.DataFrame: - """ - Get SQL formatted Snowflake query. - - Args: - filepath: full path to SQL Query. - sf_hook: Snowflake Hook - - Kwargs: - prefix (efault=None): desired prefix - style (default=None): "lower", "upper", "camel" or "title" - abbrs (default=None): list of lowercase abbrs to make upper case - - Return: - data: If succesful, pandas DataFrame - """ # noqa: E501 - - with open(filepath, "r") as fp: - query = fp.read() - - df = sf.get_tbl_query(sf_hook=sf_hook, query=query, return_raw=False) - - if style: - df.rename( - columns=lambda x: Cruder._fmt_col_names( - x, prefix, style, abbrs - ), - inplace=True, - ) - logger.info(f"Fetched {df.shape[0]} rows/records from Snowflake") - return df - - @staticmethod - def convert_df(df: pd.DataFrame, output: str = "records") -> list: - """ - Get SQL query formatted for airtable - - Args: - df: pandas DataFrame - - Kwargs: - output (default="records"): "records" or "rows" - - Return: - data: If succesful, list of records or rows - """ # noqa: E501 - assert output in ( - "records", - "rows", - ), "{} is not an acceptable output type".format(output) - - if output == "records": - df = df.where(pd.notnull(df), None) - records = [{"fields": fields} for fields in df.to_dict("records")] - return records - - elif output == "rows": - headers = df.columns.values.tolist() - rows = df.values.tolist() - rows.insert(0, headers) - return rows - - @staticmethod - def crud(data: list, dest: str, **kwargs) -> None: - """ - Post Snowflake Query to Airtable or Sheets - - Args: - data: List of records or rows to CRUD - dest: airtable or sheets - - Kwargs: - credentials (``oauth2client.service_account.ServiceAccountCredentials`` or ``string``): Google Credentials Object for Service Account or Airtable API Key - type(``string``): Type of CRUD operation ("overwrite", "update", "partial" or "append") - target(``dict``): - table (``string``): Airtable Table Name or ID - wks_index (``int``): Worksheet Index - id (``string``): Airtable Base ID or Spreadheet Key - primary_keys (``list``): List of field names to be used as primary keys - links (``list``, optional): List of lowercase abbrs to make upper case - table (``string``): Name of Airtable Table to link to - primary_key (``string``): Name of field to use as a primary key in link table - fields (``list``): List of fields to link - arrays (``list``, optional): list of fields to turn into arrays for Multiple Select - overrides (``list``, optional): List of dictionaries with two entries: - override_field (``string``): Name of field that user can override - ref_field (``string``): Name of field (checkbox) that flags a user override - value_input_option (``string``, default="USER_ENTERED"): Sheets input style ("RAW" or "USER_ENTERED") - """ # noqa: E501 - if dest == "airtable": - Cruder._at_crud(data, **kwargs) - elif dest == "sheets": - Cruder._sh_crud(data, **kwargs) - - @staticmethod - def _at_crud(records: list, **kwargs) -> None: - """ - Post Snowflake Query to Airtable - - Args: - records: List of records to CRUD - - Kwargs: - credentials (``string``): Airtable API Key - type(``string``): Type of CRUD operation ("overwrite", "update", "partial" or "append") - target(``dict``): - table (``string``): Airtable Table Name or ID - id (``string``): Airtable Base ID - primary_keys (``list``): List of field names to be used as primary keys - links (``list``, optional): List of lowercase abbrs to make upper case - table (``string``): Name of Airtable Table to link to - primary_key (``string``): Name of field to use as a primary key in link table - fields (``list``): List of fields to link - arrays (``list``, optional): list of fields to turn into arrays for Multiple Select - overrides (``list``, optional): List of dictionaries with two entries: - override_field (``string``): Name of field that user can override - ref_field (``string``): Name of field (checkbox) that flags a user override - """ # noqa: E501 - - api_key = kwargs.get("credentials") - mode = kwargs.get("type") - table_key = kwargs.get("target").get("table") - base_id = kwargs.get("id") - primary_keys = kwargs.get("primary_keys") - links = kwargs.get("links") or [] - arrays = kwargs.get("arrays") - prefix = kwargs.get("prefix") - overrides = kwargs.get("overrides") - - # Create Airtable() instance for this base & table - at = Airtable(api_key=api_key) - base = at.get_base(base_id) - base.get_tables() - table = [ - table - for table in base.tables - if table.name == table_key or table.id == table_key - ][0] - - # Get records in that table - existing_records = table.get_records() - - # Get linked tables - for link in links: - linked_table = base.get_table(link["table"]) - link["table"] = linked_table.get_records() - records = at_tools.link_tables( - records, link["table"], link["fields"], link["primary_key"], - ) - - # If records in table - if existing_records: - post_records = [] # Records to post - update_records = [] # Records to update - existing_indices = [] # Indices of records in existing_records - - existing_records_indices_by_primary_key = {} - for i, existing_record in enumerate(existing_records): - hashable_keys = at_tools.get_primary_keys_as_hashable( - existing_record, primary_keys - ) - if hashable_keys: - existing_records_indices_by_primary_key[hashable_keys] = i - - for record in records: - - # turn array values into lists - if arrays: - record = at_tools.graft_fields(record, arrays) - - # check for existing via primary keys - hashable_keys = at_tools.get_primary_keys_as_hashable( - record, primary_keys - ) - existing_index = ( - existing_records_indices_by_primary_key.get(hashable_keys) - if hashable_keys - else None - ) - - # if record exists - if existing_index is not None: - # add index to existing indices list - existing_indices.append(existing_index) - existing_record = existing_records[existing_index] - - # remove overriden fields - if overrides: - record = at_tools.override_record( - record, existing_record, overrides - ) - - # filter records to only keep new data - record = at_tools.filter_record(record, existing_record) - - # if new data, then append to update_records - if record["fields"]: - update_records.append(record) - - # else append to post_records - else: - post_records.append(record) - - # create new records - table.post_records(post_records) - # update existing records - table.update_records(update_records) - - # Get dead record indices - all_indices = set(range(len(existing_records))) - existing_indices = set(existing_indices) - dead_indices = all_indices - existing_indices - - if len(dead_indices) > 0: - # loop through records ro delete - for index in dead_indices: - # get dead record - dead_record = existing_records[index] - - record_name = dead_record["fields"].get( - table.primary_field_name - ) - # get record's first field - if not record_name: - field_names = list(dead_record["fields"].keys()) - if field_names: - record_name = dead_record["fields"].get( - field_names[0] - ) - else: - record_name = dead_record["id"] - - # delete record - if mode == "overwrite": - table.delete_record(dead_record, message=record_name) - # flag manual deletion via 'Delete' checkbox field - elif mode == "update": - del_field = "Delete" - if prefix: - del_field = "AUTO_Delete" - - if not dead_record["fields"].get(del_field): - record = { - "id": dead_record["id"], - "fields": {del_field: True}, - } - table.update_record(record, message=record_name) - - # If no records in table - else: - if arrays: - # turn array values into lists - records = [ - at_tools.graft_fields(record, arrays) for record in records - ] - table.post_records(records) - - @staticmethod - def _sh_crud(rows: list, **kwargs) -> None: - """ - Post Snowflake Query to Sheets - - Args: - rows: List of rows to CRUD - - Kwargs: - credentials (``oauth2client.service_account.ServiceAccountCredentials``): Google Credentials Object for Service Account - type(``string``): Type of CRUD operation ("overwrite" or "append") - target(``dict``): - wks_index (``int``): Worksheet Index - id (``string``): Spreadheet Key - value_input_option (``string``, default="USER_ENTERED"): Sheets input style ("RAW" or "USER_ENTERED") - """ # noqa: E501 - credentials = kwargs.get("credentials") - sh_key = kwargs.get("id") - wks_index = kwargs.get("target").get("wks_index") - txn_type = kwargs.get("type") - value_input_option = kwargs.get("value_input_option") or "USER_ENTERED" - - n_of_rows = len(rows) - n_of_columns = len(rows[0]) - row_offset = 0 - - # get the work sheet - client = authorize(credentials) - sh = client.open_by_key(sh_key) - wks = sh.get_worksheet(wks_index) - - # overwrite - if txn_type == "overwrite": - wks.resize(rows=n_of_rows, cols=n_of_columns) - cell_list = wks.range(1, 1, n_of_rows, n_of_columns) - - # append - elif txn_type == "append": - rows = rows[1:] - n_of_rows = len(rows) - row_count = wks.row_count - col_count = wks.col_count - if n_of_columns > col_count: - col_count = n_of_columns - wks.resize(rows=row_count + n_of_rows, cols=col_count) - wks = sh.get_worksheet(wks_index) - cell_list = wks.range( - row_count + 1, 1, row_count + n_of_rows, n_of_columns - ) - row_offset = row_count - - # TODO elif txn_type in ("update", "partial") - else: - return - - for i, row in enumerate(rows): - for j, item in enumerate(row): - cell_list[i * n_of_columns + j].value = item if item else "" - - row_chunk = 1500 - max_cells = n_of_columns * row_chunk - for i in range(int(len(cell_list) / max_cells + 1)): - min_index = i * max_cells - max_index = i * max_cells + max_cells - 1 - if max_index > len(cell_list): - max_index = len(cell_list) - - sublist = cell_list[min_index:max_index] - try: - wks.update_cells( - sublist, value_input_option=value_input_option - ) - except TypeError: - wks.update_cells(sublist) - start_row = int(min_index / n_of_columns) + row_offset + 1 - end_row = start_row + row_chunk - 1 - if end_row > n_of_rows + row_offset: - end_row = n_of_rows + row_offset - logger.info(f"Posted rows {start_row} to {end_row} to Sheets") - - if max_index == len(cell_list) - 1: - break - - logger.info("Completed posting to Sheets") diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..1215375 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +ignore_missing_imports = True \ No newline at end of file diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000..66522ac --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,14 @@ +pytest +pytest-cov +pytest-asyncio +coveralls +aiohttp +flake8 +tox-travis + +wheel +twine + +# Only installed in environments that can handle it; don't try running the `lint` +# tox environment on Python < 3.6. +black ; python_version >="3.6" \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 224a779..9d08805 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,5 @@ +[bdist_wheel] +universal=1 + [metadata] description-file = README.md \ No newline at end of file diff --git a/setup.py b/setup.py index ccca9ee..b89b300 100644 --- a/setup.py +++ b/setup.py @@ -1,16 +1,18 @@ -from distutils.core import setup +from setuptools import setup setup( - name="airbase", - packages=["airbase"], - description="A async Python API Wrapper for the Airtable API", + name="airtable-async", + packages=["airbase", "airbase.utils"], + description="An asynchronous Python API Wrapper for the Airtable API", author="Luis Felipe Paris", author_email="lfparis@gmail.com", url="https://github.com/lfparis/airbase", - download_url="", - version="0.0.1", - install_requires=["aiohttp", "pandas"], - python_requires="!=2.7.*, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*, !=3.7.*", # noqa: E501 + version="0.0.1b12", + install_requires=["aiohttp"], + extras_require={"tools": ["pandas"]}, + package_data={"airbase": ["py.typed"]}, + zip_safe=False, + python_requires=">=3.7", keywords=["airtable", "api", "async", "async.io"], license="The MIT License (MIT)", classifiers=[ @@ -20,7 +22,9 @@ "Intended Audience :: Developers", "Programming Language :: Python", "Topic :: Software Development", + "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: Implementation :: CPython", ], ) diff --git a/tests/test_airtable.py b/tests/test_airtable.py new file mode 100644 index 0000000..b0f5e99 --- /dev/null +++ b/tests/test_airtable.py @@ -0,0 +1,16 @@ +import pytest +import sys + +from airbase.airtable import Airtable, Base # noqa F401 + +if sys.version_info[:2] < (3, 6): + pass + + +@pytest.mark.asyncio +async def test_airtable() -> None: + async with Airtable() as at: + # Get all bases for a user + await at.get_bases() + assert getattr(at, "bases", None) + assert isinstance(at.bases[0], Base) diff --git a/tests/test_api_key.py b/tests/test_api_key.py new file mode 100644 index 0000000..cc9cae5 --- /dev/null +++ b/tests/test_api_key.py @@ -0,0 +1,16 @@ +import os +import pytest +import sys + + +from airbase.airtable import Airtable, Base # noqa F401 + +if sys.version_info[:2] < (3, 6): + pass + + +@pytest.mark.asyncio +async def test_api_key() -> None: + async with Airtable() as at: + # Get all bases for a user + assert at.api_key == os.environ["AIRTABLE_API_KEY"] diff --git a/todo b/todo new file mode 100644 index 0000000..a730855 --- /dev/null +++ b/todo @@ -0,0 +1,9 @@ +Todos: + Table(): + self.get_records(): + [] filter_by_fields convert array to urlencoded string + Logger, level, etc + + Airtable(): + self.semaphore should be restricted to calls to META_URL, + and each base should have it's own semaphore. \ No newline at end of file diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..d799240 --- /dev/null +++ b/tox.ini @@ -0,0 +1,28 @@ +[tox] +envlist = py37,py38,lint + +[flake8] +filename = *.py +count = True +# Per Black Formmater Documentation +ignore = E203, E266, E501, W503 +select = B,C,E,F,W,T4,B9 +max-line-length = 79 +max-complexity = 15 +exclude = + .venv + .eggs + .tox + +[testenv] +passenv = AIRTABLE_API_KEY +addopts = -v +testpaths = tests +deps = -r requirements.dev.txt +commands = pytest + +[testenv:lint] +python = python3.7 +commands = + black --line-length 79 --diff airbase tests + flake8 \ No newline at end of file