Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ celerybeat.pid
.venv
env/
venv/
venv
ENV/
env.bak/
venv.bak/
Expand Down
8 changes: 8 additions & 0 deletions definitions/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def main(
# behavior changed incompatibly in py3.3
command_line.parser.error("too few arguments")
Config.get_template_directory = get_template_directory # type: ignore
if options.config is None:
if os.path.isfile("alembic.ini"):
options.config = "alembic.ini"
else:
raise EnvironmentError(
"File alembic.ini does not exist and was not provided in command line. See --help."
)

config = Config(
file_=options.config,
ini_section=options.name,
Expand Down
7 changes: 5 additions & 2 deletions definitions/custom_scripts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@
render_drop_group,
render_drop_sequence,
)
from .schemas import add_table_schema_to_model, compare_schemas
from .schemas import add_table_schema_to_model, compare_for_groups
from .tables import compare_for_encrypted, compare_for_sensitive

__all__ = [
"create_table_schema",
"drop_table_schema",
"render_create_sequence",
"render_drop_sequence",
"add_table_schema_to_model",
"compare_schemas",
"compare_for_groups",
"compare_for_encrypted",
"compare_for_sensitive",
"create_group",
"delete_group",
"render_create_group",
Expand Down
196 changes: 196 additions & 0 deletions definitions/custom_scripts/operations_encrypt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
from enum import IntEnum

from alembic.operations import MigrateOperation, Operations
from sqlalchemy import text


def _get_column_names(schema, table, id_column):
return [
i.column_name
for i in operations.get_bind().execute(
text(
f"SELECT column_name FROM information_schema.columns "
f"WHERE table_schema='{schema}' AND table_name='{name}';"
)
)
if i != id_column
]


class op_type(IntEnum):
ENCRYPT = 0x00
DECRYPT = 0x01


def _common_constructor(self, table_name, key_table_name, id_column, columns, **kw):
_name = table_name.split(".")
if len(_name) != 2:
self.table_name = table_name
else:
self.table_schema, self.table_name = _name
if (sch := kw.get("schema", None)) is not None:
self.table_schema = sch

_key_name = key_table_name.split(".")
if len(_key_name) < 2:
self.key_table_name = key_table_name
else:
self.key_table_schema, self.key_table_name = _key_name
if (sch := kw.get("key_schema", None)) is not None:
self.key_table_schema = sch

self.id_column = id_column
self.columns = columns


@Operations.register_operation("encrypt_table")
class EncryptTableOp(MigrateOperation):
"""fields:
self.table_name: str
self.table_schema: str
self.key_table_name: str
self.key_table_schema: str
self.id_column: str
self.columns: list[str]
"""

def __init__(self, table_name, key_table_name, id_column, columns, **kw):
_name = table_name.split(".")
if len(_name) != 2:
self.table_name = table_name
else:
self.table_schema, self.table_name = _name
if (sch := kw.get("schema", None)) is not None:
self.table_schema = sch

_key_name = key_table_name.split(".")
if len(_key_name) < 2:
self.key_table_name = key_table_name
else:
self.key_table_schema, self.key_table_name = _key_name
if (sch := kw.get("key_schema", None)) is not None:
self.key_table_schema = sch

self.id_column = id_column
self.columns = columns

@classmethod
def encrypt_table(
cls,
operations,
table_name,
key_table_name=None,
id_column="id",
columns=None,
**kw,
):
if columns is None:
columns = _get_column_names(*table_name.split("."), id_column)
if key_table_name is None:
key_table_name = table_name + "_ekeys"
op = EncryptTableOp(table_name, key_table_name, id_column, columns, **kw)
return operations.invoke(op)

def reverse(self):
return DecryptTableOp(
self.table_name, self.key_table_name, self.id_column, self.columns
)


@Operations.register_operation("decrypt_table")
class DecryptTableOp(MigrateOperation):
"same fields as in EncryptTableOp"

__init__ = _common_constructor

@classmethod
def decrypt_table(
cls, operations, table_name, key_table_name, id_column, columns, **kw
):
op = DecryptTableOp(table_name, key_table_name, id_column, columns, **kw)
return operations.invoke(op)

def reverse(self):
return EncryptTableOp(
self.table_name, self.key_table_name, self.id_column, self.columns
)


@Operations.register_operation("encrypt_column")
class EncryptColumnOp(MigrateOperation):
def __init__(self, column_name, key_table_name, id_column):
self.column_name = column_name
self.key_table_name = key_table_name
self.id_column = id_column

@classmethod
def encrypt_column(cls, operations, column_name, key_table_name, id_column):
op = EncryptColumnOp(column_name, key_table_name, id_column)
return operations.invoke(op)

def reverse(self):
return DecryptColumnOp(self.column_name, self.key_table_name, self.id_column)


@Operations.register_operation("decrypt_column")
class DecryptColumnOp(MigrateOperation):
def __init__(self, column_name, key_table_name, id_column):
self.column_name = column_name
self.key_table_name = key_table_name
self.id_column = id_column

@classmethod
def decrypt_column(cls, operations, column_name, key_table_name, id_column):
op = DecryptColumnOp(column_name, key_table_name, id_column)
return operations.invoke(op)

def reverse(self):
return EncryptColumnOp(self.column_name, self.key_table_name, self.id_column)


def _generate_keygen_query(operation) -> str:
return (
f'INSERT INTO "{operation.key_table_schema}".{operation.key_table_name} '
f"SELECT src.{operation.id_column}, encode(gen_random_bytes(32), 'base64'), NOW() "
f'FROM "{operation.table_schema}".{operation.table_name} '
f'AS src LEFT JOIN "{operation.key_table_schema}".{operation.key_table_name} AS dest '
f"ON dest.id = src.{operation.id_column} WHERE dest.id IS NULL;"
)


def _generate_encryption_query(operation, func: op_type) -> str:
set_cols = []
encrypt_cols = []
for colname in operation.columns:
set_cols.append(f"{colname} = sub.enc_{colname}")
encrypt_cols.append(
f"pgp_sym_{func.name}_bytea(dst.{colname}, keys.key) enc_{colname}"
)
return (
f'UPDATE "{operation.table_schema}".{operation.table_name} dst '
f'SET {",".join(set_cols)} FROM (SELECT '
f'dst.{operation.id_column} id,{",".join(encrypt_cols)}'
f' FROM "{operation.table_schema}".{operation.table_name} dst '
f' LEFT JOIN "{operation.key_table_schema}".{operation.key_table_name} keys'
f" ON keys.id=dst.{operation.id_column}) sub "
f"WHERE dst.{operation.id_column}=sub.id;"
)


@Operations.implementation_for(EncryptTableOp)
def encrypt_table(operations, operation):
operations.execute(_generate_keygen_query(operation))
operations.execute(_generate_encryption_query(operation, op_type.ENCRYPT))


@Operations.implementation_for(DecryptTableOp)
def decrypt_table(operations, operation):
operations.execute(_generate_encryption_query(operation, op_type.DECRYPT))


@Operations.implementation_for(EncryptColumnOp)
def encrypt_column(operations, operation): ...


@Operations.implementation_for(DecryptColumnOp)
def decrypt_column(operations, operation): ...
6 changes: 6 additions & 0 deletions definitions/custom_scripts/operations_tables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from alembic.operations import MigrateOperation, Operations
from alembic.operations.ops import CreateTableOp

# TODO: replace CreateTableOp and DropTableOp implementations
# with custom ones
# Will be possible in alembic 1.17.2+
# till then, we wait


@Operations.register_operation("grant_on_table")
Expand Down
119 changes: 42 additions & 77 deletions definitions/custom_scripts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,104 +13,69 @@
from .operations_tables import GrantRightsOp, RevokeRightsOp


# this function is exported and not used inside this file, do not delete
def add_table_schema_to_model(table_schema, metadata):
metadata.info.setdefault("table_schemas", set()).add(table_schema)


@comparators.dispatch_for("schema")
def compare_schemas(autogen_context, upgrade_ops, schemas):
def compare_for_groups(autogen_context, upgrade_ops, schemas):
environment = "test" if os.getenv("ENVIRONMENT") != "production" else "prod"
project_prefix = os.getenv("SCHEMA_PREFIX", "dwh")
all_conn_schemas = set()
default_pg_schemas = ["pg_toast", "information_schema", "public", "pg_catalog"]
query = text("select schema_name from information_schema.schemata")
query = text("SELECT schema_name FROM information_schema.schemata")
# all schemas in database
all_conn_schemas.update(
[
sch[0]
for sch in autogen_context.connection.execute(query)
if sch[0] not in default_pg_schemas
]
)

# all schemas in code
metadata_schemas = autogen_context.metadata.info.setdefault("table_schemas", set())

# Create/delete new schemas
for sch in metadata_schemas - all_conn_schemas:
upgrade_ops.ops.append(CreateTableSchemaOp(sch))
for render_scope in ["read", "write", "all"]:
group_name = (
f"test_dwh_{sch}_{render_scope}".lower()
if os.getenv("ENVIRONMENT") != "production"
else f"prod_dwh_{sch}_{render_scope}".lower()
)
upgrade_ops.ops.append(CreateGroupOp(group_name))
upgrade_ops.ops.append(GrantOnSchemaOp(group_name, sch))

tables = set(
[
table
for table in autogen_context.metadata.tables.values()
if table.schema == sch
]
)
for table in tables:
for render_scope in ["read", "write", "all"]:
scopes = []
match render_scope:
case "read":
scopes = ["SELECT"]
case "write":
scopes = ["SELECT", "UPDATE", "DELETE", "TRUNCATE", "INSERT"]
case "all":
scopes = ["ALL"]

group_name = (
f"test_dwh_{sch}_{render_scope}".lower()
if os.getenv("ENVIRONMENT") != "production"
else f"prod_dwh_{sch}_{render_scope}".lower()
)
upgrade_ops.ops.append(
GrantRightsOp(
table_name=str(table), scopes=scopes, group_name=group_name
)
)

for sch in all_conn_schemas - metadata_schemas:
upgrade_ops.ops.append(DropTableSchemaOp(sch))
for render_scope in ["read", "write", "all"]:
group_name = (
f"test_dwh_{sch}_{render_scope}".lower()
if os.getenv("ENVIRONMENT") != "production"
else f"prod_dwh_{sch}_{render_scope}".lower()
)
upgrade_ops.ops.append(DeleteGroupOp(group_name))
upgrade_ops.ops.append(RevokeOnSchemaOp(group_name, sch))

query = text(
f"SELECT * FROM information_schema.tables WHERE table_schema='{sch}';"
all_groups_db = set()
all_groups_code = set()
query = text(
"SELECT grantee,table_schema FROM information_schema.role_table_grants "
"WHERE grantee LIKE :pattern"
)
all_groups_db.update(
autogen_context.connection.execute(
query, {"pattern": f"{environment}%{project_prefix}%"}
)
)
for sch in metadata_schemas:
has_regular = any(
table.schema == sch and not table.info.get("sensitive", False)
for table in autogen_context.metadata.tables.values()
)
tables = set(
[
".".join(table[1:3])
for table in autogen_context.connection.execute(query)
]
has_sensitive = any(
table.schema == sch and table.info.get("sensitive", False)
for table in autogen_context.metadata.tables.values()
)
print(tables)
for table in tables:
for render_scope in ["read", "write", "all"]:
scopes = []
match render_scope:
case "read":
scopes = ["SELECT"]
case "write":
scopes = ["SELECT", "UPDATE", "DELETE", "TRUNCATE", "INSERT"]
case "all":
scopes = ["ALL"]

group_name = (
f"test_dwh_{sch}_{render_scope}".lower()
if os.getenv("ENVIRONMENT") != "production"
else f"prod_dwh_{sch}_{render_scope}".lower()
)
upgrade_ops.ops.append(
RevokeRightsOp(
table_name=str(table), scopes=scopes, group_name=group_name
)
)
group_name = f"{environment}%s_{project_prefix}_{sch}_%s".lower()
for render_scope in ["read", "write", "all"]:
if has_regular:
all_groups_code.add((group_name % ("", render_scope), sch))
if has_sensitive:
all_groups_code.add((group_name % ("_sensitive", render_scope), sch))

# for all new required groups
for group, sch in all_groups_code - all_groups_db:
upgrade_ops.ops.append(CreateGroupOp(group))
upgrade_ops.ops.append(GrantOnSchemaOp(group, sch))

# for all groups no longer needed
for group, sch in all_groups_db - all_groups_code:
upgrade_ops.ops.append(DeleteGroupOp(group))
upgrade_ops.ops.append(RevokeOnSchemaOp(group, sch))
Loading
Loading