diff --git a/README.md b/README.md index d393344..802ceea 100644 --- a/README.md +++ b/README.md @@ -89,6 +89,12 @@ class AnotherModel(models.Model): Supported field types include: `EncryptedCharField`, `EncryptedTextField`, `EncryptedDateField`, `EncryptedDateTimeField`, `EncryptedEmailField`, and `EncryptedIntegerField`. +## Deterministic Encryption + +`DeterministicEncryptedCharField` provides support for [Deterministic AEAD](https://developers.google.com/tink/deterministic-aead) which means value in the field can be queried with exact matches. However, unlike normal AEAD encryption, an attacker can verify that two messages are equal. + +Deterministic encryption requires key of type `AES-SIV` and supports Associated Data. + ### Associated Data The encrypted fields make use of `Authenticated Encryption With Associated Data (AEAD)` which offers confidentiality and integrity within the same mode of operation. This allows the caller to specify a cleartext fragment named `additional authenticated data (aad)` to the encryption and decryption operations and receive cryptographic guarantees that the ciphertext data has not been tampered with. diff --git a/tink_fields/__init__.py b/tink_fields/__init__.py index f4fbd52..9b9dd43 100644 --- a/tink_fields/__init__.py +++ b/tink_fields/__init__.py @@ -1,4 +1,5 @@ from .fields import * # noqa -from tink import aead +from tink import aead, daead aead.register() +daead.register() diff --git a/tink_fields/config.py b/tink_fields/config.py new file mode 100644 index 0000000..caab5f6 --- /dev/null +++ b/tink_fields/config.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass +from os.path import exists +from typing import Optional, TypeVar, Type, TYPE_CHECKING +from django.utils.functional import cached_property + +from django.core.exceptions import ImproperlyConfigured +from tink import ( + aead, + KeysetHandle, + JsonKeysetReader, + cleartext_keyset_handle, + read_keyset_handle, +) + +if TYPE_CHECKING: + from .models import Keyset + +P = TypeVar("P") + + +@dataclass +class KeysetConfig: + path: Optional[str] = None + db_name: Optional[str] = None + master_key_aead: Optional[aead.Aead] = None + cleartext: bool = False + + def validate(self): + if not self.path and not self.db_name: + raise ImproperlyConfigured("Keyset path or db_name must be set") + if self.db_name and self.path: + raise ImproperlyConfigured("Only one of keyset path or db_name must be set") + + if self.path: + if not exists(self.path): + raise ImproperlyConfigured(f"Keyset {self.path} does not exist") + + if not self.cleartext and self.master_key_aead is None: + raise ImproperlyConfigured( + f"Encrypted keysets must specify `master_key_aead`" + ) + + def primitive(self, cls: Type[P]) -> P: + if self.path: + return self._load_from_path.primitive(cls) + if self.db_name: + return self._load_from_db.primitive(cls) + + @cached_property + def _load_from_path(self) -> KeysetHandle: + with open(self.path, "r") as f: + reader = JsonKeysetReader(f.read()) + if self.cleartext: + return cleartext_keyset_handle.read(reader) + return read_keyset_handle(reader, self.master_key_aead) + + @cached_property + def _load_from_db(self) -> "Keyset": + from .models import Keyset + + keyset = Keyset.objects.get(name=self.db_name) + return keyset diff --git a/tink_fields/fields.py b/tink_fields/fields.py index 01a181b..4cc4359 100644 --- a/tink_fields/fields.py +++ b/tink_fields/fields.py @@ -1,19 +1,19 @@ -from functools import lru_cache -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, TYPE_CHECKING, Dict from django.db import models +from django.db.models.lookups import In from django.core.exceptions import FieldError, ImproperlyConfigured +from django.utils.functional import cached_property from tink import ( - KeysetHandle, - cleartext_keyset_handle, - read_keyset_handle, - JsonKeysetReader, aead, + daead, ) from django.conf import settings -from dataclasses import dataclass -from os.path import exists from django.utils.encoding import force_bytes, force_str -from django.db.backends.base.base import BaseDatabaseWrapper + +from tink_fields.config import KeysetConfig + +if TYPE_CHECKING: + from django.db.backends.base.base import BaseDatabaseWrapper __all__ = [ @@ -24,34 +24,46 @@ "EncryptedIntegerField", "EncryptedDateField", "EncryptedDateTimeField", + "EncryptedBinaryField", + "DeterministicEncryptedField", + "DeterministicEncryptedCharField", + "DeterministicEncryptedEmailField", + "DeterministicEncryptedIntegerField", ] +_config: Dict[str, KeysetConfig] = {} -@dataclass -class KeysetConfig: - path: str - master_key_aead: Optional[aead.Aead] = None - cleartext: bool = False - def validate(self): - if not self.path: - raise ImproperlyConfigured("Keyset path cannot be None or empty") +def _get_config(keyset: str) -> KeysetConfig: + global _config - if not exists(self.path): - raise ImproperlyConfigured(f"Keyset {self.path} does not exist") - - if not self.cleartext and self.master_key_aead is None: - raise ImproperlyConfigured(f"Encrypted keysets must specify `master_key_aead`") - + if keyset in _config: + return _config[keyset] -class EncryptedField(models.Field): - """A field that uses Tink primitives to protect the confidentiality and integrity of data""" + config = getattr(settings, "TINK_FIELDS_CONFIG", None) + if config is None: + raise ImproperlyConfigured( + f"Could not find `TINK_FIELDS_CONFIG` attribute in settings" + ) + + if keyset not in config: + raise ImproperlyConfigured( + f"Could not find configuration for keyset `{keyset}` in `TINK_FIELDS_CONFIG`" + ) + + keyset_config = KeysetConfig(**config[keyset]) + keyset_config.validate() + _config[keyset] = keyset_config + + return keyset_config + +class BaseEncryptedField(models.Field): _unsupported_properties = ["primary_key", "db_index", "unique"] _internal_type = "BinaryField" _keyset: str - _keyset_handle: KeysetHandle + _keyset_config: KeysetConfig _aad_callback: Callable[[models.Field], bytes] def __init__(self, *args, **kwargs): @@ -62,49 +74,78 @@ def __init__(self, *args, **kwargs): ) self._keyset = kwargs.pop("keyset", "default") - self._keyset_handle = self._get_tink_keyset_handle() + self._keyset_config = self._get_config() self._aad_callback = kwargs.pop("aad_callback", lambda x: b"") - super(EncryptedField, self).__init__(*args, **kwargs) + super(BaseEncryptedField, self).__init__(*args, **kwargs) - def _get_config(self) -> Dict[str, Any]: - config = getattr(settings, "TINK_FIELDS_CONFIG", None) - if config is None: - raise ImproperlyConfigured( - f"Could not find `TINK_FIELDS_CONFIG` attribute in settings" - ) - return config + def _get_config(self) -> KeysetConfig: + return _get_config(self._keyset) - def _get_tink_keyset_handle(self) -> KeysetHandle: - """Read the configuration for the requested keyset and return a respective keyset handle""" - config = self._get_config() + def get_internal_type(self) -> str: + return self._internal_type - if self._keyset not in config: - raise ImproperlyConfigured( - f"Could not find configuration for keyset `{self._keyset}` in `TINK_FIELDS_CONFIG`" - ) + @cached_property + def validators(self): + # Temporarily pretend to be whatever type of field we're masquerading + # as, for purposes of constructing validators (needed for + # IntegerField and subclasses). + self.__dict__["_internal_type"] = super( + BaseEncryptedField, self + ).get_internal_type() + try: + return super(BaseEncryptedField, self).validators + finally: + del self.__dict__["_internal_type"] - keyset_config = KeysetConfig(**config[self._keyset]) - keyset_config.validate() + def to_python_prepare(self, value: bytes) -> Any: + if isinstance(self, models.BinaryField): + return value - with open(keyset_config.path, "r") as f: - reader = JsonKeysetReader(f.read()) - if keyset_config.cleartext: - return cleartext_keyset_handle.read(reader) - return read_keyset_handle(reader, keyset_config.master_key_aead) + return force_str(value) - @lru_cache(maxsize=None) - def _get_aead_primitive(self) -> aead.Aead: - return self._keyset_handle.primitive(aead.Aead) - def get_internal_type(self) -> str: - return self._internal_type +class EncryptedField(BaseEncryptedField): + """A field that uses Tink primitives to protect the confidentiality and integrity of data""" + + @cached_property + def _aead_primitive(self) -> aead.Aead: + return self._keyset_config.primitive(aead.Aead) - def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: + def get_db_prep_save(self, value: Any, connection: "BaseDatabaseWrapper") -> Any: val = super(EncryptedField, self).get_db_prep_save(value, connection) if val is not None: return connection.Database.Binary( - self._get_aead_primitive().encrypt( + self._aead_primitive.encrypt(force_bytes(val), self._aad_callback(self)) + ) + + def from_db_value(self, value, expression, connection, *args): + if value is not None: + return self.to_python( + self.to_python_prepare( + self._aead_primitive.decrypt(bytes(value), self._aad_callback(self)) + ) + ) + + +class DeterministicEncryptedField(BaseEncryptedField): + """Field that is similar to EncryptedField, but support exact match lookups""" + + _unsupported_properties = [] + + @cached_property + def _daead_primitive(self) -> daead.DeterministicAead: + return self._keyset_config.primitive(daead.DeterministicAead) + + def get_db_prep_value( + self, value: Any, connection: "BaseDatabaseWrapper", prepared=False + ) -> Any: + val = super(DeterministicEncryptedField, self).get_db_prep_value( + value, connection, prepared + ) + if val is not None: + return connection.Database.Binary( + self._daead_primitive.encrypt_deterministically( force_bytes(val), self._aad_callback(self) ) ) @@ -112,26 +153,37 @@ def get_db_prep_save(self, value: Any, connection: BaseDatabaseWrapper) -> Any: def from_db_value(self, value, expression, connection, *args): if value is not None: return self.to_python( - force_str( - self._get_aead_primitive().decrypt( + self.to_python_prepare( + self._daead_primitive.decrypt_deterministically( bytes(value), self._aad_callback(self) ) ) ) - @property - @lru_cache(maxsize=None) - def validators(self): - # Temporarily pretend to be whatever type of field we're masquerading - # as, for purposes of constructing validators (needed for - # IntegerField and subclasses). - self.__dict__["_internal_type"] = super( - EncryptedField, self - ).get_internal_type() - try: - return super(EncryptedField, self).validators - finally: - del self.__dict__["_internal_type"] + def get_db_values_all_keys( + self, value: Any, connection: "BaseDatabaseWrapper", prepared=False + ) -> Any: + """Like get_db_prep_value but return array of values encrypted with every keys in the keyset""" + val = super(DeterministicEncryptedField, self).get_db_prep_value( + value, connection, prepared + ) + if val is None: + return [] + + out = [] + aad = self._aad_callback(self) + # XXX: This would run another query. Is there any way to signal that we want a cached all using + # the same primitive set interface? + for items in self._daead_primitive._primitive_set.all(): + for key in items: + out.append( + connection.Database.Binary( + key.identifier + + key.primitive.encrypt_deterministically(force_bytes(val), aad) + ) + ) + + return out def get_prep_lookup(self): @@ -143,12 +195,32 @@ def get_prep_lookup(self): ) -for name, lookup in models.Field.class_lookups.items(): - if name != "isnull": - lookup_class = type( - "EncryptedField" + name, (lookup,), {"get_prep_lookup": get_prep_lookup} +class DeterministicEncryptedFieldExactLookup(In): + lookup_name = "exact" + + def get_prep_lookup(self): + self.rhs = [self.rhs] + return super().get_prep_lookup() + + def get_db_prep_lookup(self, value, connection): + assert len(value) == 1 + return ( + "%s", + self.lhs.output_field.get_db_values_all_keys( + list(value)[0], connection, prepared=True + ), ) - EncryptedField.register_lookup(lookup_class) + + +for name, lookup in models.Field.class_lookups.items(): + for cls in (EncryptedField, DeterministicEncryptedField): + if name != "isnull": + lookup_class = type( + cls.__name__ + name, (lookup,), {"get_prep_lookup": get_prep_lookup} + ) + cls.register_lookup(lookup_class) + +DeterministicEncryptedField.register_lookup(DeterministicEncryptedFieldExactLookup) class EncryptedTextField(EncryptedField, models.TextField): @@ -173,3 +245,21 @@ class EncryptedDateField(EncryptedField, models.DateField): class EncryptedDateTimeField(EncryptedField, models.DateTimeField): pass + + +class EncryptedBinaryField(EncryptedField, models.BinaryField): + """Encrypted raw binary data, must be under 2^32 bytes (4.295GB)""" + + +class DeterministicEncryptedCharField(DeterministicEncryptedField, models.CharField): + pass + + +class DeterministicEncryptedEmailField(DeterministicEncryptedField, models.EmailField): + pass + + +class DeterministicEncryptedIntegerField( + DeterministicEncryptedField, models.IntegerField +): + pass diff --git a/tink_fields/management/__init__.py b/tink_fields/management/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tink_fields/management/commands/__init__.py b/tink_fields/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tink_fields/management/commands/tink.py b/tink_fields/management/commands/tink.py new file mode 100644 index 0000000..839de0e --- /dev/null +++ b/tink_fields/management/commands/tink.py @@ -0,0 +1,125 @@ +from django.core.management.base import BaseCommand, CommandError, CommandParser +from tink.proto import tink_pb2 +from google.protobuf import json_format, text_format + +from tink_fields.models import Keyset, Key +from tink.aead import aead_key_templates +from tink.daead import deterministic_aead_key_templates + + +def get_key_template_by_name(name: str) -> tink_pb2.KeyTemplate: + template = getattr(aead_key_templates, name, None) + if template: + return template + + return getattr(deterministic_aead_key_templates, name) + + +class Command(BaseCommand): + help = "Tink key management" + + def add_arguments(self, parser: CommandParser): + subparsers = parser.add_subparsers(dest="subcommand", required=True) + + create_keyset = subparsers.add_parser("create-keyset", help="Create new keyset") + create_keyset.add_argument("name", help="Key name") + create_keyset.add_argument( + "template", + help="Key template (see tinkey list-key-templates)", + ) + + add_key = subparsers.add_parser( + "add-key", help="Add a non-primary key to a keyset" + ) + add_key.add_argument("name", help="Keyset name") + add_key.add_argument( + "template", help="Key template (see tinkey list-key-templates)" + ) + + promote_key = subparsers.add_parser( + "promote-key", help="Promote key to primary in a keyset" + ) + promote_key.add_argument("name", help="Keyset name") + promote_key.add_argument("id", help="Key ID", type=int) + + list_keyset = subparsers.add_parser("list-keyset", help="List keys a keyset") + list_keyset.add_argument("name", help="Keyset name") + + delete_keyset = subparsers.add_parser( + "delete-keyset", help="Delete keyset and all associated keys" + ) + delete_keyset.add_argument("name", help="Keyset name") + + unsafe_export_keyset = subparsers.add_parser( + "unsafe-export-keyset", + help="Export keyset (INCLUDING KEY MATERIALS) as JSON", + ) + unsafe_export_keyset.add_argument("name", help="Keyset name") + + def handle(self, *args, **options): + subcommand_map = { + "create_keyset": self.create_keyset, + "add-key": self.add_key, + "promote-key": self.promote_key, + "list-keyset": self.list_keyset, + "unsafe-export-keyset": self.unsafe_export_keyset, + "delete-keyset": self.delete_keyset, + } + + return subcommand_map[options["subcommand"]](*args, **options) + + def create_keyset(self, name: str, template: str, *args, **options): + keyset = Keyset.create(name, get_key_template_by_name(template)) + self.stdout.write(self.style.SUCCESS(f"Created keyset {keyset.id}")) + + def add_key(self, name: str, template: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + key = keyset.generate_key(get_key_template_by_name(template)) + self.stdout.write(self.style.SUCCESS(f"Created key {key.id}")) + + def promote_key(self, name: str, id: int, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + try: + key = keyset.key_set.get(pk=id) + except Key.DoesNotExist: + raise CommandError(f"Key ID {id} not found in keyset") + + keyset.set_primary_key(key) + self.stdout.write(f"Key {key.pk} promoted to primary") + + def list_keyset(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + self.stdout.write(text_format.MessageToString(keyset.export_keyset_info())) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + def delete_keyset(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + keyset.delete() + self.stdout.write(self.style.SUCCESS(f'Deleted keyset "{name}"')) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + def unsafe_export_keyset(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + self.stdout.write(json_format.MessageToJson(keyset.export_keyset())) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') + + def export_keyset_info(self, name: str, *args, **options): + try: + keyset = Keyset.objects.get(name=name) + self.stdout.write(json_format.MessageToJson(keyset.export_keyset_info())) + except Keyset.DoesNotExist: + raise CommandError(f'Keyset "{name}" not found') diff --git a/tink_fields/migrations/0001_initial.py b/tink_fields/migrations/0001_initial.py new file mode 100644 index 0000000..fb88991 --- /dev/null +++ b/tink_fields/migrations/0001_initial.py @@ -0,0 +1,88 @@ +# Generated by Django 4.1a1 on 2022-07-08 11:06 + +from django.db import migrations, models +import django.db.models.deletion +import tink_fields.fields + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="Key", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("type_url", models.CharField(max_length=250)), + ("key_data", tink_fields.fields.EncryptedBinaryField()), + ( + "status", + models.PositiveIntegerField( + choices=[ + ("UNKNOWN_STATUS", 0), + ("ENABLED", 1), + ("DISABLED", 2), + ("DESTROYED", 3), + ] + ), + ), + ( + "output_prefix_type", + models.PositiveIntegerField( + choices=[ + ("UNKNOWN_PREFIX", 0), + ("TINK", 1), + ("LEGACY", 2), + ("RAW", 3), + ("CRUNCHY", 4), + ] + ), + ), + ("output_prefix", models.BinaryField()), + ], + ), + migrations.CreateModel( + name="Keyset", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("name", models.CharField(max_length=100, unique=True)), + ( + "primary_key", + models.ForeignKey( + null=True, + on_delete=django.db.models.deletion.PROTECT, + related_name="primary_key", + to="tink_fields.key", + ), + ), + ], + ), + migrations.AddField( + model_name="key", + name="keyset", + field=models.ForeignKey( + editable=False, + on_delete=django.db.models.deletion.CASCADE, + to="tink_fields.keyset", + ), + ), + ] diff --git a/tink_fields/migrations/__init__.py b/tink_fields/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tink_fields/models.py b/tink_fields/models.py new file mode 100644 index 0000000..c99ed70 --- /dev/null +++ b/tink_fields/models.py @@ -0,0 +1,250 @@ +from typing import List, TypeVar, Type, Optional, Set, Union + +from django.db import models +from django.db import transaction +from django.utils.functional import cached_property +from tink.proto import tink_pb2 + +from .fields import EncryptedBinaryField +from tink.core import PrimitiveSet, Registry, KeyManager +from tink.core._primitive_set import Entry +from tink.core import crypto_format + +__all__ = ["Keyset", "Key"] +P = TypeVar("P") + + +class Keyset(models.Model): + name = models.CharField(max_length=100, unique=True) + primary_key = models.ForeignKey( + "Key", on_delete=models.PROTECT, related_name="primary_key", null=True + ) + + @classmethod + def create(cls, name: str, key_template: tink_pb2.KeyTemplate) -> "Keyset": + """Create a keyset with one primary key""" + with transaction.atomic(): + instance = cls(name=name) + instance.save() + instance.primary_key = instance.generate_key(key_template) + + return instance + + def generate_key(self, key_template: tink_pb2.KeyTemplate) -> "Key": + """Create and save a key""" + key_data = Registry.new_key_data(key_template) + key = Key.create_from_keydata(self, key_data, key_template.output_prefix_type) + + return key + + def set_primary_key(self, key: Union["Key", int]): + key_id = key + if isinstance(key, Key): + key_id = key.pk + + self.primary_key_id = key_id + self.save(update_fields=["primary_key"]) + + def primitive(self, primitive_class: Type[P]) -> P: + """Get primitive of the stored type""" + # Hack: assuming that a key would only contain one primitive_class + self._primitive_set._primitive_class = primitive_class + return Registry.wrap( + self._primitive_set, + primitive_class, + ) + + @cached_property + def _primitive_set(self) -> PrimitiveSet: + # XXX: This would return keyset of the concrete type instead of interface type + return _DatabasePrimitiveKeyset(self, self.primary_key.primitive) + + def export_keyset(self) -> tink_pb2.Keyset: + return tink_pb2.Keyset( + primary_key_id=self.primary_key_id, + key=[key.key for key in self.key_set.all()], + ) + + def export_keyset_info(self) -> tink_pb2.KeysetInfo: + return tink_pb2.KeysetInfo( + primary_key_id=self.primary_key_id, + key_info=[key.key_info for key in self.key_set.all()], + ) + + +class _DatabasePrimitiveKeyset(PrimitiveSet[P]): + _keyset: Keyset + + def __init__(self, keyset: Keyset, primitive_class: Type[P]): + super().__init__(primitive_class) + self._keyset = keyset + del self._primary + + def primitive_from_identifier(self, identifier: bytes) -> List[Entry]: + # Fast path - non raw keys should have unique identifiers + if len(identifier) > 0 and identifier in self._primitives: + return super().primitive_from_identifier(identifier) + + for key in ( + self._keyset.key_set.filter(output_prefix=identifier) + .exclude(id__in=list(self._all_cached_key_ids())) + .all() + ): + self._add_key_to_cache(key) + + return super().primitive_from_identifier(identifier) + + def _entry_by_id(self, identifier: bytes, key_id: int) -> Optional[Entry]: + # Fast path - non raw keys should have unique identifiers + if ( + len(identifier) > 0 + and identifier in self._primitives + and len(self._primitives[identifier]) == 1 + ): + return self._primitives[identifier][0] + + primitives = self._primitives.get(identifier, []) + for item in primitives: + if item.key_id == key_id: + return item + + def _add_key_to_cache(self, key: "Key"): + if not self._entry_by_id(key.output_prefix, key.id): + entries = self._primitives.setdefault(key.output_prefix, []) + entries.append(key.entry) + + def _add_entry_to_cache(self, entry: Entry): + entries = self._primitives.setdefault(entry.identifier, []) + entries.append(entry) # XXX: This does not check for dupes + + def _all_cached_key_ids(self) -> Set[int]: + out = set() + for keys in self._primitives.values(): + for key in keys: + out.add(key.key_id) + + return out + + def all(self) -> List[List[Entry]]: + for key in self._keyset.key_set.exclude( + id__in=list(self._all_cached_key_ids()) + ).all(): + self._add_key_to_cache(key) + + return super().all() + + def add_primitive(self, primitive: P, key: tink_pb2.Keyset.Key) -> Entry: + assert isinstance(primitive, self._primitive_class) + + key = Key.from_key(self._keyset, key) + key.save() + + self._add_key_to_cache(key) + + return key.entry + + def set_primary(self, entry: Entry) -> None: + self._keyset.set_primary_key(entry.key_id) + + def primary(self) -> Entry: + key = self._keyset.primary_key + entry = self._entry_by_id(key.output_prefix, key.id) + if entry: + return entry + + self._add_key_to_cache(key) + return key.entry + + +class Key(models.Model): + """Key instance in a keyset. + + It is expected that Key is immutable except for the status field""" + + keyset = models.ForeignKey(Keyset, on_delete=models.CASCADE, editable=False) + type_url = models.CharField(max_length=250) + + # Serialized KeyData + key_data = EncryptedBinaryField(keyset="db_keyset", editable=False) + status = models.PositiveIntegerField(choices=tink_pb2.KeyStatusType.items()) + output_prefix_type = models.PositiveIntegerField( + choices=tink_pb2.OutputPrefixType.items() + ) + # Output prefix can be derived from output_prefix_type + id, however + # we store it here to be able to lookup without parsing Tink format ourselves + output_prefix = models.BinaryField(editable=False) + + @cached_property + def key(self) -> tink_pb2.Keyset.Key: + return tink_pb2.Keyset.Key( + key_data=self.key_data_pb, + status=self.status, + key_id=self.id, + output_prefix_type=self.output_prefix_type, + ) + + @cached_property + def key_info(self) -> tink_pb2.KeysetInfo.KeyInfo: + return tink_pb2.KeysetInfo.KeyInfo( + type_url=self.type_url, + status=self.status, + key_id=self.id, + output_prefix_type=self.output_prefix_type, + ) + + def key_manager(self) -> "KeyManager": + return Registry.key_manager(self.type_url) + + @property + def key_data_pb(self) -> tink_pb2.KeyData: + out = tink_pb2.KeyData() + out.ParseFromString(self.key_data) + return out + + @property + def primitive(self) -> P: + return self.key_manager().primitive(self.key_data_pb) + + @cached_property + def entry(self) -> Entry: + return Entry( + primitive=self.primitive, + identifier=self.output_prefix, + status=self.status, + output_prefix_type=self.output_prefix_type, + key_id=self.id, + ) + + @classmethod + def from_key(cls, keyset: "Keyset", key: tink_pb2.Keyset.Key): + return cls( + id=key.key_id, + keyset=keyset, + type_url=key.key_data.type_url, + data=key.key_data.SerializeToString(), + status=key.status, + output_prefix=crypto_format.output_prefix(key), + output_prefix_type=key.output_prefix_type, + ) + + @classmethod + def create_from_keydata( + cls, + keyset: "Keyset", + keydata: tink_pb2.KeyData, + output_prefix_type: tink_pb2.OutputPrefixType, + ): + with transaction.atomic(): + out = cls( + keyset=keyset, + type_url=keydata.type_url, + key_data=keydata.SerializeToString(), + status=tink_pb2.ENABLED, + output_prefix_type=output_prefix_type, + ) + out.save() + out.refresh_from_db() + out.output_prefix = crypto_format.output_prefix(out.key) + out.save() + + return out diff --git a/tink_fields/test/models.py b/tink_fields/test/models.py index 61e917b..10c22d7 100644 --- a/tink_fields/test/models.py +++ b/tink_fields/test/models.py @@ -27,10 +27,18 @@ class EncryptedDateTime(models.Model): value = fields.EncryptedDateTimeField() +class EncryptedBinary(models.Model): + value = fields.EncryptedBinaryField() + + class EncryptedNullable(models.Model): value = fields.EncryptedIntegerField(null=True) +class EncryptedIntEnvelope(models.Model): + value = fields.EncryptedIntegerField(keyset="db_aead") + + def sample_aad_provider(instance) -> bytes: return force_bytes(instance.__class__.__name__) @@ -41,3 +49,23 @@ class EncryptedCharWithFixedAad(models.Model): class EncryptedCharWithAlternateKeyset(models.Model): value = fields.EncryptedCharField(max_length=25, keyset="alternate") + + +class DeterministicEncryptedChar(models.Model): + value = fields.DeterministicEncryptedCharField(max_length=25, keyset="daead") + + +class DeterministicEncryptedEmail(models.Model): + value = fields.DeterministicEncryptedEmailField(keyset="daead") + + +class DeterministicEncryptedInt(models.Model): + value = fields.DeterministicEncryptedIntegerField(keyset="daead") + + +class DeterministicEncryptedIntNullable(models.Model): + value = fields.DeterministicEncryptedIntegerField(keyset="daead", null=True) + + +class DeterministicEncryptedIntEnvelope(models.Model): + value = fields.DeterministicEncryptedIntegerField(keyset="db_daead") diff --git a/tink_fields/test/settings/base.py b/tink_fields/test/settings/base.py index bc23f2c..7e5db41 100644 --- a/tink_fields/test/settings/base.py +++ b/tink_fields/test/settings/base.py @@ -1,4 +1,5 @@ INSTALLED_APPS = [ + "tink_fields", "tink_fields.test", ] diff --git a/tink_fields/test/settings/sqlite.py b/tink_fields/test/settings/sqlite.py index 5ca7642..956bd27 100644 --- a/tink_fields/test/settings/sqlite.py +++ b/tink_fields/test/settings/sqlite.py @@ -26,4 +26,18 @@ "cleartext": True, "path": os.path.join(HERE, "../test_plaintext_keyset.json"), }, + "daead": { + "cleartext": True, + "path": os.path.join(HERE, "../test_plaintext_daead_keyset.json"), + }, + "db_keyset": { + "cleartext": True, + "path": os.path.join(HERE, "../test_plaintext_keyset.json"), + }, + "db_aead": { + "db_name": "aead", + }, + "db_daead": { + "db_name": "daead", + }, } diff --git a/tink_fields/test/test_fields.py b/tink_fields/test/test_fields.py index db2c159..ef61418 100644 --- a/tink_fields/test/test_fields.py +++ b/tink_fields/test/test_fields.py @@ -1,10 +1,22 @@ from datetime import date, datetime -from django.db import connection, models as dj_models -from django.utils.encoding import force_bytes, force_str +from django.db import connection +from django.utils.encoding import force_bytes import pytest +from tink import aead, daead from . import models +from ..models import Keyset + + +@pytest.fixture(autouse=True) +def configured_db_keyset(db): + for name, template in { + "aead": aead.aead_key_templates.AES256_GCM, + "daead": daead.deterministic_aead_key_templates.AES256_SIV, + }.items(): + model = Keyset.create(name=name, key_template=template) + model.save() @pytest.mark.parametrize( @@ -21,6 +33,8 @@ [datetime(2015, 2, 5, 15), datetime(2015, 2, 8, 16)], ), (models.EncryptedCharWithAlternateKeyset, ["foo", "bar"]), + (models.EncryptedIntEnvelope, [1, 2]), + (models.EncryptedBinary, [b"1234", b"asdf"]), ], ) class TestEncryptedFieldQueries(object): @@ -32,8 +46,45 @@ def test_insert(self, db, model, vals): with connection.cursor() as cur: cur.execute("SELECT value FROM %s" % model._meta.db_table) data = [ - force_str( - field._get_aead_primitive().decrypt( + field.to_python_prepare( + field._aead_primitive.decrypt( + force_bytes(r[0]), aad_callback(field) + ) + ) + for r in cur.fetchall() + ] + + if model is models.EncryptedBinary: + assert list([bytes(field.to_python(item)) for item in data]) == [vals[0]] + else: + assert list(map(field.to_python, data)) == [vals[0]] + + +def test_encrypted_nullable(db): + models.EncryptedNullable(value=None).save() + assert models.EncryptedNullable.objects.get(value__isnull=True) + + +@pytest.mark.parametrize( + "model,vals", + [ + (models.DeterministicEncryptedChar, ["one", "two"]), + (models.DeterministicEncryptedEmail, ["a@example.com", "b@example.com"]), + (models.DeterministicEncryptedInt, [1, 2]), + (models.DeterministicEncryptedIntEnvelope, [1, 2]), + ], +) +class TestDeterministicEncryptedFieldQueries(object): + def test_insert(self, db, model, vals): + """Data stored in DB is actually encrypted.""" + field = model._meta.get_field("value") + aad_callback = getattr(field, "_aad_callback") + model.objects.create(value=vals[0]) + with connection.cursor() as cur: + cur.execute("SELECT value FROM %s" % model._meta.db_table) + data = [ + field.to_python_prepare( + field._daead_primitive.decrypt_deterministically( force_bytes(r[0]), aad_callback(field) ) ) @@ -41,3 +92,36 @@ def test_insert(self, db, model, vals): ] assert list(map(field.to_python, data)) == [vals[0]] + + def test_search(self, db, model, vals): + model.objects.create(value=vals[0]) + model.objects.create(value=vals[1]) + out = model.objects.filter(value=vals[0]) + assert len(out) == 1 + assert out[0].value == vals[0] + + +def test_rotated_deterministic_field(db): + value = 12345678 + + models.DeterministicEncryptedIntEnvelope(value=value).save() + assert ( + models.DeterministicEncryptedIntEnvelope.objects.filter(value=value).count() + == 1 + ) + + keyset = Keyset.objects.get(name="daead") + new_key = keyset.generate_key(daead.deterministic_aead_key_templates.AES256_SIV) + keyset.set_primary_key(new_key) + + value_2 = 23456789 + models.DeterministicEncryptedIntEnvelope(value=value_2).save() + + assert ( + models.DeterministicEncryptedIntEnvelope.objects.filter(value=value).count() + == 1 + ) + assert ( + models.DeterministicEncryptedIntEnvelope.objects.filter(value=value_2).count() + == 1 + ) diff --git a/tink_fields/test/test_models.py b/tink_fields/test/test_models.py new file mode 100644 index 0000000..3a4b088 --- /dev/null +++ b/tink_fields/test/test_models.py @@ -0,0 +1,74 @@ +from tink import aead, daead +from tink_fields.models import Keyset + +TEST_DATA = b"hello world" +ASSOC_DATA = b"ad" + + +def test_aead_encrypt(db): + key_template = aead.aead_key_templates.AES256_GCM + ks = Keyset.create("aead", key_template) + primitive = ks.primitive(aead.Aead) + + ciphertext = primitive.encrypt(TEST_DATA, ASSOC_DATA) + + assert primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_aead_decrypt_non_primary(db): + key_template = aead.aead_key_templates.AES256_GCM + ks = Keyset.create("aead", key_template) + + ciphertext = ks.primitive(aead.Aead).encrypt(TEST_DATA, ASSOC_DATA) + + secondary_key = ks.generate_key(key_template) + ks.set_primary_key(secondary_key) + ks.refresh_from_db() + assert ks.primary_key == secondary_key + + assert ks.primitive(aead.Aead).decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_daead_encrypt(db): + key_template = daead.deterministic_aead_key_templates.AES256_SIV + ks = Keyset.create("daead", key_template) + primitive = ks.primitive(daead.DeterministicAead) + + ciphertext = primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) + + assert primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) == ciphertext + assert primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_aead_decrypt_non_primary(db): + key_template = daead.deterministic_aead_key_templates.AES256_SIV + ks = Keyset.create("daead", key_template) + + primitive = ks.primitive(daead.DeterministicAead) + ciphertext = primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) + + secondary_key = ks.generate_key(key_template) + ks.set_primary_key(secondary_key) + ks.refresh_from_db() + assert ks.primary_key == secondary_key + + primitive = ks.primitive(daead.DeterministicAead) + assert primitive.encrypt_deterministically(TEST_DATA, ASSOC_DATA) != ciphertext + assert primitive.decrypt_deterministically(ciphertext, ASSOC_DATA) == TEST_DATA + + +def test_multiple_key_templates(db): + ks = Keyset.create("aead", aead.aead_key_templates.AES256_GCM) + primitive = ks.primitive(aead.Aead) + ciphertext = primitive.encrypt(TEST_DATA, ASSOC_DATA) + + key_2 = ks.generate_key(aead.aead_key_templates.AES128_CTR_HMAC_SHA256) + ks.set_primary_key(key_2) + ciphertext2 = primitive.encrypt(TEST_DATA, ASSOC_DATA) + + assert len(ciphertext) != len(ciphertext2) + + ks = Keyset.objects.get(name="aead") + primitive = ks.primitive(aead.Aead) + assert primitive.decrypt(ciphertext, ASSOC_DATA) == TEST_DATA + assert primitive.decrypt(ciphertext2, ASSOC_DATA) == TEST_DATA diff --git a/tink_fields/test/test_plaintext_daead_keyset.json b/tink_fields/test/test_plaintext_daead_keyset.json new file mode 100644 index 0000000..d214367 --- /dev/null +++ b/tink_fields/test/test_plaintext_daead_keyset.json @@ -0,0 +1 @@ +{"primaryKeyId":508056547,"key":[{"keyData":{"typeUrl":"type.googleapis.com/google.crypto.tink.AesSivKey","value":"EkDc2ZTmEZO2wrwmfEBWTEwoRd2WrDqPikE8rseHs3Nx/exobkxiQEZtPwTM37iNdwVvSouyDLGWUjO3T3D8v0LC","keyMaterialType":"SYMMETRIC"},"status":"ENABLED","keyId":508056547,"outputPrefixType":"TINK"}]}