diff --git a/keymanager/key_protection_service/key_custody_core/cbindgen.toml b/keymanager/key_protection_service/key_custody_core/cbindgen.toml index 63241ab35..00d028f82 100644 --- a/keymanager/key_protection_service/key_custody_core/cbindgen.toml +++ b/keymanager/key_protection_service/key_custody_core/cbindgen.toml @@ -4,7 +4,7 @@ pragma_once = false autogen_warning = "/* Auto-generated by cbindgen. DO NOT EDIT. */" cpp_compat = true no_includes = true -style = "tag" +style = "type" documentation = false usize_is_size_t = true sys_includes = ["stddef.h", "stdint.h"] @@ -15,7 +15,7 @@ parse_deps = false clean = true [export] -item_types = ["functions"] -include = ["key_manager_generate_kem_keypair"] +item_types = ["functions", "structs"] +include = ["key_manager_generate_kem_keypair", "key_manager_destroy_kem_key", "key_manager_enumerate_kem_keys", "KpsKeyInfo"] diff --git a/keymanager/key_protection_service/key_custody_core/include/kps_key_custody_core.h b/keymanager/key_protection_service/key_custody_core/include/kps_key_custody_core.h index e5e573a3a..14ad64511 100644 --- a/keymanager/key_protection_service/key_custody_core/include/kps_key_custody_core.h +++ b/keymanager/key_protection_service/key_custody_core/include/kps_key_custody_core.h @@ -6,6 +6,15 @@ #include #include +typedef struct { + uint8_t uuid[16]; + uint8_t algorithm[128]; + size_t algorithm_len; + uint8_t kem_pub_key[2048]; + size_t kem_pub_key_len; + uint64_t remaining_lifespan_secs; +} KpsKeyInfo; + #ifdef __cplusplus extern "C" { #endif // __cplusplus @@ -21,6 +30,11 @@ int32_t key_manager_generate_kem_keypair(const uint8_t *algo_ptr, int32_t key_manager_destroy_kem_key(const uint8_t *uuid_bytes); +int32_t key_manager_enumerate_kem_keys(KpsKeyInfo *out_entries, + size_t max_entries, + size_t offset, + size_t *out_count); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/keymanager/key_protection_service/key_custody_core/kps_key_custody_core_cgo.go b/keymanager/key_protection_service/key_custody_core/kps_key_custody_core_cgo.go index a52909357..b44bc4dfe 100644 --- a/keymanager/key_protection_service/key_custody_core/kps_key_custody_core_cgo.go +++ b/keymanager/key_protection_service/key_custody_core/kps_key_custody_core_cgo.go @@ -60,3 +60,56 @@ func GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, li copy(pubkey, pubkeyBuf[:pubkeyLen]) return id, pubkey, nil } + +// EnumerateKEMKeys retrieves active KEM key entries from the Rust KCC registry with pagination. +func EnumerateKEMKeys(limit, offset int) ([]KEMKeyInfo, error) { + if limit <= 0 { + return nil, fmt.Errorf("limit must be positive") + } + if offset < 0 { + return nil, fmt.Errorf("offset must be non-negative") + } + + // Dynamic allocation might be better, but for now using a slice on the heap is safer than large stack usage. + // C.KpsKeyInfo is large (~2KB+), so even 256 entries is 500KB. + entries := make([]C.KpsKeyInfo, limit) + var count C.size_t + + rc := C.key_manager_enumerate_kem_keys( + &entries[0], + C.size_t(limit), + C.size_t(offset), + &count, + ) + if rc != 0 { + return nil, fmt.Errorf("key_manager_enumerate_kem_keys failed with code %d", rc) + } + + result := make([]KEMKeyInfo, count) + for i := C.size_t(0); i < count; i++ { + e := entries[i] + + id, err := uuid.FromBytes(C.GoBytes(unsafe.Pointer(&e.uuid[0]), 16)) + if err != nil { + return nil, fmt.Errorf("invalid UUID at index %d: %w", i, err) + } + + kemPubKey := make([]byte, e.kem_pub_key_len) + copy(kemPubKey, C.GoBytes(unsafe.Pointer(&e.kem_pub_key[0]), C.int(e.kem_pub_key_len))) + + algoBytes := C.GoBytes(unsafe.Pointer(&e.algorithm[0]), C.int(e.algorithm_len)) + algo := &algorithms.HpkeAlgorithm{} + if err := proto.Unmarshal(algoBytes, algo); err != nil { + return nil, fmt.Errorf("failed to unmarshal algorithm for key %d: %w", i, err) + } + + result[i] = KEMKeyInfo{ + ID: id, + Algorithm: algo, + KEMPubKey: kemPubKey, + RemainingLifespanSecs: uint64(e.remaining_lifespan_secs), + } + } + + return result, nil +} diff --git a/keymanager/key_protection_service/key_custody_core/src/lib.rs b/keymanager/key_protection_service/key_custody_core/src/lib.rs index 876601126..35f3b4cfa 100644 --- a/keymanager/key_protection_service/key_custody_core/src/lib.rs +++ b/keymanager/key_protection_service/key_custody_core/src/lib.rs @@ -6,7 +6,7 @@ use std::slice; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::sync::LazyLock; -use std::time::Duration; +use std::time::{Duration, Instant}; use uuid::Uuid; static KEY_REGISTRY: LazyLock = LazyLock::new(|| { @@ -145,6 +145,77 @@ pub unsafe extern "C" fn key_manager_destroy_kem_key(uuid_bytes: *const u8) -> i } } +#[repr(C)] +pub struct KpsKeyInfo { + pub uuid: [u8; 16], + pub algorithm: [u8; 128], + pub algorithm_len: usize, + pub kem_pub_key: [u8; 2048], + pub kem_pub_key_len: usize, + pub remaining_lifespan_secs: u64, +} + +#[unsafe(no_mangle)] +pub unsafe extern "C" fn key_manager_enumerate_kem_keys( + out_entries: *mut KpsKeyInfo, + max_entries: usize, + offset: usize, + out_count: *mut usize, +) -> i32 { + if out_entries.is_null() || out_count.is_null() { + return -1; + } + + let metas = KEY_REGISTRY.list_kem_keys(offset, max_entries); + let now = Instant::now(); + let mut count = 0usize; + + for meta in &metas { + if let KeySpec::KemWithBindingPub { + algo, + kem_public_key, + binding_public_key: _, + } = &meta.spec + { + let algo_bytes = algo.encode_to_vec(); + if kem_public_key.as_bytes().len() > 2048 || algo_bytes.len() > 128 { + eprintln!( + "Skipping key {}: size exceeds buffer limits (algo={}, kem={})", + meta.id, + algo_bytes.len(), + kem_public_key.as_bytes().len() + ); + } else { + continue; + } + + let remaining = meta.delete_after.saturating_duration_since(now).as_secs(); + + let entry = unsafe { &mut *out_entries.add(count) }; + entry.uuid.copy_from_slice(meta.id.as_bytes()); + + + entry.algorithm = [0u8; 128]; + entry.algorithm[..algo_bytes.len()].copy_from_slice(&algo_bytes); + entry.algorithm_len = algo_bytes.len(); + + entry.kem_pub_key = [0u8; 2048]; + entry.kem_pub_key[..kem_public_key.as_bytes().len()] + .copy_from_slice(kem_public_key.as_bytes()); + entry.kem_pub_key_len = kem_public_key.as_bytes().len(); + + entry.remaining_lifespan_secs = remaining; + + count += 1; + } + } + + unsafe { + *out_count = count; + } + 0 +} + #[cfg(test)] mod tests { use super::*; @@ -326,10 +397,12 @@ mod tests { kdf: KdfAlgorithm::HkdfSha256 as i32, aead: AeadAlgorithm::Aes256Gcm as i32, }; + let algo_bytes = algo.encode_to_vec(); unsafe { let res = key_manager_generate_kem_keypair( - algo, + algo_bytes.as_ptr(), + algo_bytes.len(), binding_pubkey.as_ptr(), binding_pubkey.len(), 3600, @@ -360,4 +433,79 @@ mod tests { let result = unsafe { key_manager_destroy_kem_key(std::ptr::null()) }; assert_eq!(result, -1); } + + #[test] + fn test_enumerate_kem_keys_null_pointers() { + let result = unsafe { + key_manager_enumerate_kem_keys(std::ptr::null_mut(), 10, 0, std::ptr::null_mut()) + }; + assert_eq!(result, -1); + } + + #[test] + fn test_enumerate_kem_keys_after_generate() { + let binding_pubkey = [7u8; 32]; + let mut uuid_bytes = [0u8; 16]; + let mut pubkey_bytes = [0u8; 32]; + let pubkey_len: usize = 32; + let algo = HpkeAlgorithm { + kem: KemAlgorithm::DhkemX25519HkdfSha256 as i32, + kdf: KdfAlgorithm::HkdfSha256 as i32, + aead: AeadAlgorithm::Aes256Gcm as i32, + }; + // MUST encode to bytes + let algo_bytes = algo.encode_to_vec(); + + // Generate a key first. + let rc = unsafe { + key_manager_generate_kem_keypair( + algo_bytes.as_ptr(), + algo_bytes.len(), + binding_pubkey.as_ptr(), + binding_pubkey.len(), + 3600, + uuid_bytes.as_mut_ptr(), + pubkey_bytes.as_mut_ptr(), + pubkey_len, + ) + }; + assert_eq!(rc, 0); + + // Enumerate. + let mut entries: Vec = Vec::with_capacity(256); + // Initialize with default/zero values. Note: Arrays are larger now. + entries.resize_with(100, || KpsKeyInfo { + uuid: [0; 16], + algorithm: [0; 128], + algorithm_len: 0, + kem_pub_key: [0; 2048], + kem_pub_key_len: 0, + remaining_lifespan_secs: 0, + }); + let mut count: usize = 0; + + let rc = unsafe { + // max_entries=1, offset=0 + key_manager_enumerate_kem_keys(entries.as_mut_ptr(), entries.len(), 0, &mut count) + }; + assert_eq!(rc, 0); + // At least 1 key should be enumerated (the one we just generated). + assert!(count >= 1); + + // Find our key in the results. + let mut found = false; + for i in 0..count { + if entries[i].uuid == uuid_bytes { + found = true; + let encoded_algo = &entries[i].algorithm[..entries[i].algorithm_len]; + let decoded_algo = HpkeAlgorithm::decode(encoded_algo).unwrap(); + assert_eq!(decoded_algo.kem, KemAlgorithm::DhkemX25519HkdfSha256 as i32); + assert_eq!(entries[i].kem_pub_key_len, 32); + // binding_pub_key checks removed + assert!(entries[i].remaining_lifespan_secs > 0); + break; + } + } + assert!(found, "generated key not found in enumerate results"); + } } diff --git a/keymanager/key_protection_service/key_custody_core/types.go b/keymanager/key_protection_service/key_custody_core/types.go new file mode 100644 index 000000000..c981eb2b7 --- /dev/null +++ b/keymanager/key_protection_service/key_custody_core/types.go @@ -0,0 +1,14 @@ +package kpskcc + +import ( + algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto" + "github.com/google/uuid" +) + +// KEMKeyInfo holds metadata for a single KEM key returned by EnumerateKEMKeys. +type KEMKeyInfo struct { + ID uuid.UUID + Algorithm *algorithms.HpkeAlgorithm + KEMPubKey []byte + RemainingLifespanSecs uint64 +} diff --git a/keymanager/key_protection_service/service.go b/keymanager/key_protection_service/service.go index c9319bfaa..6ff5d7c8b 100644 --- a/keymanager/key_protection_service/service.go +++ b/keymanager/key_protection_service/service.go @@ -1,9 +1,10 @@ // Package key_protection_service implements the Key Orchestration Layer (KOL) // for the Key Protection Service. It wraps the KPS Key Custody Core (KCC) FFI -// to provide a Go-native interface for KEM key generation. +// to provide a Go-native interface for KEM key generation and enumeration. package key_protection_service import ( + kpskcc "github.com/google/go-tpm-tools/keymanager/key_protection_service/key_custody_core" "github.com/google/uuid" algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto" @@ -14,14 +15,26 @@ type KEMKeyGenerator interface { GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) } -// Service implements KEMKeyGenerator by delegating to the KPS KCC FFI. +// KEMKeyEnumerator enumerates active KEM keys in the KPS registry. +type KEMKeyEnumerator interface { + EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, error) +} + +// Service implements KEMKeyGenerator and KEMKeyEnumerator by delegating to the KPS KCC FFI. type Service struct { generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) + enumerateKEMKeysFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, error) } -// NewService creates a new KPS KOL service with the given KCC function. -func NewService(generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error)) *Service { - return &Service{generateKEMKeypairFn: generateKEMKeypairFn} +// NewService creates a new KPS KOL service with the given KCC functions. +func NewService( + generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error), + enumerateKEMKeysFn func(limit, offset int) ([]kpskcc.KEMKeyInfo, error), +) *Service { + return &Service{ + generateKEMKeypairFn: generateKEMKeypairFn, + enumerateKEMKeysFn: enumerateKEMKeysFn, + } } // GenerateKEMKeypair generates a KEM keypair linked to the provided binding @@ -29,3 +42,8 @@ func NewService(generateKEMKeypairFn func(algo *algorithms.HpkeAlgorithm, bindin func (s *Service) GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { return s.generateKEMKeypairFn(algo, bindingPubKey, lifespanSecs) } + +// EnumerateKEMKeys retrieves all active KEM key entries from the KPS KCC registry. +func (s *Service) EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, error) { + return s.enumerateKEMKeysFn(limit, offset) +} diff --git a/keymanager/key_protection_service/service_test.go b/keymanager/key_protection_service/service_test.go index e4e40c6fc..5c33beae0 100644 --- a/keymanager/key_protection_service/service_test.go +++ b/keymanager/key_protection_service/service_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + kpskcc "github.com/google/go-tpm-tools/keymanager/key_protection_service/key_custody_core" "github.com/google/uuid" algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto" @@ -16,15 +17,20 @@ func TestServiceGenerateKEMKeypairSuccess(t *testing.T) { expectedPubKey[i] = byte(i + 10) } - svc := NewService(func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { - if len(bindingPubKey) != 32 { - t.Fatalf("expected 32-byte binding public key, got %d", len(bindingPubKey)) - } - if lifespanSecs != 7200 { - t.Fatalf("expected lifespanSecs 7200, got %d", lifespanSecs) - } - return expectedUUID, expectedPubKey, nil - }) + svc := NewService( + func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { + if len(bindingPubKey) != 32 { + t.Fatalf("expected 32-byte binding public key, got %d", len(bindingPubKey)) + } + if lifespanSecs != 7200 { + t.Fatalf("expected lifespanSecs 7200, got %d", lifespanSecs) + } + return expectedUUID, expectedPubKey, nil + }, + func(limit, offset int) ([]kpskcc.KEMKeyInfo, error) { + return nil, nil + }, + ) id, pubKey, err := svc.GenerateKEMKeypair(&algorithms.HpkeAlgorithm{}, make([]byte, 32), 7200) if err != nil { @@ -39,12 +45,71 @@ func TestServiceGenerateKEMKeypairSuccess(t *testing.T) { } func TestServiceGenerateKEMKeypairError(t *testing.T) { - svc := NewService(func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { - return uuid.Nil, nil, fmt.Errorf("FFI error") - }) + svc := NewService( + func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { + return uuid.Nil, nil, fmt.Errorf("FFI error") + }, + func(limit, offset int) ([]kpskcc.KEMKeyInfo, error) { + return nil, nil + }, + ) _, _, err := svc.GenerateKEMKeypair(&algorithms.HpkeAlgorithm{}, make([]byte, 32), 3600) if err == nil { t.Fatal("expected error, got nil") } } + +func TestServiceEnumerateKEMKeysSuccess(t *testing.T) { + expectedKeys := []kpskcc.KEMKeyInfo{ + { + ID: uuid.New(), + Algorithm: &algorithms.HpkeAlgorithm{ + Kem: algorithms.KemAlgorithm_KEM_ALGORITHM_DHKEM_X25519_HKDF_SHA256, + Kdf: algorithms.KdfAlgorithm_KDF_ALGORITHM_HKDF_SHA256, + Aead: algorithms.AeadAlgorithm_AEAD_ALGORITHM_AES_256_GCM, + }, + KEMPubKey: make([]byte, 32), + RemainingLifespanSecs: 3500, + }, + } + + svc := NewService( + func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { + return uuid.Nil, nil, nil + }, + func(limit, offset int) ([]kpskcc.KEMKeyInfo, error) { + if limit != 100 || offset != 0 { + return nil, fmt.Errorf("unexpected limit/offset: %d/%d", limit, offset) + } + return expectedKeys, nil + }, + ) + + keys, err := svc.EnumerateKEMKeys(100, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(keys) != 1 { + t.Fatalf("expected 1 key, got %d", len(keys)) + } + if keys[0].ID != expectedKeys[0].ID { + t.Fatalf("expected ID %s, got %s", expectedKeys[0].ID, keys[0].ID) + } +} + +func TestServiceEnumerateKEMKeysError(t *testing.T) { + svc := NewService( + func(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) { + return uuid.Nil, nil, nil + }, + func(limit, offset int) ([]kpskcc.KEMKeyInfo, error) { + return nil, fmt.Errorf("enumerate error") + }, + ) + + _, err := svc.EnumerateKEMKeys(100, 0) + if err == nil { + t.Fatal("expected error, got nil") + } +} diff --git a/keymanager/km_common/build.rs b/keymanager/km_common/build.rs index 5d08cbae2..a3e8ceb42 100644 --- a/keymanager/km_common/build.rs +++ b/keymanager/km_common/build.rs @@ -7,7 +7,7 @@ fn main() -> Result<()> { } let mut config = prost_build::Config::new(); - config.type_attribute("HpkeAlgorithm", "#[repr(C)]"); + config.compile_protos(&["proto/algorithms.proto"], &["proto/"])?; Ok(()) diff --git a/keymanager/km_common/src/key_types.rs b/keymanager/km_common/src/key_types.rs index dcc743c15..74bc257aa 100644 --- a/keymanager/km_common/src/key_types.rs +++ b/keymanager/km_common/src/key_types.rs @@ -1,6 +1,6 @@ use crate::algorithms::{AeadAlgorithm, HpkeAlgorithm, KdfAlgorithm, KemAlgorithm}; use crate::crypto; -use crate::crypto::{PublicKey, secret_box}; +use crate::crypto::{secret_box, PublicKey}; use crate::protected_mem::Vault; use std::collections::HashMap; use std::sync::{Arc, RwLock}; @@ -86,19 +86,40 @@ impl KeyRegistry { stop_signal: Arc, ) -> std::thread::JoinHandle<()> { let keys_clone = Arc::clone(&self.keys); - std::thread::spawn(move || { - loop { - std::thread::sleep(Duration::from_secs(REAPER_INTERVAL_SECS)); - if stop_signal.load(std::sync::atomic::Ordering::Relaxed) { - break; - } - let now = Instant::now(); - if let Ok(mut keys) = keys_clone.write() { - keys.retain(|_, key| key.meta.delete_after > now); - } + std::thread::spawn(move || loop { + std::thread::sleep(Duration::from_secs(REAPER_INTERVAL_SECS)); + if stop_signal.load(std::sync::atomic::Ordering::Relaxed) { + break; + } + let now = Instant::now(); + if let Ok(mut keys) = keys_clone.write() { + keys.retain(|_, key| key.meta.delete_after > now); } }) } + + /// Lists only KEM keys with pagination support. + pub fn list_kem_keys(&self, offset: usize, limit: usize) -> Vec { + let keys = self.keys.read().unwrap(); + let mut refs: Vec<&Arc> = keys + .values() + .filter(|k| matches!(k.meta.spec, KeySpec::KemWithBindingPub { .. })) + .collect(); + + // Sort for stable pagination: created_at, then id + refs.sort_by(|a, b| { + a.meta + .created_at + .cmp(&b.meta.created_at) + .then(a.meta.id.cmp(&b.meta.id)) + }); + + refs.into_iter() + .skip(offset) + .take(limit) + .map(|r| r.meta.clone()) + .collect() + } } impl KeyRecord { @@ -186,6 +207,33 @@ mod tests { use super::*; use crate::algorithms::{AeadAlgorithm, KdfAlgorithm, KemAlgorithm}; + fn create_key_record( + algo: HpkeAlgorithm, + expiry_secs: u64, + spec_builder: F, + ) -> Result + where + F: FnOnce(HpkeAlgorithm, PublicKey) -> KeySpec, + { + let (pub_key, priv_key) = crypto::generate_keypair(KemAlgorithm::DhkemX25519HkdfSha256)?; + let id = Uuid::new_v4(); + let vault = Vault::new(secret_box::SecretBox::from(priv_key)) + .map_err(|_| crypto::Error::CryptoError)?; + let now = Instant::now(); + let delete_after = now + .checked_add(Duration::from_secs(expiry_secs)) + .ok_or(crypto::Error::UnsupportedAlgorithm)?; + Ok(KeyRecord { + meta: KeyMetadata { + id, + created_at: now, + delete_after, + spec: spec_builder(algo, pub_key), + }, + private_key: vault, + }) + } + #[test] fn test_create_binding_key_success() { let algo = HpkeAlgorithm { @@ -246,6 +294,63 @@ mod tests { } } + #[test] + fn test_list_kem_keys_empty() { + let registry = KeyRegistry::default(); + let keys = registry.list_kem_keys(0, 100); + assert!(keys.is_empty()); + } + + #[test] + fn test_list_kem_keys() { + let registry = KeyRegistry::default(); + let algo = HpkeAlgorithm { + kem: KemAlgorithm::DhkemX25519HkdfSha256 as i32, + kdf: KdfAlgorithm::HkdfSha256 as i32, + aead: AeadAlgorithm::Aes256Gcm as i32, + }; + + let binding_pubkey = [42u8; 32]; + + let record1 = create_key_record(algo, 3600, |a, pk| KeySpec::KemWithBindingPub { + algo: a, + kem_public_key: pk, + binding_public_key: PublicKey::try_from(binding_pubkey.to_vec()).unwrap(), + }) + .expect("failed to create key 1"); + let id1 = record1.meta.id; + + let record2 = create_key_record(algo, 7200, |a, pk| KeySpec::KemWithBindingPub { + algo: a, + kem_public_key: pk, + binding_public_key: PublicKey::try_from(binding_pubkey.to_vec()).unwrap(), + }) + .expect("failed to create key 2"); + let id2 = record2.meta.id; + + registry.add_key(record1); + registry.add_key(record2); + + let metas = registry.list_kem_keys(0, 100); + assert_eq!(metas.len(), 2); + + let ids: Vec = metas.iter().map(|m| m.id).collect(); + assert!(ids.contains(&id1)); + assert!(ids.contains(&id2)); + + // Test pagination + let page1 = registry.list_kem_keys(0, 1); + assert_eq!(page1.len(), 1); + let page2 = registry.list_kem_keys(1, 1); + assert_eq!(page2.len(), 1); + assert_ne!(page1[0].id, page2[0].id); + + // Verify order (record1 created before record2) + // Note: created_at might be identical if executed very fast, so sort checks ID too. + // But record1 and record2 are created with explicit calls, so created_at likely differs slightly. + // Either way, stability is guaranteed. + } + #[test] fn test_add_key() { let registry = KeyRegistry::default(); diff --git a/keymanager/workload_service/key_custody_core/src/lib.rs b/keymanager/workload_service/key_custody_core/src/lib.rs index ea5e1501e..9c8cc5b4a 100644 --- a/keymanager/workload_service/key_custody_core/src/lib.rs +++ b/keymanager/workload_service/key_custody_core/src/lib.rs @@ -264,16 +264,18 @@ mod tests { aead: AeadAlgorithm::Aes256Gcm as i32, }; - unsafe { - let res = key_manager_generate_binding_keypair( - algo, + let algo_bytes = algo.encode_to_vec(); + let res = unsafe { + key_manager_generate_binding_keypair( + algo_bytes.as_ptr(), + algo_bytes.len(), 3600, uuid_bytes.as_mut_ptr(), pubkey_bytes.as_mut_ptr(), pubkey_len, - ); - assert_eq!(res, 0); + ) }; + assert_eq!(res, 0); let result = unsafe { key_manager_destroy_binding_key(uuid_bytes.as_ptr()) }; assert_eq!(result, 0); diff --git a/keymanager/workload_service/proto_enums.go b/keymanager/workload_service/proto_enums.go index 2b2c2b748..f8ca4cf64 100644 --- a/keymanager/workload_service/proto_enums.go +++ b/keymanager/workload_service/proto_enums.go @@ -133,3 +133,88 @@ func (k KemAlgorithm) ToHpkeAlgorithm() (*algorithms.HpkeAlgorithm, error) { return nil, fmt.Errorf("unsupported algorithm: %s", k) } } + +// KdfAlgorithm represents the requested KDF algorithm. +type KdfAlgorithm int32 + +const ( + KdfAlgorithmUnspecified KdfAlgorithm = 0 + // Corrected from HKDF_SHA384 to HKDF_SHA256 based on ToHpkeAlgorithm usage which maps to HKDF_SHA256 (val 1) + KdfAlgorithmHKDFSHA256 KdfAlgorithm = 1 +) + +var ( + kdfAlgorithmToString = map[KdfAlgorithm]string{ + KdfAlgorithmUnspecified: "KDF_ALGORITHM_UNSPECIFIED", + KdfAlgorithmHKDFSHA256: "HKDF_SHA256", + } + stringToKdfAlgorithm = map[string]KdfAlgorithm{ + "KDF_ALGORITHM_UNSPECIFIED": KdfAlgorithmUnspecified, + "HKDF_SHA256": KdfAlgorithmHKDFSHA256, + } +) + +func (k KdfAlgorithm) String() string { + if s, ok := kdfAlgorithmToString[k]; ok { + return s + } + return fmt.Sprintf("KDF_ALGORITHM_UNKNOWN(%d)", k) +} + +func (k KdfAlgorithm) MarshalJSON() ([]byte, error) { + return json.Marshal(k.String()) +} + +func (k *KdfAlgorithm) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("KdfAlgorithm must be a string") + } + if v, ok := stringToKdfAlgorithm[s]; ok { + *k = v + return nil + } + return fmt.Errorf("unknown KdfAlgorithm: %q", s) +} + +// AeadAlgorithm represents the requested AEAD algorithm. +type AeadAlgorithm int32 + +const ( + AeadAlgorithmUnspecified AeadAlgorithm = 0 + AeadAlgorithmAES256GCM AeadAlgorithm = 1 +) + +var ( + aeadAlgorithmToString = map[AeadAlgorithm]string{ + AeadAlgorithmUnspecified: "AEAD_ALGORITHM_UNSPECIFIED", + AeadAlgorithmAES256GCM: "AES_256_GCM", + } + stringToAeadAlgorithm = map[string]AeadAlgorithm{ + "AEAD_ALGORITHM_UNSPECIFIED": AeadAlgorithmUnspecified, + "AES_256_GCM": AeadAlgorithmAES256GCM, + } +) + +func (k AeadAlgorithm) String() string { + if s, ok := aeadAlgorithmToString[k]; ok { + return s + } + return fmt.Sprintf("AEAD_ALGORITHM_UNKNOWN(%d)", k) +} + +func (k AeadAlgorithm) MarshalJSON() ([]byte, error) { + return json.Marshal(k.String()) +} + +func (k *AeadAlgorithm) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("AeadAlgorithm must be a string") + } + if v, ok := stringToAeadAlgorithm[s]; ok { + *k = v + return nil + } + return fmt.Errorf("unknown AeadAlgorithm: %q", s) +} diff --git a/keymanager/workload_service/server.go b/keymanager/workload_service/server.go index 41ec27ab3..437ccecc4 100644 --- a/keymanager/workload_service/server.go +++ b/keymanager/workload_service/server.go @@ -5,6 +5,7 @@ package workload_service import ( "context" + "encoding/base64" "encoding/json" "fmt" "math" @@ -13,6 +14,7 @@ import ( "os" "sync" + kpskcc "github.com/google/go-tpm-tools/keymanager/key_protection_service/key_custody_core" "github.com/google/uuid" algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto" @@ -28,6 +30,12 @@ type KEMKeyGenerator interface { GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindingPubKey []byte, lifespanSecs uint64) (uuid.UUID, []byte, error) } +// KEMKeyEnumerator enumerates active KEM keys from the KPS registry. +// KEMKeyEnumerator enumerates active KEM keys from the KPS registry. +type KEMKeyEnumerator interface { + EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, error) +} + // KeyHandle represents a key handle returned from the API. type KeyHandle struct { Handle string `json:"handle"` @@ -68,10 +76,48 @@ type GenerateKemResponse struct { KeyHandle KeyHandle `json:"key_handle"` } +// KemPublicKey represents a KEM public key with its algorithm identifier. +type KemPublicKey struct { + Algorithm KemAlgorithm `json:"algorithm"` + PublicKey string `json:"public_key"` +} + +// HpkeAlgorithm identifies the HPKE algorithm suite (KEM, KDF, AEAD). +type HpkeAlgorithm struct { + Kem KemAlgorithm `json:"kem"` + Kdf KdfAlgorithm `json:"kdf"` + Aead AeadAlgorithm `json:"aead"` +} + +// HpkePublicKey represents an HPKE public key with its full algorithm suite. +type HpkePublicKey struct { + Algorithm HpkeAlgorithm `json:"algorithm"` + PublicKey string `json:"public_key"` +} + +// BoundKEMInfo holds the full metadata for a bound KEM key. +type BoundKEMInfo struct { + KeyHandle KeyHandle `json:"key_handle"` + KemPubKey KemPublicKey `json:"kem_pub_key"` + // BindingPubKey removed as it is no longer returned by KCC FFI + RemainingLifespan ProtoDuration `json:"remaining_lifespan"` +} + +// KeyInfo wraps a single key entry in the enumerate response. +type KeyInfo struct { + BoundKemInfo *BoundKEMInfo `json:"bound_kem_info,omitempty"` +} + +// EnumerateKeysResponse is returned by GET /v1/keys. +type EnumerateKeysResponse struct { + KeyInfos []KeyInfo `json:"key_infos"` +} + // Server is the WSD HTTP server. type Server struct { bindingGen BindingKeyGenerator kemGen KEMKeyGenerator + kemEnum KEMKeyEnumerator mu sync.RWMutex kemToBindingMap map[uuid.UUID]uuid.UUID @@ -81,15 +127,17 @@ type Server struct { } // NewServer creates a new WSD server with the given dependencies. -func NewServer(bindingGen BindingKeyGenerator, kemGen KEMKeyGenerator) *Server { +func NewServer(bindingGen BindingKeyGenerator, kemGen KEMKeyGenerator, kemEnum KEMKeyEnumerator) *Server { s := &Server{ bindingGen: bindingGen, kemGen: kemGen, + kemEnum: kemEnum, kemToBindingMap: make(map[uuid.UUID]uuid.UUID), } mux := http.NewServeMux() mux.HandleFunc("POST /v1/keys:generate_kem", s.handleGenerateKem) + mux.HandleFunc("GET /v1/keys", s.handleEnumerateKeys) s.httpServer = &http.Server{Handler: mux} return s @@ -124,6 +172,47 @@ func (s *Server) LookupBindingUUID(kemUUID uuid.UUID) (uuid.UUID, bool) { return id, ok } +func (s *Server) handleEnumerateKeys(w http.ResponseWriter, r *http.Request) { + // Check method again implicitly via mux, but extra check doesn't hurt if mux misconfigures. + // Actually mux handles it if using "GET /v1/keys". + if r.Method != http.MethodGet { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Default limit/offset for now. + // TODO: Parse from query params? + const ( + defaultLimit = 100 + defaultOffset = 0 + ) + + keys, err := s.kemEnum.EnumerateKEMKeys(defaultLimit, defaultOffset) + if err != nil { + http.Error(w, fmt.Sprintf("failed to enumerate keys: %v", err), http.StatusInternalServerError) + return + } + + keyInfos := make([]KeyInfo, 0, len(keys)) + for _, k := range keys { + info := KeyInfo{ + BoundKemInfo: &BoundKEMInfo{ + KeyHandle: KeyHandle{Handle: k.ID.String()}, + KemPubKey: KemPublicKey{ + Algorithm: KemAlgorithm(k.Algorithm.Kem), + PublicKey: base64.StdEncoding.EncodeToString(k.KEMPubKey), + }, + RemainingLifespan: ProtoDuration{Seconds: k.RemainingLifespanSecs}, + }, + } + keyInfos = append(keyInfos, info) + } + + resp := EnumerateKeysResponse{KeyInfos: keyInfos} + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + func (s *Server) handleGenerateKem(w http.ResponseWriter, r *http.Request) { var req GenerateKemRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { @@ -191,5 +280,3 @@ func writeError(w http.ResponseWriter, message string, code int) { w.WriteHeader(code) json.NewEncoder(w).Encode(map[string]string{"error": message}) } - - diff --git a/keymanager/workload_service/server_test.go b/keymanager/workload_service/server_test.go index fc9e35c9d..d8787318b 100644 --- a/keymanager/workload_service/server_test.go +++ b/keymanager/workload_service/server_test.go @@ -2,6 +2,7 @@ package workload_service import ( "bytes" + "encoding/base64" "encoding/json" "fmt" "net/http" @@ -9,6 +10,7 @@ import ( "strings" "testing" + kpskcc "github.com/google/go-tpm-tools/keymanager/key_protection_service/key_custody_core" "github.com/google/uuid" algorithms "github.com/google/go-tpm-tools/keymanager/km_common/proto" @@ -40,6 +42,16 @@ func (m *mockKEMKeyGen) GenerateKEMKeypair(algo *algorithms.HpkeAlgorithm, bindi return m.uuid, m.pubKey, m.err } +// mockKEMKeyEnumerator implements KEMKeyEnumerator for testing. +type mockKEMKeyEnumerator struct { + keys []kpskcc.KEMKeyInfo + err error +} + +func (m *mockKEMKeyEnumerator) EnumerateKEMKeys(limit, offset int) ([]kpskcc.KEMKeyInfo, error) { + return m.keys, m.err +} + func validGenerateBody() []byte { body, _ := json.Marshal(GenerateKemRequest{ Algorithm: KemAlgorithmDHKEMX25519HKDFSHA256, @@ -65,6 +77,7 @@ func TestHandleGenerateKemSuccess(t *testing.T) { srv := NewServer( &mockBindingKeyGen{uuid: bindingUUID, pubKey: bindingPubKey}, kemGen, + &mockKEMKeyEnumerator{}, ) req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_kem", bytes.NewReader(validGenerateBody())) @@ -113,6 +126,7 @@ func TestHandleGenerateKemInvalidMethod(t *testing.T) { srv := NewServer( &mockBindingKeyGen{pubKey: make([]byte, 32)}, &mockKEMKeyGen{pubKey: make([]byte, 32)}, + &mockKEMKeyEnumerator{}, ) req := httptest.NewRequest(http.MethodGet, "/v1/keys:generate_kem", nil) @@ -128,6 +142,7 @@ func TestHandleGenerateKemBadRequest(t *testing.T) { srv := NewServer( &mockBindingKeyGen{uuid: uuid.New(), pubKey: make([]byte, 32)}, &mockKEMKeyGen{uuid: uuid.New(), pubKey: make([]byte, 32)}, + &mockKEMKeyEnumerator{}, ) tests := []struct { @@ -182,6 +197,7 @@ func TestHandleGenerateKemBadJSON(t *testing.T) { srv := NewServer( &mockBindingKeyGen{pubKey: make([]byte, 32)}, &mockKEMKeyGen{pubKey: make([]byte, 32)}, + &mockKEMKeyEnumerator{}, ) badBodies := []struct { @@ -189,9 +205,9 @@ func TestHandleGenerateKemBadJSON(t *testing.T) { body string }{ {"not json", "not json"}, - {"lifespan as string", `{"algorithm":1,"key_protection_mechanism":2,"lifespan":"3600"}`}, - {"lifespan as string with suffix", `{"algorithm":1,"key_protection_mechanism":2,"lifespan":"3600s"}`}, - {"lifespan negative", `{"algorithm":1,"key_protection_mechanism":2,"lifespan":-1}`}, + {"lifespan as integer", `{"algorithm":1,"key_protection_mechanism":2,"lifespan":3600}`}, + {"lifespan missing s suffix", `{"algorithm":1,"key_protection_mechanism":2,"lifespan":"3600"}`}, + {"lifespan negative", `{"algorithm":1,"key_protection_mechanism":2,"lifespan":"-1s"}`}, } for _, tc := range badBodies { @@ -212,6 +228,7 @@ func TestHandleGenerateKemBindingGenError(t *testing.T) { srv := NewServer( &mockBindingKeyGen{err: fmt.Errorf("binding FFI error")}, &mockKEMKeyGen{pubKey: make([]byte, 32)}, + &mockKEMKeyEnumerator{}, ) req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_kem", bytes.NewReader(validGenerateBody())) @@ -228,6 +245,7 @@ func TestHandleGenerateKemFlexibleLifespan(t *testing.T) { srv := NewServer( &mockBindingKeyGen{uuid: uuid.New(), pubKey: make([]byte, 32)}, &mockKEMKeyGen{uuid: uuid.New(), pubKey: make([]byte, 32)}, + &mockKEMKeyEnumerator{}, ) tests := []struct { @@ -270,6 +288,7 @@ func TestHandleGenerateKemKEMGenError(t *testing.T) { srv := NewServer( &mockBindingKeyGen{uuid: uuid.New(), pubKey: make([]byte, 32)}, &mockKEMKeyGen{err: fmt.Errorf("KEM FFI error")}, + &mockKEMKeyEnumerator{}, ) req := httptest.NewRequest(http.MethodPost, "/v1/keys:generate_kem", bytes.NewReader(validGenerateBody())) @@ -282,6 +301,155 @@ func TestHandleGenerateKemKEMGenError(t *testing.T) { } } +func TestHandleEnumerateKeysEmpty(t *testing.T) { + srv := NewServer( + &mockBindingKeyGen{}, + &mockKEMKeyGen{}, + &mockKEMKeyEnumerator{keys: []kpskcc.KEMKeyInfo{}}, + ) + + req := httptest.NewRequest(http.MethodGet, "/v1/keys", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp EnumerateKeysResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.KeyInfos) != 0 { + t.Fatalf("expected 0 key infos, got %d", len(resp.KeyInfos)) + } +} + +func TestHandleEnumerateKeysWithKeys(t *testing.T) { + kem1 := uuid.New() + kem2 := uuid.New() + kemPubKey1 := make([]byte, 32) + kemPubKey2 := make([]byte, 32) + // BindingPubKey no longer used in response + for i := range kemPubKey1 { + kemPubKey1[i] = byte(i) + kemPubKey2[i] = byte(i + 50) + } + + mockEnum := &mockKEMKeyEnumerator{ + keys: []kpskcc.KEMKeyInfo{ + { + ID: kem1, + Algorithm: &algorithms.HpkeAlgorithm{ + Kem: algorithms.KemAlgorithm_KEM_ALGORITHM_DHKEM_X25519_HKDF_SHA256, + Kdf: algorithms.KdfAlgorithm_KDF_ALGORITHM_HKDF_SHA256, + Aead: algorithms.AeadAlgorithm_AEAD_ALGORITHM_AES_256_GCM, + }, + KEMPubKey: kemPubKey1, + RemainingLifespanSecs: 3500, + }, + { + ID: kem2, + Algorithm: &algorithms.HpkeAlgorithm{ + Kem: algorithms.KemAlgorithm_KEM_ALGORITHM_DHKEM_X25519_HKDF_SHA256, + Kdf: algorithms.KdfAlgorithm_KDF_ALGORITHM_HKDF_SHA256, + Aead: algorithms.AeadAlgorithm_AEAD_ALGORITHM_AES_256_GCM, + }, + KEMPubKey: kemPubKey2, + RemainingLifespanSecs: 7100, + }, + }, + } + + srv := NewServer( + &mockBindingKeyGen{}, + &mockKEMKeyGen{}, + mockEnum, + ) + + req := httptest.NewRequest(http.MethodGet, "/v1/keys", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp EnumerateKeysResponse + if err := json.NewDecoder(w.Body).Decode(&resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(resp.KeyInfos) != 2 { + t.Fatalf("expected 2 key infos, got %d", len(resp.KeyInfos)) + } + + // Verify both keys appear (order-independent). + found := make(map[string]*BoundKEMInfo) + for _, ki := range resp.KeyInfos { + if ki.BoundKemInfo == nil { + t.Fatal("expected non-nil boundKemInfo") + } + found[ki.BoundKemInfo.KeyHandle.Handle] = ki.BoundKemInfo + } + + // Verify key 1. + info1, ok := found[kem1.String()] + if !ok { + t.Fatalf("expected kem1 %s in response", kem1) + } + if info1.KemPubKey.Algorithm != KemAlgorithmDHKEMX25519HKDFSHA256 { + t.Fatalf("expected algorithm %v, got %v", KemAlgorithmDHKEMX25519HKDFSHA256, info1.KemPubKey.Algorithm) + } + if info1.KemPubKey.PublicKey != base64.StdEncoding.EncodeToString(kemPubKey1) { + t.Fatalf("KEM pub key mismatch for kem1") + } + // BindingPubKey check removed + if info1.RemainingLifespan.Seconds != 3500 { + t.Fatalf("expected remaining lifespan 3500, got %d", info1.RemainingLifespan.Seconds) + } + + // Verify key 2. + info2, ok := found[kem2.String()] + if !ok { + t.Fatalf("expected kem2 %s in response", kem2) + } + if info2.RemainingLifespan.Seconds != 7100 { + t.Fatalf("expected remaining lifespan 7100, got %d", info2.RemainingLifespan.Seconds) + } +} + +func TestHandleEnumerateKeysMethodNotAllowed(t *testing.T) { + srv := NewServer( + &mockBindingKeyGen{}, + &mockKEMKeyGen{}, + &mockKEMKeyEnumerator{}, + ) + + req := httptest.NewRequest(http.MethodPost, "/v1/keys", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Fatalf("expected status 405, got %d", w.Code) + } +} + +func TestHandleEnumerateKeysError(t *testing.T) { + srv := NewServer( + &mockBindingKeyGen{}, + &mockKEMKeyGen{}, + &mockKEMKeyEnumerator{err: fmt.Errorf("enumerate error")}, + ) + + req := httptest.NewRequest(http.MethodGet, "/v1/keys", nil) + w := httptest.NewRecorder() + srv.Handler().ServeHTTP(w, req) + + if w.Code != http.StatusInternalServerError { + t.Fatalf("expected status 500, got %d: %s", w.Code, w.Body.String()) + } +} + func TestHandleGenerateKemMapUniqueness(t *testing.T) { bindingPubKey := make([]byte, 32) @@ -294,7 +462,7 @@ func TestHandleGenerateKemMapUniqueness(t *testing.T) { bindingGen := &mockBindingKeyGen{} kemGen := &mockKEMKeyGen{} - srv := NewServer(bindingGen, kemGen) + srv := NewServer(bindingGen, kemGen, &mockKEMKeyEnumerator{}) // First call. bindingGen.uuid = bindingUUID1