Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions crate/cli/src/actions/kms/azure/byok/import_kek.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::path::PathBuf;

use clap::Parser;
use cosmian_kmip::kmip_2_1::kmip_types::UniqueIdentifier;
use cosmian_kms_client::{
KmsClient,
reexport::cosmian_kms_client_utils::import_utils::{ImportKeyFormat, KeyUsage},
Expand Down Expand Up @@ -30,7 +31,7 @@ pub struct ImportKekAction {
}

impl ImportKekAction {
pub async fn run(&self, kms_client: KmsClient) -> KmsCliResult<()> {
pub async fn run(&self, kms_client: KmsClient) -> KmsCliResult<UniqueIdentifier> {
let import_action = ImportSecretDataOrKeyAction {
key_file: self.kek_file.clone(),
key_id: self.key_id.clone(),
Expand All @@ -45,6 +46,6 @@ impl ImportKekAction {
wrapping_key_id: None,
};

import_action.run(kms_client).await.map(|_| ())
import_action.run(kms_client).await
}
}
12 changes: 7 additions & 5 deletions crate/cli/src/actions/kms/azure/byok/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ mod import_kek;

use clap::Subcommand;
use cosmian_kms_client::KmsClient;
pub(crate) use export_byok::ExportByokAction;
pub(crate) use import_kek::ImportKekAction;

use crate::{
actions::kms::azure::byok::{export_byok::ExportByokAction, import_kek::ImportKekAction},
error::result::KmsCliResult,
};
use crate::error::result::KmsCliResult;

/// Azure BYOK support.
/// See: <https://learn.microsoft.com/en-us/azure/key-vault/keys/byok-specification>
Expand All @@ -20,7 +19,10 @@ pub enum ByokCommands {
impl ByokCommands {
pub async fn process(&self, kms_rest_client: KmsClient) -> KmsCliResult<()> {
match self {
Self::Import(action) => action.run(kms_rest_client).await,
Self::Import(action) => {
action.run(kms_rest_client).await?;
Ok(())
}
Self::Export(action) => action.run(kms_rest_client).await,
}
}
Expand Down
2 changes: 1 addition & 1 deletion crate/cli/src/actions/kms/azure/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mod byok;
pub(crate) mod byok;

use clap::Parser;
use cosmian_kms_client::KmsClient;
Expand Down
108 changes: 108 additions & 0 deletions crate/cli/src/tests/kms/azure/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
use std::fs;

use openssl::{
pkey::{PKey, Private, Public},
rsa::Rsa,
};
use tempfile::TempDir;
use test_kms_server::start_default_test_kms_server;

use crate::{
actions::kms::{
azure::byok::{ExportByokAction, ImportKekAction},
symmetric::keys::create_key::CreateKeyAction,
},
error::{KmsCliError, result::KmsCliResult},
};

/// Generate RSA keypair using OpenSSL (random size from 2048, 3072, or 4096 bits).
///
/// This mirrors AWS KMS "get-parameters-for-import" wrapping key specs and keeps
/// the test independent from KMS RSA key generation/export actions.
fn generate_rsa_keypair() -> KmsCliResult<(PKey<Private>, PKey<Public>)> {
let key_sizes = [2048_u32, 3072_u32, 4096_u32];
// Avoid introducing new RNG deps in the CLI crate's dev-deps.
let bits = key_sizes[std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| {
let len_u32 = u32::try_from(key_sizes.len()).unwrap_or(1);
let idx_u32 = d.subsec_nanos() % len_u32;
usize::try_from(idx_u32).unwrap_or(0)
})
.unwrap_or(0)];

let rsa = Rsa::generate(bits)
.map_err(|e| KmsCliError::Default(format!("Failed to generate RSA key: {e}")))?;
let private_key = PKey::from_rsa(rsa.clone())
.map_err(|e| KmsCliError::Default(format!("Failed to build private key: {e}")))?;
let public_key = PKey::from_rsa(
Rsa::from_public_components(
rsa.n()
.to_owned()
.map_err(|e| KmsCliError::Default(format!("Failed to clone modulus: {e}")))?,
rsa.e()
.to_owned()
.map_err(|e| KmsCliError::Default(format!("Failed to clone exponent: {e}")))?,
)
.map_err(|e| KmsCliError::Default(format!("Failed to build public RSA key: {e}")))?,
)
.map_err(|e| KmsCliError::Default(format!("Failed to build public key: {e}")))?;

Ok((private_key, public_key))
}

#[tokio::test]
async fn test_azure_byok_import_kek_then_export_byok() -> KmsCliResult<()> {
// 1. Instantiate a default KMS server
let ctx = start_default_test_kms_server().await;
let kms_client = ctx.get_owner_client();

let tmp_dir = TempDir::new()?;
let kek_pem_path = tmp_dir.path().join("kek_pub.pem");

// 2. Generate an RSA key pair locally, write the public key in PKCS#8 PEM, then import it as Azure KEK
let (_private_key, public_key) = generate_rsa_keypair()?;
let public_key_pem = public_key
.public_key_to_pem()
.map_err(|e| KmsCliError::Default(format!("Failed to serialize public key PEM: {e}")))?;
fs::write(&kek_pem_path, &public_key_pem)?;

let kid = "https://unit.test/keys/KEK/00000000000000000000000000000000".to_owned();
let imported_kek_id = ImportKekAction {
kek_file: kek_pem_path,
kid: kid.clone(),
key_id: None,
}
.run(kms_client.clone())
.await?;

// The import action writes to stdout and does not return the imported id; locate it via tag.
// Tag is `kid:<kid>`.

// 3. Generate a symmetric key and run ExportByokAction using it as wrapped_key_id
let sym_key_id = CreateKeyAction {
number_of_bits: Some(256),
tags: vec!["test".to_owned()],
..CreateKeyAction::default()
}
.run(kms_client.clone())
.await?
.to_string();

let byok_file = tmp_dir.path().join("out.byok");

ExportByokAction {
wrapped_key_id: sym_key_id,
kek_id: imported_kek_id.to_string(),
byok_file: Some(byok_file.clone()),
}
.run(kms_client)
.await?;

// Assert byok file written
let contents = std::fs::read_to_string(&byok_file)?;
assert!(contents.contains("\"ciphertext\""));
assert!(contents.contains("\"kid\""));

Ok(())
}
1 change: 1 addition & 0 deletions crate/cli/src/tests/kms/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod access;
mod attributes;
mod auth_tests;
mod azure;
mod certificates;
#[cfg(feature = "non-fips")]
mod cover_crypt;
Expand Down
26 changes: 26 additions & 0 deletions crate/wasm/src/wasm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,32 @@ pub fn get_attributes_ttlv_request(unique_identifier: String) -> Result<JsValue,
serde_wasm_bindgen::to_value(&objects).map_err(|e| JsValue::from(e.to_string()))
}

/// Same as `get_attributes_ttlv_request`, but can force requesting tags.
///
/// Some callers (notably UI/WASM) rely on tags being returned, but the server may not include
/// `Tag::Tag` unless explicitly requested.
#[wasm_bindgen]
pub fn get_attributes_ttlv_request_with_options(
unique_identifier: String,
force_tags: bool,
) -> Result<JsValue, JsValue> {
let unique_identifier = UniqueIdentifier::TextString(unique_identifier);

let attribute_reference = if force_tags {
Some(vec![AttributeReference::Standard(Tag::Tag)])
} else {
None
};

let request = GetAttributes {
unique_identifier: Some(unique_identifier),
attribute_reference,
};

let objects = to_ttlv(&request).map_err(|e| JsValue::from(e.to_string()))?;
serde_wasm_bindgen::to_value(&objects).map_err(|e| JsValue::from(e.to_string()))
}

#[allow(clippy::needless_pass_by_value)]
#[wasm_bindgen]
pub fn parse_get_attributes_ttlv_response(
Expand Down
4 changes: 2 additions & 2 deletions ui/src/AzureExportByok.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import {useAuth} from "./AuthContext";
import {downloadFile, sendKmipRequest} from "./utils";
import {
export_ttlv_request,
get_attributes_ttlv_request,
get_attributes_ttlv_request_with_options,
parse_export_ttlv_response,
parse_get_attributes_ttlv_response
} from "./wasm/pkg";
Expand Down Expand Up @@ -59,7 +59,7 @@ const ExportAzureBYOKForm: React.FC = () => {
setRes(undefined);
try {
// Step 1: Get the KEK attributes to retrieve the Azure kid
const getAttrsRequest = get_attributes_ttlv_request(values.kekId);
const getAttrsRequest = get_attributes_ttlv_request_with_options(values.kekId, true);
const attrsResultStr = await sendKmipRequest(getAttrsRequest, idToken, serverUrl);

if (!attrsResultStr) {
Expand Down
Loading