diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 277685466..68f7524df 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -192,6 +192,7 @@ def validate_and_render_schema(): llms_with_usage = [] model_name_keys = set() model_usage_name_keys = set() + bedrock_providers_present = False print("listeners: ", listeners) @@ -240,6 +241,8 @@ def validate_and_render_schema(): f"Invalid model name {model_name}. Please provide model name in the format / or /* for wildcards." ) provider = model_name_tokens[0].strip() + if provider == "amazon_bedrock": + bedrock_providers_present = True # Check if this is a wildcard (provider/*) is_wildcard = model_name_tokens[-1].strip() == "*" @@ -436,6 +439,20 @@ def validate_and_render_schema(): f"Model alias 2 - '{alias_name}' targets '{target}' which is not defined as a model. Available models: {', '.join(sorted(model_name_keys))}" ) + aws_credentials_config = {} + for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]: + value = os.getenv(key) + if value: + aws_credentials_config[key] = value + + if aws_credentials_config: + config_yaml["aws_credentials"] = aws_credentials_config + if not bedrock_providers_present: + print( + "WARNING: AWS credentials were detected but no amazon_bedrock model_providers were found. " + "These credentials will be unused unless Bedrock providers are configured." + ) + arch_config_string = yaml.dump(config_yaml) arch_llm_config_string = yaml.dump(config_yaml) diff --git a/config/arch_config_schema.yaml b/config/arch_config_schema.yaml index 0f3cefb73..d90ed606c 100644 --- a/config/arch_config_schema.yaml +++ b/config/arch_config_schema.yaml @@ -128,6 +128,16 @@ properties: timeout: type: string additionalProperties: false + aws_credentials: + type: object + properties: + AWS_ACCESS_KEY_ID: + type: string + AWS_SECRET_ACCESS_KEY: + type: string + AWS_SESSION_TOKEN: + type: string + additionalProperties: false endpoints: type: object patternProperties: diff --git a/crates/Cargo.lock b/crates/Cargo.lock index f2744ad28..5b51d21fc 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -138,6 +138,51 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-credential-types" +version = "1.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf26925f4a5b59eb76722b63c2892b1d70d06fa053c72e4a100ec308c1d47bc" +dependencies = [ + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "zeroize", +] + +[[package]] +name = "aws-sigv4" +version = "1.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffc03068fbb9c8dd5ce1c6fb240678a5cffb86fb2b7b1985c999c4b83c8df68" +dependencies = [ + "aws-credential-types", + "aws-smithy-http", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "form_urlencoded", + "hex", + "hmac", + "http 0.2.12", + "http 1.3.1", + "percent-encoding", + "sha2", + "time", + "tracing", +] + +[[package]] +name = "aws-smithy-async" +version = "1.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "127fcfad33b7dfc531141fda7e1c402ac65f88aca5511a4d31e2e3d2cd01ce9c" +dependencies = [ + "futures-util", + "pin-project-lite", + "tokio", +] + [[package]] name = "aws-smithy-eventstream" version = "0.60.12" @@ -149,6 +194,43 @@ dependencies = [ "crc32fast", ] +[[package]] +name = "aws-smithy-http" +version = "0.62.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3feafd437c763db26aa04e0cc7591185d0961e64c61885bece0fb9d50ceac671" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-runtime-api" +version = "1.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3683c5b152d2ad753607179ed71988e8cfd52964443b4f74fd8e552d0bbfeb46" +dependencies = [ + "aws-smithy-async", + "aws-smithy-types", + "bytes", + "http 0.2.12", + "http 1.3.1", + "pin-project-lite", + "tokio", + "tracing", + "zeroize", +] + [[package]] name = "aws-smithy-types" version = "1.3.3" @@ -158,6 +240,8 @@ dependencies = [ "base64-simd", "bytes", "bytes-utils", + "http 0.2.12", + "http-body 0.4.6", "itoa", "num-integer", "pin-project-lite", @@ -167,6 +251,20 @@ dependencies = [ "time", ] +[[package]] +name = "aws-types" +version = "1.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2fd329bf0e901ff3f60425691410c69094dc2a1f34b331f37bfc4e9ac1565a1" +dependencies = [ + "aws-credential-types", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "rustc_version", + "tracing", +] + [[package]] name = "axum" version = "0.7.9" @@ -435,12 +533,18 @@ dependencies = [ name = "common" version = "0.1.0" dependencies = [ + "aws-credential-types", + "aws-sigv4", + "aws-smithy-runtime-api", + "aws-types", "axum", + "bytes", "derivative", "duration-string", "governor", "hermesllm", "hex", + "http 1.3.1", "log", "pretty_assertions", "proxy-wasm", @@ -2400,6 +2504,15 @@ version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.0.7" @@ -2614,6 +2727,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.219" diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index cb471bd6a..816d8599e 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -20,9 +20,24 @@ urlencoding = "2.1.3" url = "2.5.4" hermesllm = { version = "0.1.0", path = "../hermesllm" } serde_with = "3.13.0" +aws-sigv4 = { version = "1", optional = true } +aws-credential-types = { version = "1", optional = true } +aws-types = { version = "1", optional = true } +http = { version = "1", optional = true } +bytes = { version = "1", optional = true } +aws-smithy-runtime-api = { version = "1.9", optional = true } [features] default = [] +trace-collection = ["tokio", "reqwest", "tracing"] +aws-sigv4 = [ + "dep:aws-sigv4", + "dep:aws-credential-types", + "dep:aws-types", + "dep:http", + "dep:bytes", + "dep:aws-smithy-runtime-api", +] [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/aws_credentials.rs b/crates/common/src/aws_credentials.rs new file mode 100644 index 000000000..3baa8cceb --- /dev/null +++ b/crates/common/src/aws_credentials.rs @@ -0,0 +1,24 @@ +use crate::configuration::AwsCredentialsConfig; +use crate::errors::AwsError; + +pub fn get_credentials_from_config( + config: &AwsCredentialsConfig, +) -> Result<(String, String, Option), AwsError> { + let access_key_id = config + .access_key_id + .as_ref() + .ok_or_else(|| AwsError::CredentialError("AWS_ACCESS_KEY_ID not found".to_string()))? + .clone(); + + let secret_access_key = config + .secret_access_key + .as_ref() + .ok_or_else(|| AwsError::CredentialError("AWS_SECRET_ACCESS_KEY not found".to_string()))? + .clone(); + + Ok(( + access_key_id, + secret_access_key, + config.session_token.clone(), + )) +} diff --git a/crates/common/src/aws_sigv4.rs b/crates/common/src/aws_sigv4.rs new file mode 100644 index 000000000..3a421f6ff --- /dev/null +++ b/crates/common/src/aws_sigv4.rs @@ -0,0 +1,139 @@ +use crate::errors::AwsError; +use std::collections::BTreeMap; + +pub struct SigV4Params { + pub access_key_id: String, + pub secret_access_key: String, + pub session_token: Option, + pub region: String, + pub service: String, + pub method: String, + pub uri: String, + pub query_string: String, + pub headers: BTreeMap, + pub payload: Vec, +} + +#[cfg(feature = "aws-sigv4")] +pub fn sign_request(params: SigV4Params) -> Result<(String, String), AwsError> { + use aws_credential_types::Credentials; + use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings}; + use aws_sigv4::sign::v4; + use aws_smithy_runtime_api::client::identity::Identity; + use std::time::SystemTime; + + let credentials = Credentials::new( + ¶ms.access_key_id, + ¶ms.secret_access_key, + params.session_token.clone(), + None, + "plano", + ); + + let settings = SigningSettings::default(); + let identity: Identity = credentials.into(); + + let signing_params = v4::SigningParams::builder() + .identity(&identity) + .region(¶ms.region) + .name(¶ms.service) + .time(SystemTime::now()) + .settings(settings) + .build() + .map_err(|e| AwsError::SigningError(format!("Failed to build signing params: {}", e)))?; + + let host = params.headers.get("host").cloned().unwrap_or_default(); + let url = if params.query_string.is_empty() { + format!("https://{}{}", host, params.uri) + } else { + format!("https://{}{}?{}", host, params.uri, params.query_string) + }; + + let header_pairs: Vec<(String, String)> = params + .headers + .iter() + .filter(|(k, _)| { + let k = k.as_str(); + k != "host" && k != "x-amz-date" && k != "x-amz-security-token" + }) + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + + let signable_body = SignableBody::Bytes(¶ms.payload); + + let signable_request = SignableRequest::new( + ¶ms.method, + &url, + header_pairs.iter().map(|(k, v)| (k.as_str(), v.as_str())), + signable_body, + ) + .map_err(|e| AwsError::SigningError(format!("Failed to create signable request: {}", e)))?; + + let signing_output = sign(signable_request, &signing_params.into()) + .map_err(|e| AwsError::SigningError(format!("Failed to sign request: {}", e)))?; + + let mut authorization = String::new(); + let mut amz_date = String::new(); + let (instructions, _) = signing_output.into_parts(); + for (name, value) in instructions.headers() { + match name { + "authorization" => authorization = value.to_string(), + "x-amz-date" => amz_date = value.to_string(), + _ => {} + } + } + + if authorization.is_empty() { + return Err(AwsError::SigningError( + "Authorization header not produced by signing".to_string(), + )); + } + + Ok((authorization, amz_date)) +} + +#[cfg(not(feature = "aws-sigv4"))] +pub fn sign_request(_params: SigV4Params) -> Result<(String, String), AwsError> { + Err(AwsError::SigningError( + "aws-signing feature not enabled".to_string(), + )) +} + +#[cfg(all(test, feature = "aws-sigv4"))] +mod tests { + use super::*; + use std::collections::BTreeMap; + + #[test] + fn test_sign_request_produces_authorization() { + let mut headers = BTreeMap::new(); + headers.insert( + "host".to_string(), + "bedrock-runtime.us-east-1.amazonaws.com".to_string(), + ); + headers.insert("content-type".to_string(), "application/json".to_string()); + + let params = SigV4Params { + access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_access_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(), + session_token: Some("test-session-token".to_string()), + region: "us-east-1".to_string(), + service: "bedrock-runtime".to_string(), + method: "POST".to_string(), + uri: "/model/test-model/converse".to_string(), + query_string: String::new(), + headers, + payload: b"{}".to_vec(), + }; + + let result = sign_request(params); + assert!(result.is_ok(), "sign_request should succeed"); + + let (authorization, amz_date) = result.unwrap(); + assert!( + authorization.starts_with("AWS4-HMAC-SHA256"), + "Should use AWS4-HMAC-SHA256 algorithm" + ); + assert!(!amz_date.is_empty(), "amz_date should not be empty"); + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index ccca89c38..ee5a79a27 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -58,6 +58,16 @@ pub enum StateStorageType { Postgres, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AwsCredentialsConfig { + #[serde(rename = "AWS_ACCESS_KEY_ID")] + pub access_key_id: Option, + #[serde(rename = "AWS_SECRET_ACCESS_KEY")] + pub secret_access_key: Option, + #[serde(rename = "AWS_SESSION_TOKEN")] + pub session_token: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Configuration { pub version: String, @@ -77,6 +87,8 @@ pub struct Configuration { pub filters: Option>, pub listeners: Vec, pub state_storage: Option, + #[serde(rename = "aws_credentials")] + pub aws_credentials: Option, } #[derive(Debug, Clone, Serialize, Deserialize, Default)] diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs index 21af3c94f..352bbb547 100644 --- a/crates/common/src/errors.rs +++ b/crates/common/src/errors.rs @@ -43,3 +43,15 @@ pub enum ServerError { #[error("error parsing openai message: {0}")] OpenAIPError(#[from] OpenAIError), } + +#[derive(thiserror::Error, Debug)] +pub enum AwsError { + #[error("Failed to get credentials from environment: {0}")] + CredentialError(String), + #[error("STS AssumeRole failed: {0}")] + StsError(String), + #[error("AWS Signature V4 signing failed: {0}")] + SigningError(String), + #[error("Invalid AWS configuration: {0}")] + ConfigError(String), +} diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index aba27b9b2..228befd53 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -1,4 +1,6 @@ pub mod api; +pub mod aws_credentials; +pub mod aws_sigv4; pub mod configuration; pub mod consts; pub mod errors; diff --git a/crates/llm_gateway/Cargo.toml b/crates/llm_gateway/Cargo.toml index 281e05beb..7c427ecfb 100644 --- a/crates/llm_gateway/Cargo.toml +++ b/crates/llm_gateway/Cargo.toml @@ -27,3 +27,7 @@ bytes = "1.10" [dev-dependencies] serial_test = "3.1.1" + +[features] +default = [] +aws-sigv4 = ["common/aws-sigv4"] diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/llm_gateway/src/filter_context.rs index 4bcd9955a..c385616b4 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,7 +1,6 @@ use crate::metrics::Metrics; use crate::stream_context::StreamContext; -use common::configuration::Configuration; -use common::configuration::Overrides; +use common::configuration::{AwsCredentialsConfig, Configuration, Overrides}; use common::http::Client; use common::llm_providers::LlmProviders; use common::ratelimit; @@ -24,6 +23,7 @@ pub struct FilterContext { callouts: RefCell>, llm_providers: Option>, overrides: Rc>, + aws_credentials: Rc>, } impl FilterContext { @@ -33,6 +33,7 @@ impl FilterContext { metrics: Rc::new(Metrics::new()), llm_providers: None, overrides: Rc::new(None), + aws_credentials: Rc::new(None), } } } @@ -63,6 +64,7 @@ impl RootContext for FilterContext { ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); self.overrides = Rc::new(config.overrides); + self.aws_credentials = Rc::new(config.aws_credentials); match config.model_providers.try_into() { Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), @@ -86,6 +88,7 @@ impl RootContext for FilterContext { .expect("LLM Providers must exist when Streams are being created"), ), Rc::clone(&self.overrides), + Rc::clone(&self.aws_credentials), ))) } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index d31162333..40b2377dd 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -10,7 +10,9 @@ use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::metrics::Metrics; -use common::configuration::{LlmProvider, LlmProviderType, Overrides}; +use common::aws_credentials::get_credentials_from_config; +use common::aws_sigv4::{sign_request, SigV4Params}; +use common::configuration::{AwsCredentialsConfig, LlmProvider, LlmProviderType, Overrides}; use common::consts::{ ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, HEALTHZ_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, @@ -49,6 +51,7 @@ pub struct StreamContext { traceparent: Option, request_body_sent_time: Option, _overrides: Rc>, + aws_credentials: Rc>, user_message: Option, upstream_status_code: Option, binary_frame_decoder: Option>, @@ -56,6 +59,9 @@ pub struct StreamContext { http_protocol: Option, sse_buffer: Option, sse_chunk_processor: Option, + needs_bedrock_sigv4: bool, + bedrock_sigv4_signed: bool, + bedrock_signed_headers: Option<(String, String, Option)>, } impl StreamContext { @@ -63,10 +69,12 @@ impl StreamContext { metrics: Rc, llm_providers: Rc, overrides: Rc>, + aws_credentials: Rc>, ) -> Self { StreamContext { metrics, _overrides: overrides, + aws_credentials, ratelimit_selector: None, streaming_response: false, response_tokens: 0, @@ -87,6 +95,9 @@ impl StreamContext { http_protocol: None, sse_buffer: None, sse_chunk_processor: None, + needs_bedrock_sigv4: false, + bedrock_sigv4_signed: false, + bedrock_signed_headers: None, } } @@ -176,6 +187,8 @@ impl StreamContext { } fn modify_auth_headers(&mut self) -> Result<(), ServerError> { + self.needs_bedrock_sigv4 = false; + if self.llm_provider().passthrough_auth == Some(true) { // Check if client provided an Authorization header if self.get_http_request_header("Authorization").is_none() { @@ -193,6 +206,27 @@ impl StreamContext { return Ok(()); } + if matches!( + self.resolved_api.as_ref(), + Some( + SupportedUpstreamAPIs::AmazonBedrockConverse(_) + | SupportedUpstreamAPIs::AmazonBedrockConverseStream(_) + ) + ) && self.aws_credentials.is_some() + { + if cfg!(feature = "aws-sigv4") { + self.needs_bedrock_sigv4 = true; + self.remove_http_request_header("Authorization"); + self.remove_http_request_header("x-api-key"); + return Ok(()); + } + + warn!( + "[PLANO_REQ_ID:{}] BEDROCK_SIGV4_DISABLED: aws_credentials configured but aws-sigv4 feature is off; falling back to bearer auth", + self.request_identifier() + ); + } + let llm_provider_api_key_value = self.llm_provider() .access_key @@ -231,6 +265,87 @@ impl StreamContext { Ok(()) } + fn extract_host_for_sigv4(&self) -> Result { + if let Some(endpoint) = self.llm_provider().endpoint.as_ref() { + let mut host = endpoint.as_str(); + if let Some(stripped) = host.strip_prefix("https://") { + host = stripped; + } else if let Some(stripped) = host.strip_prefix("http://") { + host = stripped; + } + + let host = host.split('/').next().unwrap_or(host); + return Ok(host.to_string()); + } + + self.get_http_request_header(":authority") + .ok_or_else(|| ServerError::BadRequest { + why: "Unable to determine host for SigV4 signing".to_string(), + }) + } + + fn sign_bedrock_request_with_sigv4(&mut self, payload: &[u8]) -> Result<(String, String, Option), ServerError> { + let creds_config = + self.aws_credentials + .as_ref() + .as_ref() + .ok_or_else(|| ServerError::BadRequest { + why: "AWS credentials not configured".to_string(), + })?; + + let (access_key_id, secret_access_key, session_token) = + get_credentials_from_config(creds_config).map_err(|e| ServerError::BadRequest { + why: format!("Failed to load AWS credentials: {}", e), + })?; + + let method = self + .get_http_request_header(":method") + .unwrap_or_else(|| "POST".to_string()); + let path = self + .get_http_request_header(":path") + .unwrap_or_else(|| "/".to_string()); + + let (uri, query_string) = if let Some(q_pos) = path.find('?') { + (path[..q_pos].to_string(), path[q_pos + 1..].to_string()) + } else { + (path, String::new()) + }; + + let host = self.extract_host_for_sigv4()?; + + let mut headers = std::collections::BTreeMap::new(); + headers.insert("host".to_string(), host); + if let Some(ct) = self.get_http_request_header("content-type") { + headers.insert("content-type".to_string(), ct); + } + + let region = + extract_region_from_host(self.llm_provider().endpoint.as_deref().unwrap_or("")) + .or_else(|| { + self.get_http_request_header(":authority") + .and_then(|h| extract_region_from_host(&h)) + }) + .unwrap_or_else(|| "us-east-1".to_string()); + + let (authorization, amz_date) = sign_request(SigV4Params { + access_key_id, + secret_access_key, + session_token: session_token.clone(), + region, + service: "bedrock-runtime".to_string(), + method, + uri, + query_string, + headers, + payload: payload.to_vec(), + }) + .map_err(|e| ServerError::BadRequest { + why: format!("Failed to sign AWS request: {}", e), + })?; + + Ok((authorization, amz_date, session_token)) + } + fn delete_content_length_header(&mut self) { // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. // Server's generally throw away requests whose body length do not match the Content-Length header. @@ -797,6 +912,26 @@ impl StreamContext { } } +fn extract_region_from_host(host: &str) -> Option { + let host = host + .trim() + .trim_start_matches("https://") + .trim_start_matches("http://") + .trim_end_matches('/'); + + let host = host.split('/').next().unwrap_or(host); + if let Some(domain) = host.strip_prefix("bedrock-runtime.") { + if let Some(region) = domain.strip_suffix(".amazonaws.com.cn") { + return Some(region.to_string()); + } + if let Some(region) = domain.strip_suffix(".amazonaws.com") { + return Some(region.to_string()); + } + } + + None +} + // HttpContext is the trait that allows the Rust code to interact with HTTP objects. impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto @@ -886,6 +1021,15 @@ impl HttpContext for StreamContext { self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); self.traceparent = self.get_http_request_header(TRACE_PARENT_HEADER); + // Pause request if we need to compute SigV4 signature (requires body) + if self.needs_bedrock_sigv4 { + debug!( + "[PLANO_REQ_ID:{}] BEDROCK_SIGV4_PAUSE: Pausing in headers to wait for body", + self.request_identifier() + ); + return Action::Pause; + } + Action::Continue } @@ -1099,6 +1243,37 @@ impl HttpContext for StreamContext { } }; + if self.needs_bedrock_sigv4 && !self.bedrock_sigv4_signed { + match self.sign_bedrock_request_with_sigv4(&serialized_body_bytes_upstream) { + Ok((auth, date, token)) => { + self.bedrock_signed_headers = Some((auth, date, token)); + self.bedrock_sigv4_signed = true; + debug!( + "[PLANO_REQ_ID:{}] BEDROCK_SIGV4_SIGNED: Headers computed, will set before resume", + self.request_identifier() + ); + } + Err(e) => { + self.send_server_error(e, Some(StatusCode::BAD_REQUEST)); + return Action::Pause; + } + } + return Action::Pause; + } + + if self.needs_bedrock_sigv4 && self.bedrock_sigv4_signed { + if let Some((ref auth, ref date, ref token)) = self.bedrock_signed_headers { + self.set_http_request_header("Authorization", Some(auth)); + self.set_http_request_header("x-amz-date", Some(date)); + if let Some(ref t) = token { + self.set_http_request_header("x-amz-security-token", Some(t)); + } + } + self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream); + self.resume_http_request(); + return Action::Continue; + } + self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream); Action::Continue }