From 7d63449ef1caa264f4268b5725c458130ba6fb40 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Tue, 7 Jun 2022 21:14:45 +0700 Subject: [PATCH 01/10] Add deterministic encryption --- README.md | 6 + tink_fields/__init__.py | 3 +- tink_fields/fields.py | 130 ++++++++++++++---- tink_fields/test/models.py | 12 ++ tink_fields/test/settings/sqlite.py | 4 + tink_fields/test/test_fields.py | 37 ++++- .../test/test_plaintext_daead_keyset.json | 1 + 7 files changed, 164 insertions(+), 29 deletions(-) create mode 100644 tink_fields/test/test_plaintext_daead_keyset.json diff --git a/README.md b/README.md index d393344..802ceea 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,12 @@ class AnotherModel(models.Model): Supported field types include: `EncryptedCharField`, `EncryptedTextField`, `EncryptedDateField`, `EncryptedDateTimeField`, `EncryptedEmailField`, and `EncryptedIntegerField`. +## Deterministic Encryption + +`DeterministicEncryptedCharField` provides support for [Deterministic AEAD](https://developers.google.com/tink/deterministic-aead) which means value in the field can be queried with exact matches. However, unlike normal AEAD encryption, an attacker can verify that two messages are equal. + +Deterministic encryption requires key of type `AES-SIV` and supports Associated Data. + ### Associated Data The encrypted fields make use of `Authenticated Encryption With Associated Data (AEAD)` which offers confidentiality and integrity within the same mode of operation. This allows the caller to specify a cleartext fragment named `additional authenticated data (aad)` to the encryption and decryption operations and receive cryptographic guarantees that the ciphertext data has not been tampered with. diff --git a/tink_fields/__init__.py b/tink_fields/__init__.py index f4fbd52..9b9dd43 100644 --- a/tink_fields/__init__.py +++ b/tink_fields/__init__.py @@ -1,4 +1,5 @@ from .fields import * # noqa -from tink import aead +from tink import aead, daead aead.register() +daead.register() diff --git a/tink_fields/fields.py b/tink_fields/fields.py index 01a181b..9c575bc 100644 --- a/tink_fields/fields.py +++ b/tink_fields/fields.py @@ -1,5 +1,5 @@ from functools import lru_cache -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING from django.db import models from django.core.exceptions import FieldError, ImproperlyConfigured from tink import ( @@ -8,12 +8,15 @@ read_keyset_handle, JsonKeysetReader, aead, + daead, ) from django.conf import settings from dataclasses import dataclass from os.path import exists from django.utils.encoding import force_bytes, force_str -from django.db.backends.base.base import BaseDatabaseWrapper + +if TYPE_CHECKING: + from django.db.backends.base.base import BaseDatabaseWrapper __all__ = [ @@ -24,6 +27,10 @@ "EncryptedIntegerField", "EncryptedDateField", "EncryptedDateTimeField", + "DeterministicEncryptedField", + "DeterministicEncryptedCharField", + "DeterministicEncryptedEmailField", + "DeterministicEncryptedIntegerField", ] @@ -39,14 +46,14 @@ def validate(self): if not exists(self.path): raise ImproperlyConfigured(f"Keyset {self.path} does not exist") - + if not self.cleartext and self.master_key_aead is None: - raise ImproperlyConfigured(f"Encrypted keysets must specify `master_key_aead`") - + raise ImproperlyConfigured( + f"Encrypted keysets must specify `master_key_aead`" + ) -class EncryptedField(models.Field): - """A field that uses Tink primitives to protect the confidentiality and integrity of data""" +class BaseEncryptedField(models.Field): _unsupported_properties = ["primary_key", "db_index", "unique"] _internal_type = "BinaryField" @@ -65,7 +72,7 @@ def __init__(self, *args, **kwargs): self._keyset_handle = self._get_tink_keyset_handle() self._aad_callback = kwargs.pop("aad_callback", lambda x: b"") - super(EncryptedField, self).__init__(*args, **kwargs) + super(BaseEncryptedField, self).__init__(*args, **kwargs) def _get_config(self) -> Dict[str, Any]: config = getattr(settings, "TINK_FIELDS_CONFIG", None) @@ -93,14 +100,32 @@ def _get_tink_keyset_handle(self) -> KeysetHandle: return cleartext_keyset_handle.read(reader) return read_keyset_handle(reader, keyset_config.master_key_aead) + def get_internal_type(self) -> str: + return self._internal_type + + @property + @lru_cache(maxsize=None) + def validators(self): + # Temporarily pretend to be whatever type of field we're masquerading + # as, for purposes of constructing validators (needed for + # IntegerField and subclasses). + self.__dict__["_internal_type"] = super( + BaseEncryptedField, self + ).get_internal_type() + try: + return super(BaseEncryptedField, self).validators + finally: + del self.__dict__["_internal_type"] + + +class EncryptedField(BaseEncryptedField): + """A field that uses Tink primitives to protect the confidentiality and integrity of data""" + @lru_cache(maxsize=None) def _get_aead_primitive(self) -> aead.Aead: return self._keyset_handle.primitive(aead.Aead) - def get_internal_type(self) -> str: - return self._internal_type - - def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: + def get_db_prep_save(self, value: Any, connection: "BaseDatabaseWrapper") -> Any: val = super(EncryptedField, self).get_db_prep_save(value, connection) if val is not None: return connection.Database.Binary( @@ -119,19 +144,39 @@ def from_db_value(self, value, expression, connection, *args): ) ) - @property + +class DeterministicEncryptedField(BaseEncryptedField): + """Field that is similar to EncryptedField, but support exact match lookups""" + + _unsupported_properties = [] + @lru_cache(maxsize=None) - def validators(self): - # Temporarily pretend to be whatever type of field we're masquerading - # as, for purposes of constructing validators (needed for - # IntegerField and subclasses). - self.__dict__["_internal_type"] = super( - EncryptedField, self - ).get_internal_type() - try: - return super(EncryptedField, self).validators - finally: - del self.__dict__["_internal_type"] + def _get_daead_primitive(self) -> daead.DeterministicAead: + return self._keyset_handle.primitive(daead.DeterministicAead) + + def get_db_prep_value( + self, value: Any, connection: "BaseDatabaseWrappr", prepared=False + ) -> Any: + + val = super(DeterministicEncryptedField, self).get_db_prep_value( + value, connection, prepared + ) + if val is not None: + return connection.Database.Binary( + self._get_daead_primitive().encrypt_deterministically( + force_bytes(val), self._aad_callback(self) + ) + ) + + def from_db_value(self, value, expression, connection, *args): + if value is not None: + return self.to_python( + force_str( + self._get_daead_primitive().decrypt_deterministically( + bytes(value), self._aad_callback(self) + ) + ) + ) def get_prep_lookup(self): @@ -143,12 +188,29 @@ def get_prep_lookup(self): ) +lookup_allowlist = { + (object, "isnull"), + (DeterministicEncryptedField, "exact"), +} + + +def is_lookup_allowed(cls, name) -> bool: + for item in lookup_allowlist: + if item[1] == name and isinstance(cls, item[0]): + return True + return False + + +# Override all lookups except in lookup_allowlist to get_prep_lookup for name, lookup in models.Field.class_lookups.items(): - if name != "isnull": + for cls in (EncryptedField, DeterministicEncryptedField): + if not is_lookup_allowed(cls, name): + continue + lookup_class = type( - "EncryptedField" + name, (lookup,), {"get_prep_lookup": get_prep_lookup} + cls.__name__ + name, (lookup,), {"get_prep_lookup": get_prep_lookup} ) - EncryptedField.register_lookup(lookup_class) + cls.register_lookup(lookup_class) class EncryptedTextField(EncryptedField, models.TextField): @@ -173,3 +235,17 @@ class EncryptedDateField(EncryptedField, models.DateField): class EncryptedDateTimeField(EncryptedField, models.DateTimeField): pass + + +class DeterministicEncryptedCharField(DeterministicEncryptedField, models.CharField): + pass + + +class DeterministicEncryptedEmailField(DeterministicEncryptedField, models.EmailField): + pass + + +class DeterministicEncryptedIntegerField( + DeterministicEncryptedField, models.IntegerField +): + pass diff --git a/tink_fields/test/models.py b/tink_fields/test/models.py index 61e917b..68f6b3e 100644 --- a/tink_fields/test/models.py +++ b/tink_fields/test/models.py @@ -41,3 +41,15 @@ class EncryptedCharWithFixedAad(models.Model): class EncryptedCharWithAlternateKeyset(models.Model): value = fields.EncryptedCharField(max_length=25, keyset="alternate") + + +class DeterministicEncryptedChar(models.Model): + value = fields.DeterministicEncryptedCharField(max_length=25, keyset="daead") + + +class DeterministicEncryptedEmail(models.Model): + value = fields.DeterministicEncryptedEmailField(keyset="daead") + + +class DeterministicEncryptedInt(models.Model): + value = fields.DeterministicEncryptedIntegerField(keyset="daead") diff --git a/tink_fields/test/settings/sqlite.py b/tink_fields/test/settings/sqlite.py index 5ca7642..27c3ae4 100644 --- a/tink_fields/test/settings/sqlite.py +++ b/tink_fields/test/settings/sqlite.py @@ -26,4 +26,8 @@ "cleartext": True, "path": os.path.join(HERE, "../test_plaintext_keyset.json"), }, + "daead": { + "cleartext": True, + "path": os.path.join(HERE, "../test_plaintext_daead_keyset.json"), + }, } diff --git a/tink_fields/test/test_fields.py b/tink_fields/test/test_fields.py index db2c159..7e95858 100644 --- a/tink_fields/test/test_fields.py +++ b/tink_fields/test/test_fields.py @@ -1,6 +1,6 @@ from datetime import date, datetime -from django.db import connection, models as dj_models +from django.db import connection from django.utils.encoding import force_bytes, force_str import pytest @@ -41,3 +41,38 @@ def test_insert(self, db, model, vals): ] assert list(map(field.to_python, data)) == [vals[0]] + + +@pytest.mark.parametrize( + "model,vals", + [ + (models.DeterministicEncryptedChar, ["one", "two"]), + (models.DeterministicEncryptedEmail, ["a@example.com", "b@example.com"]), + (models.DeterministicEncryptedInt, [1, 2]), + ], +) +class TestDeterministicEncryptedFieldQueries(object): + def test_insert(self, db, model, vals): + """Data stored in DB is actually encrypted.""" + field = model._meta.get_field("value") + aad_callback = getattr(field, "_aad_callback") + model.objects.create(value=vals[0]) + with connection.cursor() as cur: + cur.execute("SELECT value FROM %s" % model._meta.db_table) + data = [ + force_str( + field._get_daead_primitive().decrypt_deterministically( + force_bytes(r[0]), aad_callback(field) + ) + ) + for r in cur.fetchall() + ] + + assert list(map(field.to_python, data)) == [vals[0]] + + def test_search(self, db, model, vals): + model.objects.create(value=vals[0]) + model.objects.create(value=vals[1]) + out = model.objects.filter(value=vals[0]) + assert len(out) == 1 + assert out[0].value == vals[0] diff --git a/tink_fields/test/test_plaintext_daead_keyset.json b/tink_fields/test/test_plaintext_daead_keyset.json new file mode 100644 index 0000000..d214367 --- /dev/null +++ b/tink_fields/test/test_plaintext_daead_keyset.json @@ -0,0 +1 @@ +{"primaryKeyId":508056547,"key":[{"keyData":{"typeUrl":"type.googleapis.com/google.crypto.tink.AesSivKey","value":"EkDc2ZTmEZO2wrwmfEBWTEwoRd2WrDqPikE8rseHs3Nx/exobkxiQEZtPwTM37iNdwVvSouyDLGWUjO3T3D8v0LC","keyMaterialType":"SYMMETRIC"},"status":"ENABLED","keyId":508056547,"outputPrefixType":"TINK"}]} From 00a42ae560be3b93f3853fea21fa94ed90106040 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Wed, 8 Jun 2022 23:18:00 +0700 Subject: [PATCH 02/10] Add encrypted binary field, fix memory leak with lru_cache --- tink_fields/fields.py | 42 ++++++++++++++++++++------------- tink_fields/test/models.py | 8 +++++++ tink_fields/test/test_fields.py | 26 +++++++++++++++----- 3 files changed, 54 insertions(+), 22 deletions(-) diff --git a/tink_fields/fields.py b/tink_fields/fields.py index 9c575bc..860c27e 100644 --- a/tink_fields/fields.py +++ b/tink_fields/fields.py @@ -1,4 +1,4 @@ -from functools import lru_cache +from django.utils.functional import cached_property from typing import Any, Callable, Dict, Optional, TYPE_CHECKING from django.db import models from django.core.exceptions import FieldError, ImproperlyConfigured @@ -27,6 +27,7 @@ "EncryptedIntegerField", "EncryptedDateField", "EncryptedDateTimeField", + "EncryptedBinaryField", "DeterministicEncryptedField", "DeterministicEncryptedCharField", "DeterministicEncryptedEmailField", @@ -103,8 +104,7 @@ def _get_tink_keyset_handle(self) -> KeysetHandle: def get_internal_type(self) -> str: return self._internal_type - @property - @lru_cache(maxsize=None) + @cached_property def validators(self): # Temporarily pretend to be whatever type of field we're masquerading # as, for purposes of constructing validators (needed for @@ -117,19 +117,25 @@ def validators(self): finally: del self.__dict__["_internal_type"] + def to_python_prepare(self, value: bytes) -> Any: + if isinstance(self, models.BinaryField): + return value + + return force_str(value) + class EncryptedField(BaseEncryptedField): """A field that uses Tink primitives to protect the confidentiality and integrity of data""" - @lru_cache(maxsize=None) - def _get_aead_primitive(self) -> aead.Aead: + @cached_property + def _aead_primitive(self) -> aead.Aead: return self._keyset_handle.primitive(aead.Aead) def get_db_prep_save(self, value: Any, connection: "BaseDatabaseWrapper") -> Any: val = super(EncryptedField, self).get_db_prep_save(value, connection) if val is not None: return connection.Database.Binary( - self._get_aead_primitive().encrypt( + self._aead_primitive.encrypt( force_bytes(val), self._aad_callback(self) ) ) @@ -137,8 +143,8 @@ def get_db_prep_save(self, value: Any, connection: "BaseDatabaseWrapper") -> Any def from_db_value(self, value, expression, connection, *args): if value is not None: return self.to_python( - force_str( - self._get_aead_primitive().decrypt( + self.to_python_prepare( + self._aead_primitive.decrypt( bytes(value), self._aad_callback(self) ) ) @@ -150,8 +156,8 @@ class DeterministicEncryptedField(BaseEncryptedField): _unsupported_properties = [] - @lru_cache(maxsize=None) - def _get_daead_primitive(self) -> daead.DeterministicAead: + @cached_property + def _daead_primitive(self) -> daead.DeterministicAead: return self._keyset_handle.primitive(daead.DeterministicAead) def get_db_prep_value( @@ -163,7 +169,7 @@ def get_db_prep_value( ) if val is not None: return connection.Database.Binary( - self._get_daead_primitive().encrypt_deterministically( + self._daead_primitive.encrypt_deterministically( force_bytes(val), self._aad_callback(self) ) ) @@ -171,8 +177,8 @@ def get_db_prep_value( def from_db_value(self, value, expression, connection, *args): if value is not None: return self.to_python( - force_str( - self._get_daead_primitive().decrypt_deterministically( + self.to_python_prepare( + self._daead_primitive.decrypt_deterministically( bytes(value), self._aad_callback(self) ) ) @@ -196,7 +202,7 @@ def get_prep_lookup(self): def is_lookup_allowed(cls, name) -> bool: for item in lookup_allowlist: - if item[1] == name and isinstance(cls, item[0]): + if item[1] == name and issubclass(cls, item[0]): return True return False @@ -204,11 +210,11 @@ def is_lookup_allowed(cls, name) -> bool: # Override all lookups except in lookup_allowlist to get_prep_lookup for name, lookup in models.Field.class_lookups.items(): for cls in (EncryptedField, DeterministicEncryptedField): - if not is_lookup_allowed(cls, name): + if is_lookup_allowed(cls, name): continue lookup_class = type( - cls.__name__ + name, (lookup,), {"get_prep_lookup": get_prep_lookup} + cls.__name__ + "__" + name, (lookup,), {"get_prep_lookup": get_prep_lookup} ) cls.register_lookup(lookup_class) @@ -237,6 +243,10 @@ class EncryptedDateTimeField(EncryptedField, models.DateTimeField): pass +class EncryptedBinaryField(EncryptedField, models.BinaryField): + """Encrypted raw binary data, must be under 2^32 bytes (4.295GB)""" + + class DeterministicEncryptedCharField(DeterministicEncryptedField, models.CharField): pass diff --git a/tink_fields/test/models.py b/tink_fields/test/models.py index 68f6b3e..9752015 100644 --- a/tink_fields/test/models.py +++ b/tink_fields/test/models.py @@ -27,6 +27,10 @@ class EncryptedDateTime(models.Model): value = fields.EncryptedDateTimeField() +class EncryptedBinary(models.Model): + value = fields.EncryptedBinaryField() + + class EncryptedNullable(models.Model): value = fields.EncryptedIntegerField(null=True) @@ -53,3 +57,7 @@ class DeterministicEncryptedEmail(models.Model): class DeterministicEncryptedInt(models.Model): value = fields.DeterministicEncryptedIntegerField(keyset="daead") + + +class DeterministicEncryptedIntNullable(models.Model): + value = fields.DeterministicEncryptedIntegerField(keyset="daead", null=True) diff --git a/tink_fields/test/test_fields.py b/tink_fields/test/test_fields.py index 7e95858..bef6676 100644 --- a/tink_fields/test/test_fields.py +++ b/tink_fields/test/test_fields.py @@ -1,7 +1,7 @@ from datetime import date, datetime from django.db import connection -from django.utils.encoding import force_bytes, force_str +from django.utils.encoding import force_bytes import pytest from . import models @@ -21,6 +21,7 @@ [datetime(2015, 2, 5, 15), datetime(2015, 2, 8, 16)], ), (models.EncryptedCharWithAlternateKeyset, ["foo", "bar"]), + (models.EncryptedBinary, [b"1234", b"asdf"]), ], ) class TestEncryptedFieldQueries(object): @@ -32,15 +33,23 @@ def test_insert(self, db, model, vals): with connection.cursor() as cur: cur.execute("SELECT value FROM %s" % model._meta.db_table) data = [ - force_str( - field._get_aead_primitive().decrypt( + field.to_python_prepare( + field._aead_primitive.decrypt( force_bytes(r[0]), aad_callback(field) ) ) for r in cur.fetchall() ] - assert list(map(field.to_python, data)) == [vals[0]] + if model is models.EncryptedBinary: + assert list([bytes(field.to_python(item)) for item in data]) == [vals[0]] + else: + assert list(map(field.to_python, data)) == [vals[0]] + + +def test_encrypted_nullable(db): + models.EncryptedNullable(value=None).save() + assert models.EncryptedNullable.objects.get(value__isnull=True) @pytest.mark.parametrize( @@ -60,8 +69,8 @@ def test_insert(self, db, model, vals): with connection.cursor() as cur: cur.execute("SELECT value FROM %s" % model._meta.db_table) data = [ - force_str( - field._get_daead_primitive().decrypt_deterministically( + field.to_python_prepare( + field._daead_primitive.decrypt_deterministically( force_bytes(r[0]), aad_callback(field) ) ) @@ -76,3 +85,8 @@ def test_search(self, db, model, vals): out = model.objects.filter(value=vals[0]) assert len(out) == 1 assert out[0].value == vals[0] + + +def test_encrypted_deterministic_nullable(db): + models.DeterministicEncryptedIntNullable(value=None).save() + assert models.DeterministicEncryptedIntNullable.objects.get(value=None) From e047f79732289fb5660e338d8f304ffda25bc509 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 16:31:04 +0700 Subject: [PATCH 03/10] Add database key storage --- tink_fields/config.py | 62 ++++++++ tink_fields/fields.py | 66 ++------ tink_fields/models.py | 228 ++++++++++++++++++++++++++++ tink_fields/test/models.py | 8 + tink_fields/test/settings/base.py | 1 + tink_fields/test/settings/sqlite.py | 6 + tink_fields/test/test_fields.py | 16 ++ tink_fields/test/test_models.py | 53 +++++++ 8 files changed, 389 insertions(+), 51 deletions(-) create mode 100644 tink_fields/config.py create mode 100644 tink_fields/models.py create mode 100644 tink_fields/test/test_models.py diff --git a/tink_fields/config.py b/tink_fields/config.py new file mode 100644 index 0000000..3306a97 --- /dev/null +++ b/tink_fields/config.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass +from os.path import exists +from typing import Optional, TypeVar, Type, TYPE_CHECKING +from django.utils.functional import cached_property + +from django.core.exceptions import ImproperlyConfigured +from tink import ( + aead, + KeysetHandle, + JsonKeysetReader, + cleartext_keyset_handle, + read_keyset_handle, +) + +if TYPE_CHECKING: + from .models import Keyset + +P = TypeVar("P") + + +@dataclass +class KeysetConfig: + path: Optional[str] = None + db_name: Optional[str] = None + master_key_aead: Optional[aead.Aead] = None + cleartext: bool = False + + def validate(self): + if not self.path and not self.db_name: + raise ImproperlyConfigured("Keyset path or db_name must be set") + if self.db_name and self.path: + raise ImproperlyConfigured("Only one of keyset path or db_name must be set") + + if self.path: + if not exists(self.path): + raise ImproperlyConfigured(f"Keyset {self.path} does not exist") + + if not self.cleartext and self.master_key_aead is None: + raise ImproperlyConfigured( + f"Encrypted keysets must specify `master_key_aead`" + ) + + def primitive(self, cls: Type[P]) -> P: + if self.path: + return self._load_from_path.primitive(cls) + if self.db_name: + return self._load_from_db.primitive + + @cached_property + def _load_from_path(self) -> KeysetHandle: + with open(self.path, "r") as f: + reader = JsonKeysetReader(f.read()) + if self.cleartext: + return cleartext_keyset_handle.read(reader) + return read_keyset_handle(reader, self.master_key_aead) + + @cached_property + def _load_from_db(self) -> "Keyset": + from .models import Keyset + + keyset = Keyset.objects.get(name=self.db_name) + return keyset diff --git a/tink_fields/fields.py b/tink_fields/fields.py index 860c27e..17f4a09 100644 --- a/tink_fields/fields.py +++ b/tink_fields/fields.py @@ -1,20 +1,17 @@ -from django.utils.functional import cached_property -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from django.db import models from django.core.exceptions import FieldError, ImproperlyConfigured +from django.utils.functional import cached_property from tink import ( KeysetHandle, - cleartext_keyset_handle, - read_keyset_handle, - JsonKeysetReader, aead, daead, ) from django.conf import settings -from dataclasses import dataclass -from os.path import exists from django.utils.encoding import force_bytes, force_str +from tink_fields.config import KeysetConfig + if TYPE_CHECKING: from django.db.backends.base.base import BaseDatabaseWrapper @@ -35,31 +32,12 @@ ] -@dataclass -class KeysetConfig: - path: str - master_key_aead: Optional[aead.Aead] = None - cleartext: bool = False - - def validate(self): - if not self.path: - raise ImproperlyConfigured("Keyset path cannot be None or empty") - - if not exists(self.path): - raise ImproperlyConfigured(f"Keyset {self.path} does not exist") - - if not self.cleartext and self.master_key_aead is None: - raise ImproperlyConfigured( - f"Encrypted keysets must specify `master_key_aead`" - ) - - class BaseEncryptedField(models.Field): _unsupported_properties = ["primary_key", "db_index", "unique"] _internal_type = "BinaryField" _keyset: str - _keyset_handle: KeysetHandle + _keyset_config: KeysetConfig _aad_callback: Callable[[models.Field], bytes] def __init__(self, *args, **kwargs): @@ -70,22 +48,17 @@ def __init__(self, *args, **kwargs): ) self._keyset = kwargs.pop("keyset", "default") - self._keyset_handle = self._get_tink_keyset_handle() + self._keyset_config = self._get_config() self._aad_callback = kwargs.pop("aad_callback", lambda x: b"") super(BaseEncryptedField, self).__init__(*args, **kwargs) - def _get_config(self) -> Dict[str, Any]: + def _get_config(self) -> KeysetConfig: config = getattr(settings, "TINK_FIELDS_CONFIG", None) if config is None: raise ImproperlyConfigured( f"Could not find `TINK_FIELDS_CONFIG` attribute in settings" ) - return config - - def _get_tink_keyset_handle(self) -> KeysetHandle: - """Read the configuration for the requested keyset and return a respective keyset handle""" - config = self._get_config() if self._keyset not in config: raise ImproperlyConfigured( @@ -93,13 +66,9 @@ def _get_tink_keyset_handle(self) -> KeysetHandle: ) keyset_config = KeysetConfig(**config[self._keyset]) - keyset_config.validate() + keyset_config.validate() # TODO: Reuse config - with open(keyset_config.path, "r") as f: - reader = JsonKeysetReader(f.read()) - if keyset_config.cleartext: - return cleartext_keyset_handle.read(reader) - return read_keyset_handle(reader, keyset_config.master_key_aead) + return keyset_config def get_internal_type(self) -> str: return self._internal_type @@ -129,24 +98,20 @@ class EncryptedField(BaseEncryptedField): @cached_property def _aead_primitive(self) -> aead.Aead: - return self._keyset_handle.primitive(aead.Aead) + return self._keyset_config.primitive(aead.Aead) def get_db_prep_save(self, value: Any, connection: "BaseDatabaseWrapper") -> Any: val = super(EncryptedField, self).get_db_prep_save(value, connection) if val is not None: return connection.Database.Binary( - self._aead_primitive.encrypt( - force_bytes(val), self._aad_callback(self) - ) + self._aead_primitive.encrypt(force_bytes(val), self._aad_callback(self)) ) def from_db_value(self, value, expression, connection, *args): if value is not None: return self.to_python( self.to_python_prepare( - self._aead_primitive.decrypt( - bytes(value), self._aad_callback(self) - ) + self._aead_primitive.decrypt(bytes(value), self._aad_callback(self)) ) ) @@ -158,12 +123,11 @@ class DeterministicEncryptedField(BaseEncryptedField): @cached_property def _daead_primitive(self) -> daead.DeterministicAead: - return self._keyset_handle.primitive(daead.DeterministicAead) + return self._keyset_config.primitive(daead.DeterministicAead) def get_db_prep_value( - self, value: Any, connection: "BaseDatabaseWrappr", prepared=False + self, value: Any, connection: "BaseDatabaseWrapper", prepared=False ) -> Any: - val = super(DeterministicEncryptedField, self).get_db_prep_value( value, connection, prepared ) @@ -196,7 +160,7 @@ def get_prep_lookup(self): lookup_allowlist = { (object, "isnull"), - (DeterministicEncryptedField, "exact"), + (DeterministicEncryptedField, "exact"), # TODO: Support key rotation } diff --git a/tink_fields/models.py b/tink_fields/models.py new file mode 100644 index 0000000..730ac18 --- /dev/null +++ b/tink_fields/models.py @@ -0,0 +1,228 @@ +from typing import List, TypeVar, Type, Optional, Set, Union + +from django.db import models +from django.db import transaction +from django.utils.functional import cached_property +from tink.proto import tink_pb2 + +from .fields import EncryptedBinaryField +from tink.core import PrimitiveSet, Registry, KeyManager +from tink.core._primitive_set import Entry +from tink.core import crypto_format + +__all__ = ["Keyset", "Key"] +P = TypeVar("P") + + +class Keyset(models.Model): + name = models.CharField(max_length=100, unique=True) + type_url = models.CharField(max_length=250) + + @classmethod + def create(cls, name: str, key_template: tink_pb2.KeyTemplate) -> "Keyset": + """Create a keyset with one primary key""" + instance = cls(name=name, type_url=key_template.type_url) + instance.save() + + key = instance.generate_key(key_template) + key.is_primary = True + key.save() + + return instance + + def generate_key(self, key_template: tink_pb2.KeyTemplate) -> "Key": + """Create and save a key""" + key_data = self.key_manager().new_key_data(key_template) + key = Key.create_from_keydata(self, key_data, key_template.output_prefix_type) + + return key + + def set_primary_key(self, key: Union["Key", int]): + key_id = key + if isinstance(key, Key): + key_id = key.pk + + with transaction.atomic(): + self.key_set.update(is_primary=False) + self.key_set.filter(pk=key_id).update(is_primary=True) + + def key_manager(self) -> "KeyManager": + return Registry.key_manager(self.type_url) + + @property + def primitive(self) -> P: + """Get primitive of the stored type""" + return Registry.wrap( + self.primitive_set, + self.key_manager().primitive_class(), + ) + + @cached_property + def primitive_set(self) -> PrimitiveSet: + return _DatabasePrimitiveKeyset(self, self.key_manager().primitive_class()) + + +class _DatabasePrimitiveKeyset(PrimitiveSet[P]): + _keyset: Keyset + + def __init__(self, keyset: Keyset, primitive_class: Type[P]): + super().__init__(primitive_class) + self._keyset = keyset + del self._primary + + def primitive_from_identifier(self, identifier: bytes) -> List[Entry]: + for key in ( + self._keyset.key_set.filter(output_prefix=identifier) + .exclude(id__in=list(self._all_cached_key_ids())) + .all() + ): + self._add_key_to_cache(key) + + return super().primitive_from_identifier(identifier) + + def entry_by_id(self, identifier: bytes, key_id: int) -> Optional[Entry]: + primitives = self._primitives.get(identifier, []) + for item in primitives: + if item.key_id == key_id: + return item + + def _add_key_to_cache(self, key: "Key"): + if not self.entry_by_id(key.output_prefix, key.id): + entries = self._primitives.setdefault(key.output_prefix, []) + entries.append(key.entry) + + def _add_entry_to_cache(self, entry: Entry): + entries = self._primitives.setdefault(entry.identifier, []) + entries.append(entry) # XXX: This does not check for dupes + + def _all_cached_key_ids(self) -> Set[int]: + out = set() + for keys in self._primitives.values(): + for key in keys: + out.add(key.key_id) + + return out + + def all(self) -> List[List[Entry]]: + for key in self._keyset.key_set.exclude( + id__in=list(self._all_cached_key_ids()) + ).all(): + self._add_key_to_cache(key) + + return super().all() + + def add_primitive(self, primitive: P, key: tink_pb2.Keyset.Key) -> Entry: + assert isinstance(primitive, self._primitive_class) + + key = Key.from_key(self._keyset, key) + key.save() + + self._add_key_to_cache(key) + + return key.entry + + def set_primary(self, entry: Entry) -> None: + self._keyset.set_primary_key(entry.key_id) + + def primary(self) -> Entry: + key = self._keyset.key_set.get(is_primary=True) + entry = self.entry_by_id(key.output_prefix, key.id) + if entry: + return entry + + self._add_key_to_cache(key) + return key.entry + + +class Key(models.Model): + """Key instance in a keyset. + + It is expected that Key is immutable except for is_primary, status field""" + + keyset = models.ForeignKey(Keyset, on_delete=models.CASCADE, editable=False) + is_primary = models.BooleanField() + + # Serialized KeyData + key_data = EncryptedBinaryField(editable=False) + status = models.PositiveIntegerField(choices=tink_pb2.KeyStatusType.items()) + output_prefix_type = models.PositiveIntegerField( + choices=tink_pb2.OutputPrefixType.items() + ) + # Output prefix can be derived from output_prefix_type + id, however + # we store it here to be able to lookup without parsing Tink format ourselves + output_prefix = models.BinaryField(editable=False) + + @cached_property + def key(self) -> tink_pb2.Keyset.Key: + return tink_pb2.Keyset.Key( + key_data=self.key_data_pb, + status=self.status, + key_id=self.id, + output_prefix_type=self.output_prefix_type, + ) + + @property + def key_data_pb(self) -> tink_pb2.KeyData: + out = tink_pb2.KeyData() + out.ParseFromString(self.key_data) + return out + + @property + def primitive(self) -> P: + return Registry.primitive( + self.key_data_pb, + self.keyset.key_manager().primitive_class(), + ) + + @cached_property + def entry(self) -> Entry: + return Entry( + primitive=self.primitive, + identifier=self.output_prefix, + status=self.status, + output_prefix_type=self.output_prefix_type, + key_id=self.id, + ) + + @classmethod + def from_key(cls, keyset: "Keyset", key: tink_pb2.Keyset.Key): + return cls( + id=key.key_id, + is_primary=False, + keyset=keyset, + data=key.key_data.SerializeToString(), + status=key.status, + output_prefix=crypto_format.output_prefix(key), + output_prefix_type=key.output_prefix_type, + ) + + @classmethod + def create_from_keydata( + cls, + keyset: "Keyset", + keydata: tink_pb2.KeyData, + output_prefix_type: tink_pb2.OutputPrefixType, + ): + with transaction.atomic(): + out = cls( + keyset=keyset, + is_primary=False, + key_data=keydata.SerializeToString(), + status=tink_pb2.ENABLED, + output_prefix_type=output_prefix_type, + ) + out.save() + out.refresh_from_db() + out.output_prefix = crypto_format.output_prefix(out.key) + out.save() + + return out + + class Meta: + constraints = [ + models.UniqueConstraint( + name="one_primary_per_keyset", + fields=("keyset", "is_primary"), + condition=models.Q(is_primary=True), + ), + ] diff --git a/tink_fields/test/models.py b/tink_fields/test/models.py index 9752015..10c22d7 100644 --- a/tink_fields/test/models.py +++ b/tink_fields/test/models.py @@ -35,6 +35,10 @@ class EncryptedNullable(models.Model): value = fields.EncryptedIntegerField(null=True) +class EncryptedIntEnvelope(models.Model): + value = fields.EncryptedIntegerField(keyset="db_aead") + + def sample_aad_provider(instance) -> bytes: return force_bytes(instance.__class__.__name__) @@ -61,3 +65,7 @@ class DeterministicEncryptedInt(models.Model): class DeterministicEncryptedIntNullable(models.Model): value = fields.DeterministicEncryptedIntegerField(keyset="daead", null=True) + + +class DeterministicEncryptedIntEnvelope(models.Model): + value = fields.DeterministicEncryptedIntegerField(keyset="db_daead") diff --git a/tink_fields/test/settings/base.py b/tink_fields/test/settings/base.py index bc23f2c..7e5db41 100644 --- a/tink_fields/test/settings/base.py +++ b/tink_fields/test/settings/base.py @@ -1,4 +1,5 @@ INSTALLED_APPS = [ + "tink_fields", "tink_fields.test", ] diff --git a/tink_fields/test/settings/sqlite.py b/tink_fields/test/settings/sqlite.py index 27c3ae4..111ae60 100644 --- a/tink_fields/test/settings/sqlite.py +++ b/tink_fields/test/settings/sqlite.py @@ -30,4 +30,10 @@ "cleartext": True, "path": os.path.join(HERE, "../test_plaintext_daead_keyset.json"), }, + "db_aead": { + "db_name": "aead", + }, + "db_daead": { + "db_name": "daead", + }, } diff --git a/tink_fields/test/test_fields.py b/tink_fields/test/test_fields.py index bef6676..5589d2e 100644 --- a/tink_fields/test/test_fields.py +++ b/tink_fields/test/test_fields.py @@ -3,10 +3,24 @@ from django.db import connection from django.utils.encoding import force_bytes import pytest +from tink import aead, daead from . import models +@pytest.fixture(autouse=True) +def configured_db_keyset(db): + for name, template in { + "aead": aead.aead_key_templates.AES256_GCM, + "daead": daead.deterministic_aead_key_templates.AES256_SIV, + }.items(): + model = Keyset.create(name=name, key_template=template) + model.save() + + +from ..models import Keyset + + @pytest.mark.parametrize( "model,vals", [ @@ -21,6 +35,7 @@ [datetime(2015, 2, 5, 15), datetime(2015, 2, 8, 16)], ), (models.EncryptedCharWithAlternateKeyset, ["foo", "bar"]), + (models.EncryptedIntEnvelope, [1, 2]), (models.EncryptedBinary, [b"1234", b"asdf"]), ], ) @@ -58,6 +73,7 @@ def test_encrypted_nullable(db): (models.DeterministicEncryptedChar, ["one", "two"]), (models.DeterministicEncryptedEmail, ["a@example.com", "b@example.com"]), (models.DeterministicEncryptedInt, [1, 2]), + (models.DeterministicEncryptedIntEnvelope, [1, 2]), ], ) class TestDeterministicEncryptedFieldQueries(object): diff --git a/tink_fields/test/test_models.py b/tink_fields/test/test_models.py new file mode 100644 index 0000000..aa051e3 --- /dev/null +++ b/tink_fields/test/test_models.py @@ -0,0 +1,53 @@ +from tink import aead, daead +from tink_fields.models import Keyset + +TEST_DATA = b"hello world" +ASSOC_DATA = b"ad" + + +def test_aead_encrypt(db): + key_template = aead.aead_key_templates.AES256_GCM + ks = Keyset.create("aead", key_template) + + ciphertext = ks.primitive.encrypt(TEST_DATA, ASSOC_DATA) + + assert ks.primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_aead_decrypt_non_primary(db): + key_template = aead.aead_key_templates.AES256_GCM + ks = Keyset.create("aead", key_template) + + ciphertext = ks.primitive.encrypt(TEST_DATA, ASSOC_DATA) + + secondary_key = ks.generate_key(key_template) + ks.set_primary_key(secondary_key) + secondary_key.refresh_from_db() + assert secondary_key.is_primary + + assert ks.primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_daead_encrypt(db): + key_template = daead.deterministic_aead_key_templates.AES256_SIV + ks = Keyset.create("daead", key_template) + + ciphertext = ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) + + assert ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) == ciphertext + assert ks.primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_aead_decrypt_non_primary(db): + key_template = daead.deterministic_aead_key_templates.AES256_SIV + ks = Keyset.create("daead", key_template) + + ciphertext = ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) + + secondary_key = ks.generate_key(key_template) + ks.set_primary_key(secondary_key) + secondary_key.refresh_from_db() + assert secondary_key.is_primary + + assert ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) != ciphertext + assert ks.primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA From 53f6818b71415369417a65f7336638266f9c06e5 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 17:25:05 +0700 Subject: [PATCH 04/10] Add migrations and management command --- tink_fields/management/__init__.py | 0 tink_fields/management/commands/__init__.py | 0 tink_fields/management/commands/tink.py | 143 ++++++++++++++++++++ tink_fields/migrations/0001_initial.py | 87 ++++++++++++ tink_fields/migrations/__init__.py | 0 tink_fields/models.py | 32 ++++- 6 files changed, 257 insertions(+), 5 deletions(-) create mode 100644 tink_fields/management/__init__.py create mode 100644 tink_fields/management/commands/__init__.py create mode 100644 tink_fields/management/commands/tink.py create mode 100644 tink_fields/migrations/0001_initial.py create mode 100644 tink_fields/migrations/__init__.py diff --git a/tink_fields/management/__init__.py b/tink_fields/management/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tink_fields/management/commands/__init__.py b/tink_fields/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tink_fields/management/commands/tink.py b/tink_fields/management/commands/tink.py new file mode 100644 index 0000000..adc3b05 --- /dev/null +++ b/tink_fields/management/commands/tink.py @@ -0,0 +1,143 @@ +from django.core.management.base import BaseCommand, CommandError, CommandParser +from tink.proto import tink_pb2 +from google.protobuf import json_format + +from tink_fields.models import Keyset, Key +from tink.aead import aead_key_templates +from tink.daead import deterministic_aead_key_templates + + +def get_key_template_by_name(name: str) -> tink_pb2.KeyTemplate: + template = getattr(aead_key_templates, name, None) + if template: + return template + + return getattr(deterministic_aead_key_templates, name) + + +class Command(BaseCommand): + help = "Tink key management" + + def add_arguments(self, parser: CommandParser): + subparsers = parser.add_subparsers(dest="subcommand", required=True) + + create_keyset = subparsers.add_parser("create-keyset", help="Create new keyset") + create_keyset.add_argument("name", help="Key name") + create_keyset.add_argument( + "template", + help="Key template (see tinkey list-key-templates)", + ) + + create_key = subparsers.add_parser( + "create-key", help="Create a non-primary key in a keyset" + ) + create_key.add_argument("name", help="Keyset name") + create_key.add_argument( + "template", help="Key template (see tinkey list-key-templates)" + ) + + promote_key = subparsers.add_parser( + "promote-key", help="Promote key to primary in a keyset" + ) + promote_key.add_argument("name", help="Keyset name") + promote_key.add_argument("id", help="Key ID", type=int) + + list_keys = subparsers.add_parser("list-keys", help="List keys a keyset") + list_keys.add_argument("name", help="Keyset name") + + delete_keyset = subparsers.add_parser( + "delete-keyset", help="Delete keyset and all associated keys" + ) + delete_keyset.add_argument("name", help="Keyset name") + + unsafe_export_keyset = subparsers.add_parser( + "unsafe-export-keyset", + help="Export keyset (INCLUDING KEY MATERIALS) as JSON", + ) + unsafe_export_keyset.add_argument("name", help="Keyset name") + + export_keyset_info = subparsers.add_parser( + "export-keyset-info", + help="Export keyset info as JSON", + ) + export_keyset_info.add_argument("name", help="Keyset name") + + def handle(self, *args, **options): + if options["subcommand"] == "create-keyset": + return self.create_keyset(*args, **options) + elif options["subcommand"] == "create-key": + return self.create_key(*args, **options) + elif options["subcommand"] == "promote-key": + return self.promote_key(*args, **options) + elif options["subcommand"] == "list-keys": + return self.list_keys(*args, **options) + elif options["subcommand"] == "unsafe-export-keyset": + return self.unsafe_export_keyset(*args, **options) + elif options["subcommand"] == "export-keyset-info": + return self.export_keyset_info(*args, **options) + elif options["subcommand"] == "delete-keyset": + return self.delete_keyset(*args, **options) + else: + raise CommandError("invalid subcommand") + + def create_keyset(self, name: str, template: str, *args, **options): + keyset = Keyset.create(name, get_key_template_by_name(template)) + self.stdout.write(self.style.SUCCESS(f"Created keyset {keyset.id}")) + + def create_key(self, name: str, template: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + key = keyset.generate_key(get_key_template_by_name(template)) + self.stdout.write(self.style.SUCCESS(f"Created key {key.id}")) + + def promote_key(self, name: str, id: int, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + try: + key = keyset.key_set.get(pk=id) + except Key.DoesNotExist: + raise CommandError(f"Key ID {id} not found in keyset") + + keyset.set_primary_key(key) + self.stdout.write(f"Key {key.pk} promoted to primary") + + def list_keys(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + self.stdout.write(f"Key type: {keyset.type_url}") + self.stdout.write("") + self.stdout.write("ID\tPrimary\tStatus\tPrefix") + for key in keyset.key_set.all(): + self.stdout.write( + f"{key.id}\t{'Y' if key.is_primary else 'N'}\t{tink_pb2.KeyStatusType.Name(key.status)}\t{tink_pb2.OutputPrefixType.Name(key.output_prefix_type)}" + ) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + def delete_keyset(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + keyset.delete() + self.stdout.write(self.style.SUCCESS(f'Deleted keyset "{name}"')) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + def unsafe_export_keyset(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + self.stdout.write(json_format.MessageToJson(keyset.export_keyset())) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + def export_keyset_info(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + self.stdout.write(json_format.MessageToJson(keyset.export_keyset_info())) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') diff --git a/tink_fields/migrations/0001_initial.py b/tink_fields/migrations/0001_initial.py new file mode 100644 index 0000000..2f0b951 --- /dev/null +++ b/tink_fields/migrations/0001_initial.py @@ -0,0 +1,87 @@ +# Generated by Django 3.2.13 on 2022-06-12 09:54 + +from django.db import migrations, models +import django.db.models.deletion +import tink_fields.fields + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="Keyset", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=100, unique=True)), + ("type_url", models.CharField(max_length=250)), + ], + ), + migrations.CreateModel( + name="Key", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("is_primary", models.BooleanField()), + ("key_data", tink_fields.fields.EncryptedBinaryField()), + ( + "status", + models.PositiveIntegerField( + choices=[ + ("UNKNOWN_STATUS", 0), + ("ENABLED", 1), + ("DISABLED", 2), + ("DESTROYED", 3), + ] + ), + ), + ( + "output_prefix_type", + models.PositiveIntegerField( + choices=[ + ("UNKNOWN_PREFIX", 0), + ("TINK", 1), + ("LEGACY", 2), + ("RAW", 3), + ("CRUNCHY", 4), + ] + ), + ), + ("output_prefix", models.BinaryField()), + ( + "keyset", + models.ForeignKey( + editable=False, + on_delete=django.db.models.deletion.CASCADE, + to="tink_fields.keyset", + ), + ), + ], + ), + migrations.AddConstraint( + model_name="key", + constraint=models.UniqueConstraint( + condition=models.Q(("is_primary", True)), + fields=("keyset", "is_primary"), + name="one_primary_per_keyset", + ), + ), + ] diff --git a/tink_fields/migrations/__init__.py b/tink_fields/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tink_fields/models.py b/tink_fields/models.py index 730ac18..3d3b911 100644 --- a/tink_fields/models.py +++ b/tink_fields/models.py @@ -21,12 +21,13 @@ class Keyset(models.Model): @classmethod def create(cls, name: str, key_template: tink_pb2.KeyTemplate) -> "Keyset": """Create a keyset with one primary key""" - instance = cls(name=name, type_url=key_template.type_url) - instance.save() + with transaction.atomic(): + instance = cls(name=name, type_url=key_template.type_url) + instance.save() - key = instance.generate_key(key_template) - key.is_primary = True - key.save() + key = instance.generate_key(key_template) + key.is_primary = True + key.save() return instance @@ -61,6 +62,18 @@ def primitive(self) -> P: def primitive_set(self) -> PrimitiveSet: return _DatabasePrimitiveKeyset(self, self.key_manager().primitive_class()) + def export_keyset(self) -> tink_pb2.Keyset: + return tink_pb2.Keyset( + primary_key_id=self.key_set.get(is_primary=True).id, + key=[key.key for key in self.key_set.all()], + ) + + def export_keyset_info(self) -> tink_pb2.KeysetInfo: + return tink_pb2.KeysetInfo( + primary_key_id=self.key_set.get(is_primary=True).id, + key_info=[key.key_info for key in self.key_set.all()], + ) + class _DatabasePrimitiveKeyset(PrimitiveSet[P]): _keyset: Keyset @@ -161,6 +174,15 @@ def key(self) -> tink_pb2.Keyset.Key: output_prefix_type=self.output_prefix_type, ) + @cached_property + def key_info(self) -> tink_pb2.KeysetInfo.KeyInfo: + return tink_pb2.KeysetInfo.KeyInfo( + type_url=self.keyset.type_url, + status=self.status, + key_id=self.id, + output_prefix_type=self.output_prefix_type, + ) + @property def key_data_pb(self) -> tink_pb2.KeyData: out = tink_pb2.KeyData() From 285a6af709894e79d7acf06e6dff48c3626ea9a1 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 17:30:57 +0700 Subject: [PATCH 05/10] Store KeysetConfig globally --- tink_fields/fields.py | 44 +++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/tink_fields/fields.py b/tink_fields/fields.py index 17f4a09..23f3e77 100644 --- a/tink_fields/fields.py +++ b/tink_fields/fields.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING, Dict from django.db import models from django.core.exceptions import FieldError, ImproperlyConfigured from django.utils.functional import cached_property @@ -31,6 +31,32 @@ "DeterministicEncryptedIntegerField", ] +_config: Dict[str, KeysetConfig] = {} + + +def _get_config(keyset: str) -> KeysetConfig: + global _config + + if keyset in _config: + return _config[keyset] + + config = getattr(settings, "TINK_FIELDS_CONFIG", None) + if config is None: + raise ImproperlyConfigured( + f"Could not find `TINK_FIELDS_CONFIG` attribute in settings" + ) + + if keyset not in config: + raise ImproperlyConfigured( + f"Could not find configuration for keyset `{keyset}` in `TINK_FIELDS_CONFIG`" + ) + + keyset_config = KeysetConfig(**config[keyset]) + keyset_config.validate() + _config[keyset] = keyset_config + + return keyset_config + class BaseEncryptedField(models.Field): _unsupported_properties = ["primary_key", "db_index", "unique"] @@ -54,21 +80,7 @@ def __init__(self, *args, **kwargs): super(BaseEncryptedField, self).__init__(*args, **kwargs) def _get_config(self) -> KeysetConfig: - config = getattr(settings, "TINK_FIELDS_CONFIG", None) - if config is None: - raise ImproperlyConfigured( - f"Could not find `TINK_FIELDS_CONFIG` attribute in settings" - ) - - if self._keyset not in config: - raise ImproperlyConfigured( - f"Could not find configuration for keyset `{self._keyset}` in `TINK_FIELDS_CONFIG`" - ) - - keyset_config = KeysetConfig(**config[self._keyset]) - keyset_config.validate() # TODO: Reuse config - - return keyset_config + return _get_config(self._keyset) def get_internal_type(self) -> str: return self._internal_type From 43d499a703a930f01182cc2bd8d131c1e8dc57f5 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 20:28:58 +0700 Subject: [PATCH 06/10] Convert exact match queries into in query --- tink_fields/fields.py | 62 ++++++++++++++++++++++++--------- tink_fields/models.py | 18 ++++++++-- tink_fields/test/test_fields.py | 31 +++++++++++++---- 3 files changed, 85 insertions(+), 26 deletions(-) diff --git a/tink_fields/fields.py b/tink_fields/fields.py index 23f3e77..4cc4359 100644 --- a/tink_fields/fields.py +++ b/tink_fields/fields.py @@ -1,9 +1,9 @@ from typing import Any, Callable, TYPE_CHECKING, Dict from django.db import models +from django.db.models.lookups import In from django.core.exceptions import FieldError, ImproperlyConfigured from django.utils.functional import cached_property from tink import ( - KeysetHandle, aead, daead, ) @@ -160,6 +160,31 @@ def from_db_value(self, value, expression, connection, *args): ) ) + def get_db_values_all_keys( + self, value: Any, connection: "BaseDatabaseWrapper", prepared=False + ) -> Any: + """Like get_db_prep_value but return array of values encrypted with every keys in the keyset""" + val = super(DeterministicEncryptedField, self).get_db_prep_value( + value, connection, prepared + ) + if val is None: + return [] + + out = [] + aad = self._aad_callback(self) + # XXX: This would run another query. Is there any way to signal that we want a cached all using + # the same primitive set interface? + for items in self._daead_primitive._primitive_set.all(): + for key in items: + out.append( + connection.Database.Binary( + key.identifier + + key.primitive.encrypt_deterministically(force_bytes(val), aad) + ) + ) + + return out + def get_prep_lookup(self): """Raise errors for unsupported lookups""" @@ -170,29 +195,32 @@ def get_prep_lookup(self): ) -lookup_allowlist = { - (object, "isnull"), - (DeterministicEncryptedField, "exact"), # TODO: Support key rotation -} +class DeterministicEncryptedFieldExactLookup(In): + lookup_name = "exact" + def get_prep_lookup(self): + self.rhs = [self.rhs] + return super().get_prep_lookup() -def is_lookup_allowed(cls, name) -> bool: - for item in lookup_allowlist: - if item[1] == name and issubclass(cls, item[0]): - return True - return False + def get_db_prep_lookup(self, value, connection): + assert len(value) == 1 + return ( + "%s", + self.lhs.output_field.get_db_values_all_keys( + list(value)[0], connection, prepared=True + ), + ) -# Override all lookups except in lookup_allowlist to get_prep_lookup for name, lookup in models.Field.class_lookups.items(): for cls in (EncryptedField, DeterministicEncryptedField): - if is_lookup_allowed(cls, name): - continue + if name != "isnull": + lookup_class = type( + cls.__name__ + name, (lookup,), {"get_prep_lookup": get_prep_lookup} + ) + cls.register_lookup(lookup_class) - lookup_class = type( - cls.__name__ + "__" + name, (lookup,), {"get_prep_lookup": get_prep_lookup} - ) - cls.register_lookup(lookup_class) +DeterministicEncryptedField.register_lookup(DeterministicEncryptedFieldExactLookup) class EncryptedTextField(EncryptedField, models.TextField): diff --git a/tink_fields/models.py b/tink_fields/models.py index 3d3b911..a8cb57b 100644 --- a/tink_fields/models.py +++ b/tink_fields/models.py @@ -84,6 +84,10 @@ def __init__(self, keyset: Keyset, primitive_class: Type[P]): del self._primary def primitive_from_identifier(self, identifier: bytes) -> List[Entry]: + # Fast path - non raw keys will have unique identifiers + if len(identifier) > 0 and identifier in self._primitives: + return super().primitive_from_identifier(identifier) + for key in ( self._keyset.key_set.filter(output_prefix=identifier) .exclude(id__in=list(self._all_cached_key_ids())) @@ -93,14 +97,22 @@ def primitive_from_identifier(self, identifier: bytes) -> List[Entry]: return super().primitive_from_identifier(identifier) - def entry_by_id(self, identifier: bytes, key_id: int) -> Optional[Entry]: + def _entry_by_id(self, identifier: bytes, key_id: int) -> Optional[Entry]: + # Fast path - non raw keys will have unique identifiers + if ( + len(identifier) > 0 + and identifier in self._primitives + and len(self._primitives[identifier]) == 1 + ): + return self._primitives[identifier][0] + primitives = self._primitives.get(identifier, []) for item in primitives: if item.key_id == key_id: return item def _add_key_to_cache(self, key: "Key"): - if not self.entry_by_id(key.output_prefix, key.id): + if not self._entry_by_id(key.output_prefix, key.id): entries = self._primitives.setdefault(key.output_prefix, []) entries.append(key.entry) @@ -139,7 +151,7 @@ def set_primary(self, entry: Entry) -> None: def primary(self) -> Entry: key = self._keyset.key_set.get(is_primary=True) - entry = self.entry_by_id(key.output_prefix, key.id) + entry = self._entry_by_id(key.output_prefix, key.id) if entry: return entry diff --git a/tink_fields/test/test_fields.py b/tink_fields/test/test_fields.py index 5589d2e..ef61418 100644 --- a/tink_fields/test/test_fields.py +++ b/tink_fields/test/test_fields.py @@ -6,6 +6,7 @@ from tink import aead, daead from . import models +from ..models import Keyset @pytest.fixture(autouse=True) @@ -18,9 +19,6 @@ def configured_db_keyset(db): model.save() -from ..models import Keyset - - @pytest.mark.parametrize( "model,vals", [ @@ -103,6 +101,27 @@ def test_search(self, db, model, vals): assert out[0].value == vals[0] -def test_encrypted_deterministic_nullable(db): - models.DeterministicEncryptedIntNullable(value=None).save() - assert models.DeterministicEncryptedIntNullable.objects.get(value=None) +def test_rotated_deterministic_field(db): + value = 12345678 + + models.DeterministicEncryptedIntEnvelope(value=value).save() + assert ( + models.DeterministicEncryptedIntEnvelope.objects.filter(value=value).count() + == 1 + ) + + keyset = Keyset.objects.get(name="daead") + new_key = keyset.generate_key(daead.deterministic_aead_key_templates.AES256_SIV) + keyset.set_primary_key(new_key) + + value_2 = 23456789 + models.DeterministicEncryptedIntEnvelope(value=value_2).save() + + assert ( + models.DeterministicEncryptedIntEnvelope.objects.filter(value=value).count() + == 1 + ) + assert ( + models.DeterministicEncryptedIntEnvelope.objects.filter(value=value_2).count() + == 1 + ) From e980cb332cfcf1246d5ab58ea5ac15c9d9cbe4a6 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 21:04:01 +0700 Subject: [PATCH 07/10] Rename create-key to add-key --- tink_fields/management/commands/tink.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tink_fields/management/commands/tink.py b/tink_fields/management/commands/tink.py index adc3b05..613ab58 100644 --- a/tink_fields/management/commands/tink.py +++ b/tink_fields/management/commands/tink.py @@ -28,11 +28,11 @@ def add_arguments(self, parser: CommandParser): help="Key template (see tinkey list-key-templates)", ) - create_key = subparsers.add_parser( - "create-key", help="Create a non-primary key in a keyset" + add_key = subparsers.add_parser( + "add-key", help="Add a non-primary key to a keyset" ) - create_key.add_argument("name", help="Keyset name") - create_key.add_argument( + add_key.add_argument("name", help="Keyset name") + add_key.add_argument( "template", help="Key template (see tinkey list-key-templates)" ) @@ -65,8 +65,8 @@ def add_arguments(self, parser: CommandParser): def handle(self, *args, **options): if options["subcommand"] == "create-keyset": return self.create_keyset(*args, **options) - elif options["subcommand"] == "create-key": - return self.create_key(*args, **options) + elif options["subcommand"] == "add-key": + return self.add_key(*args, **options) elif options["subcommand"] == "promote-key": return self.promote_key(*args, **options) elif options["subcommand"] == "list-keys": @@ -84,7 +84,7 @@ def create_keyset(self, name: str, template: str, *args, **options): keyset = Keyset.create(name, get_key_template_by_name(template)) self.stdout.write(self.style.SUCCESS(f"Created keyset {keyset.id}")) - def create_key(self, name: str, template: str, *args, **options): + def add_key(self, name: str, template: str, *args, **options): try: keyset = Keyset.objects.get(name=name) except Keyset.DoesNotExist: From 7705317b2e1beab03470a3b311fb7cfebb43a62d Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 21:09:14 +0700 Subject: [PATCH 08/10] Rename list-keys to list-keyset and implement compatible output to tink list-keyset. Also change db keyset name to db_keyset --- tink_fields/management/commands/tink.py | 28 +++++++------------------ tink_fields/models.py | 2 +- tink_fields/test/settings/sqlite.py | 4 ++++ 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/tink_fields/management/commands/tink.py b/tink_fields/management/commands/tink.py index 613ab58..ca77d63 100644 --- a/tink_fields/management/commands/tink.py +++ b/tink_fields/management/commands/tink.py @@ -1,6 +1,6 @@ from django.core.management.base import BaseCommand, CommandError, CommandParser from tink.proto import tink_pb2 -from google.protobuf import json_format +from google.protobuf import json_format, text_format from tink_fields.models import Keyset, Key from tink.aead import aead_key_templates @@ -42,8 +42,8 @@ def add_arguments(self, parser: CommandParser): promote_key.add_argument("name", help="Keyset name") promote_key.add_argument("id", help="Key ID", type=int) - list_keys = subparsers.add_parser("list-keys", help="List keys a keyset") - list_keys.add_argument("name", help="Keyset name") + list_keyset = subparsers.add_parser("list-keyset", help="List keys a keyset") + list_keyset.add_argument("name", help="Keyset name") delete_keyset = subparsers.add_parser( "delete-keyset", help="Delete keyset and all associated keys" @@ -56,12 +56,6 @@ def add_arguments(self, parser: CommandParser): ) unsafe_export_keyset.add_argument("name", help="Keyset name") - export_keyset_info = subparsers.add_parser( - "export-keyset-info", - help="Export keyset info as JSON", - ) - export_keyset_info.add_argument("name", help="Keyset name") - def handle(self, *args, **options): if options["subcommand"] == "create-keyset": return self.create_keyset(*args, **options) @@ -69,12 +63,10 @@ def handle(self, *args, **options): return self.add_key(*args, **options) elif options["subcommand"] == "promote-key": return self.promote_key(*args, **options) - elif options["subcommand"] == "list-keys": - return self.list_keys(*args, **options) + elif options["subcommand"] == "list-keyset": + return self.list_keyset(*args, **options) elif options["subcommand"] == "unsafe-export-keyset": return self.unsafe_export_keyset(*args, **options) - elif options["subcommand"] == "export-keyset-info": - return self.export_keyset_info(*args, **options) elif options["subcommand"] == "delete-keyset": return self.delete_keyset(*args, **options) else: @@ -107,16 +99,10 @@ def promote_key(self, name: str, id: int, *args, **options): keyset.set_primary_key(key) self.stdout.write(f"Key {key.pk} promoted to primary") - def list_keys(self, name: str, *args, **options): + def list_keyset(self, name: str, *args, **options): try: keyset = Keyset.objects.get(name=name) - self.stdout.write(f"Key type: {keyset.type_url}") - self.stdout.write("") - self.stdout.write("ID\tPrimary\tStatus\tPrefix") - for key in keyset.key_set.all(): - self.stdout.write( - f"{key.id}\t{'Y' if key.is_primary else 'N'}\t{tink_pb2.KeyStatusType.Name(key.status)}\t{tink_pb2.OutputPrefixType.Name(key.output_prefix_type)}" - ) + self.stdout.write(text_format.MessageToString(keyset.export_keyset_info())) except Keyset.DoesNotExist: raise CommandError(f'Keyset "{name}" not found') diff --git a/tink_fields/models.py b/tink_fields/models.py index a8cb57b..d49f15f 100644 --- a/tink_fields/models.py +++ b/tink_fields/models.py @@ -168,7 +168,7 @@ class Key(models.Model): is_primary = models.BooleanField() # Serialized KeyData - key_data = EncryptedBinaryField(editable=False) + key_data = EncryptedBinaryField(keyset="db_keyset", editable=False) status = models.PositiveIntegerField(choices=tink_pb2.KeyStatusType.items()) output_prefix_type = models.PositiveIntegerField( choices=tink_pb2.OutputPrefixType.items() diff --git a/tink_fields/test/settings/sqlite.py b/tink_fields/test/settings/sqlite.py index 111ae60..956bd27 100644 --- a/tink_fields/test/settings/sqlite.py +++ b/tink_fields/test/settings/sqlite.py @@ -30,6 +30,10 @@ "cleartext": True, "path": os.path.join(HERE, "../test_plaintext_daead_keyset.json"), }, + "db_keyset": { + "cleartext": True, + "path": os.path.join(HERE, "../test_plaintext_keyset.json"), + }, "db_aead": { "db_name": "aead", }, From 876cf780403d77a9c1ddf1c34605bf12f3f4aec4 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Sun, 12 Jun 2022 21:18:17 +0700 Subject: [PATCH 09/10] Refactor handle() --- tink_fields/management/commands/tink.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/tink_fields/management/commands/tink.py b/tink_fields/management/commands/tink.py index ca77d63..839de0e 100644 --- a/tink_fields/management/commands/tink.py +++ b/tink_fields/management/commands/tink.py @@ -57,20 +57,16 @@ def add_arguments(self, parser: CommandParser): unsafe_export_keyset.add_argument("name", help="Keyset name") def handle(self, *args, **options): - if options["subcommand"] == "create-keyset": - return self.create_keyset(*args, **options) - elif options["subcommand"] == "add-key": - return self.add_key(*args, **options) - elif options["subcommand"] == "promote-key": - return self.promote_key(*args, **options) - elif options["subcommand"] == "list-keyset": - return self.list_keyset(*args, **options) - elif options["subcommand"] == "unsafe-export-keyset": - return self.unsafe_export_keyset(*args, **options) - elif options["subcommand"] == "delete-keyset": - return self.delete_keyset(*args, **options) - else: - raise CommandError("invalid subcommand") + subcommand_map = { + "create_keyset": self.create_keyset, + "add-key": self.add_key, + "promote-key": self.promote_key, + "list-keyset": self.list_keyset, + "unsafe-export-keyset": self.unsafe_export_keyset, + "delete-keyset": self.delete_keyset, + } + + return subcommand_map[options["subcommand"]](*args, **options) def create_keyset(self, name: str, template: str, *args, **options): keyset = Keyset.create(name, get_key_template_by_name(template)) From 8f64bdf0ee1666ff9b360fc6c51513ddef526d79 Mon Sep 17 00:00:00 2001 From: Manatsawin Hanmongkolchai Date: Fri, 8 Jul 2022 23:33:29 +0700 Subject: [PATCH 10/10] Move type_url to individual key instead of keyset. Update primary key to be 1-1 relationship --- tink_fields/config.py | 2 +- tink_fields/migrations/0001_initial.py | 55 ++++++++++---------- tink_fields/models.py | 72 +++++++++++--------------- tink_fields/test/test_models.py | 49 +++++++++++++----- 4 files changed, 94 insertions(+), 84 deletions(-) diff --git a/tink_fields/config.py b/tink_fields/config.py index 3306a97..caab5f6 100644 --- a/tink_fields/config.py +++ b/tink_fields/config.py @@ -44,7 +44,7 @@ def primitive(self, cls: Type[P]) -> P: if self.path: return self._load_from_path.primitive(cls) if self.db_name: - return self._load_from_db.primitive + return self._load_from_db.primitive(cls) @cached_property def _load_from_path(self) -> KeysetHandle: diff --git a/tink_fields/migrations/0001_initial.py b/tink_fields/migrations/0001_initial.py index 2f0b951..fb88991 100644 --- a/tink_fields/migrations/0001_initial.py +++ b/tink_fields/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 3.2.13 on 2022-06-12 09:54 +# Generated by Django 4.1a1 on 2022-07-08 11:06 from django.db import migrations, models import django.db.models.deletion @@ -12,22 +12,6 @@ class Migration(migrations.Migration): dependencies = [] operations = [ - migrations.CreateModel( - name="Keyset", - fields=[ - ( - "id", - models.AutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), - ("name", models.CharField(max_length=100, unique=True)), - ("type_url", models.CharField(max_length=250)), - ], - ), migrations.CreateModel( name="Key", fields=[ @@ -40,7 +24,7 @@ class Migration(migrations.Migration): verbose_name="ID", ), ), - ("is_primary", models.BooleanField()), + ("type_url", models.CharField(max_length=250)), ("key_data", tink_fields.fields.EncryptedBinaryField()), ( "status", @@ -66,22 +50,39 @@ class Migration(migrations.Migration): ), ), ("output_prefix", models.BinaryField()), + ], + ), + migrations.CreateModel( + name="Keyset", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=100, unique=True)), ( - "keyset", + "primary_key", models.ForeignKey( - editable=False, - on_delete=django.db.models.deletion.CASCADE, - to="tink_fields.keyset", + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="primary_key", + to="tink_fields.key", ), ), ], ), - migrations.AddConstraint( + migrations.AddField( model_name="key", - constraint=models.UniqueConstraint( - condition=models.Q(("is_primary", True)), - fields=("keyset", "is_primary"), - name="one_primary_per_keyset", + name="keyset", + field=models.ForeignKey( + editable=False, + on_delete=django.db.models.deletion.CASCADE, + to="tink_fields.keyset", ), ), ] diff --git a/tink_fields/models.py b/tink_fields/models.py index d49f15f..c99ed70 100644 --- a/tink_fields/models.py +++ b/tink_fields/models.py @@ -16,24 +16,23 @@ class Keyset(models.Model): name = models.CharField(max_length=100, unique=True) - type_url = models.CharField(max_length=250) + primary_key = models.ForeignKey( + "Key", on_delete=models.PROTECT, related_name="primary_key", null=True + ) @classmethod def create(cls, name: str, key_template: tink_pb2.KeyTemplate) -> "Keyset": """Create a keyset with one primary key""" with transaction.atomic(): - instance = cls(name=name, type_url=key_template.type_url) + instance = cls(name=name) instance.save() - - key = instance.generate_key(key_template) - key.is_primary = True - key.save() + instance.primary_key = instance.generate_key(key_template) return instance def generate_key(self, key_template: tink_pb2.KeyTemplate) -> "Key": """Create and save a key""" - key_data = self.key_manager().new_key_data(key_template) + key_data = Registry.new_key_data(key_template) key = Key.create_from_keydata(self, key_data, key_template.output_prefix_type) return key @@ -43,34 +42,32 @@ def set_primary_key(self, key: Union["Key", int]): if isinstance(key, Key): key_id = key.pk - with transaction.atomic(): - self.key_set.update(is_primary=False) - self.key_set.filter(pk=key_id).update(is_primary=True) - - def key_manager(self) -> "KeyManager": - return Registry.key_manager(self.type_url) + self.primary_key_id = key_id + self.save(update_fields=["primary_key"]) - @property - def primitive(self) -> P: + def primitive(self, primitive_class: Type[P]) -> P: """Get primitive of the stored type""" + # Hack: assuming that a key would only contain one primitive_class + self._primitive_set._primitive_class = primitive_class return Registry.wrap( - self.primitive_set, - self.key_manager().primitive_class(), + self._primitive_set, + primitive_class, ) @cached_property - def primitive_set(self) -> PrimitiveSet: - return _DatabasePrimitiveKeyset(self, self.key_manager().primitive_class()) + def _primitive_set(self) -> PrimitiveSet: + # XXX: This would return keyset of the concrete type instead of interface type + return _DatabasePrimitiveKeyset(self, self.primary_key.primitive) def export_keyset(self) -> tink_pb2.Keyset: return tink_pb2.Keyset( - primary_key_id=self.key_set.get(is_primary=True).id, + primary_key_id=self.primary_key_id, key=[key.key for key in self.key_set.all()], ) def export_keyset_info(self) -> tink_pb2.KeysetInfo: return tink_pb2.KeysetInfo( - primary_key_id=self.key_set.get(is_primary=True).id, + primary_key_id=self.primary_key_id, key_info=[key.key_info for key in self.key_set.all()], ) @@ -84,7 +81,7 @@ def __init__(self, keyset: Keyset, primitive_class: Type[P]): del self._primary def primitive_from_identifier(self, identifier: bytes) -> List[Entry]: - # Fast path - non raw keys will have unique identifiers + # Fast path - non raw keys should have unique identifiers if len(identifier) > 0 and identifier in self._primitives: return super().primitive_from_identifier(identifier) @@ -98,7 +95,7 @@ def primitive_from_identifier(self, identifier: bytes) -> List[Entry]: return super().primitive_from_identifier(identifier) def _entry_by_id(self, identifier: bytes, key_id: int) -> Optional[Entry]: - # Fast path - non raw keys will have unique identifiers + # Fast path - non raw keys should have unique identifiers if ( len(identifier) > 0 and identifier in self._primitives @@ -150,7 +147,7 @@ def set_primary(self, entry: Entry) -> None: self._keyset.set_primary_key(entry.key_id) def primary(self) -> Entry: - key = self._keyset.key_set.get(is_primary=True) + key = self._keyset.primary_key entry = self._entry_by_id(key.output_prefix, key.id) if entry: return entry @@ -162,10 +159,10 @@ def primary(self) -> Entry: class Key(models.Model): """Key instance in a keyset. - It is expected that Key is immutable except for is_primary, status field""" + It is expected that Key is immutable except for the status field""" keyset = models.ForeignKey(Keyset, on_delete=models.CASCADE, editable=False) - is_primary = models.BooleanField() + type_url = models.CharField(max_length=250) # Serialized KeyData key_data = EncryptedBinaryField(keyset="db_keyset", editable=False) @@ -189,12 +186,15 @@ def key(self) -> tink_pb2.Keyset.Key: @cached_property def key_info(self) -> tink_pb2.KeysetInfo.KeyInfo: return tink_pb2.KeysetInfo.KeyInfo( - type_url=self.keyset.type_url, + type_url=self.type_url, status=self.status, key_id=self.id, output_prefix_type=self.output_prefix_type, ) + def key_manager(self) -> "KeyManager": + return Registry.key_manager(self.type_url) + @property def key_data_pb(self) -> tink_pb2.KeyData: out = tink_pb2.KeyData() @@ -203,10 +203,7 @@ def key_data_pb(self) -> tink_pb2.KeyData: @property def primitive(self) -> P: - return Registry.primitive( - self.key_data_pb, - self.keyset.key_manager().primitive_class(), - ) + return self.key_manager().primitive(self.key_data_pb) @cached_property def entry(self) -> Entry: @@ -222,8 +219,8 @@ def entry(self) -> Entry: def from_key(cls, keyset: "Keyset", key: tink_pb2.Keyset.Key): return cls( id=key.key_id, - is_primary=False, keyset=keyset, + type_url=key.key_data.type_url, data=key.key_data.SerializeToString(), status=key.status, output_prefix=crypto_format.output_prefix(key), @@ -240,7 +237,7 @@ def create_from_keydata( with transaction.atomic(): out = cls( keyset=keyset, - is_primary=False, + type_url=keydata.type_url, key_data=keydata.SerializeToString(), status=tink_pb2.ENABLED, output_prefix_type=output_prefix_type, @@ -251,12 +248,3 @@ def create_from_keydata( out.save() return out - - class Meta: - constraints = [ - models.UniqueConstraint( - name="one_primary_per_keyset", - fields=("keyset", "is_primary"), - condition=models.Q(is_primary=True), - ), - ] diff --git a/tink_fields/test/test_models.py b/tink_fields/test/test_models.py index aa051e3..3a4b088 100644 --- a/tink_fields/test/test_models.py +++ b/tink_fields/test/test_models.py @@ -8,46 +8,67 @@ def test_aead_encrypt(db): key_template = aead.aead_key_templates.AES256_GCM ks = Keyset.create("aead", key_template) + primitive = ks.primitive(aead.Aead) - ciphertext = ks.primitive.encrypt(TEST_DATA, ASSOC_DATA) + ciphertext = primitive.encrypt(TEST_DATA, ASSOC_DATA) - assert ks.primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + assert primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA def test_aead_decrypt_non_primary(db): key_template = aead.aead_key_templates.AES256_GCM ks = Keyset.create("aead", key_template) - ciphertext = ks.primitive.encrypt(TEST_DATA, ASSOC_DATA) + ciphertext = ks.primitive(aead.Aead).encrypt(TEST_DATA, ASSOC_DATA) secondary_key = ks.generate_key(key_template) ks.set_primary_key(secondary_key) - secondary_key.refresh_from_db() - assert secondary_key.is_primary + ks.refresh_from_db() + assert ks.primary_key == secondary_key - assert ks.primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + assert ks.primitive(aead.Aead).decrypt(ciphertext, ASSOC_DATA) == TEST_DATA def test_daead_encrypt(db): key_template = daead.deterministic_aead_key_templates.AES256_SIV ks = Keyset.create("daead", key_template) + primitive = ks.primitive(daead.DeterministicAead) - ciphertext = ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) + ciphertext = primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) - assert ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) == ciphertext - assert ks.primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA + assert primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) == ciphertext + assert primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA def test_aead_decrypt_non_primary(db): key_template = daead.deterministic_aead_key_templates.AES256_SIV ks = Keyset.create("daead", key_template) - ciphertext = ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) + primitive = ks.primitive(daead.DeterministicAead) + ciphertext = primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) secondary_key = ks.generate_key(key_template) ks.set_primary_key(secondary_key) - secondary_key.refresh_from_db() - assert secondary_key.is_primary + ks.refresh_from_db() + assert ks.primary_key == secondary_key - assert ks.primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) != ciphertext - assert ks.primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA + primitive = ks.primitive(daead.DeterministicAead) + assert primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) != ciphertext + assert primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_multiple_key_templates(db): + ks = Keyset.create("aead", aead.aead_key_templates.AES256_GCM) + primitive = ks.primitive(aead.Aead) + ciphertext = primitive.encrypt(TEST_DATA, ASSOC_DATA) + + key_2 = ks.generate_key(aead.aead_key_templates.AES128_CTR_HMAC_SHA256) + ks.set_primary_key(key_2) + ciphertext2 = primitive.encrypt(TEST_DATA, ASSOC_DATA) + + assert len(ciphertext) != len(ciphertext2) + + ks = Keyset.objects.get(name="aead") + primitive = ks.primitive(aead.Aead) + assert primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + assert primitive.decrypt(ciphertext2, ASSOC_DATA) == TEST_DATA