Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ pub struct Config {
/// Policy configuration
pub policy: PolicyConfig,

/// SSRF protection configuration
pub ssrf: SsrfConfig,

/// Identity registry configuration
pub identity: IdentityConfig,

Expand All @@ -32,6 +35,18 @@ pub struct Config {
pub logging: LoggingConfig,
}

/// SSRF protection configuration
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(default)]
pub struct SsrfConfig {
/// Allowed endpoints that bypass SSRF protection (host:port format)
/// Example: ["172.30.192.1:11434", "127.0.0.1:9200"]
pub allowed_endpoints: Vec<String>,

/// Disable SSRF protection entirely (not recommended)
pub disabled: bool,
}

/// Server configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
Expand Down Expand Up @@ -72,6 +87,14 @@ pub struct PolicyConfig {

/// Hot-reload check interval in seconds
pub hot_reload_interval_s: u64,

/// Secret required for /policy/reload endpoint (bearer token)
/// If set, requests must include `Authorization: Bearer <secret>`
pub reload_secret: Option<String>,

/// Disable the /policy/reload endpoint entirely
/// If true, the endpoint returns 404
pub disable_reload: bool,
}

/// Identity registry configuration
Expand Down Expand Up @@ -361,6 +384,14 @@ shutdown_timeout_s = 30
# file = "/path/to/policy.json"
hot_reload = false
hot_reload_interval_s = 30
# reload_secret = "your-secret-here" # Require bearer token for /policy/reload
# disable_reload = false # Set to true to disable /policy/reload entirely

# SSRF Protection Configuration
# By default, SSRF blocks private IPs, localhost, and cloud metadata endpoints
[ssrf]
# allowed_endpoints = ["172.30.192.1:11434", "127.0.0.1:9200"] # Bypass SSRF for these
# disabled = false # Set to true to disable all SSRF protection (not recommended)

[identity]
# file = "~/.predicate/local-identity-registry.json"
Expand Down
211 changes: 203 additions & 8 deletions src/http/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ pub struct AppState {
pub start_time: std::time::Instant,
pub mode: String,
pub identity_mode: String,
/// Secret required for /policy/reload endpoint (if set)
pub policy_reload_secret: Option<String>,
/// Whether /policy/reload endpoint is disabled
pub policy_reload_disabled: bool,
}

impl AppState {
Expand All @@ -55,9 +59,21 @@ impl AppState {
start_time: std::time::Instant::now(),
mode: mode.to_string(),
identity_mode: "local".to_string(),
policy_reload_secret: None,
policy_reload_disabled: false,
}
}

pub fn with_policy_reload_secret(mut self, secret: Option<String>) -> Self {
self.policy_reload_secret = secret;
self
}

pub fn with_policy_reload_disabled(mut self, disabled: bool) -> Self {
self.policy_reload_disabled = disabled;
self
}

pub fn with_identity_registry(mut self, registry: LocalIdentityRegistry) -> Self {
self.identity_registry = Some(Arc::new(registry));
self
Expand Down Expand Up @@ -105,13 +121,16 @@ pub fn create_router(state: AppState) -> Router {
router = router.route("/v1/execute", post(execute_handler));
}

// Conditionally add policy reload endpoint (unless disabled)
if !state.policy_reload_disabled {
router = router.route("/policy/reload", post(policy_reload_handler));
}

router
// Operations
.route("/health", get(health_handler))
.route("/status", get(status_handler))
.route("/metrics", get(metrics_handler))
// Policy management
.route("/policy/reload", post(policy_reload_handler))
// Identity management
.route("/identity/task", post(identity_task_handler))
.route("/identity/revoke", post(identity_revoke_handler))
Expand Down Expand Up @@ -523,18 +542,54 @@ struct PolicyReloadResponse {

async fn policy_reload_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<PolicyReloadRequest>,
) -> Json<PolicyReloadResponse> {
) -> impl IntoResponse {
// Check authentication if a reload secret is configured
if let Some(ref expected_secret) = state.policy_reload_secret {
let provided_token = extract_bearer_token(&headers);
match provided_token {
Some(token) if token == *expected_secret => {
// Authentication successful
}
Some(_) => {
warn!("Policy reload rejected: invalid bearer token");
return (
StatusCode::UNAUTHORIZED,
Json(PolicyReloadResponse {
success: false,
rule_count: 0,
message: "Invalid bearer token".to_string(),
}),
);
}
None => {
warn!("Policy reload rejected: missing bearer token");
return (
StatusCode::UNAUTHORIZED,
Json(PolicyReloadResponse {
success: false,
rule_count: 0,
message: "Authorization required: Bearer <token>".to_string(),
}),
);
}
}
}

let rule_count = request.rules.len();

info!("Reloading policy with {} rules", rule_count);
state.policy_engine.replace_rules(request.rules);

Json(PolicyReloadResponse {
success: true,
rule_count,
message: format!("Loaded {} rules", rule_count),
})
(
StatusCode::OK,
Json(PolicyReloadResponse {
success: true,
rule_count,
message: format!("Loaded {} rules", rule_count),
}),
)
}

// --- Identity Management ---
Expand Down Expand Up @@ -1299,4 +1354,144 @@ mod tests {
assert!(response_json["allowed"].as_bool().unwrap());
assert!(response_json["mandate_token"].is_string());
}

// --- Policy reload authentication tests (Issue #26) ---

fn test_state_with_policy_reload_secret() -> AppState {
AppState::new(PolicyEngine::new(), "test")
.with_policy_reload_secret(Some("test-reload-secret".to_string()))
}

fn test_state_with_policy_reload_disabled() -> AppState {
AppState::new(PolicyEngine::new(), "test").with_policy_reload_disabled(true)
}

#[tokio::test]
async fn test_policy_reload_requires_auth_when_secret_set() {
let app = create_router(test_state_with_policy_reload_secret());

// Without Authorization header - should fail
let body = r#"{"rules": []}"#;
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/policy/reload")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);

let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(!response_json["success"].as_bool().unwrap());
assert!(response_json["message"]
.as_str()
.unwrap()
.contains("Authorization required"));
}

#[tokio::test]
async fn test_policy_reload_rejects_invalid_token() {
let app = create_router(test_state_with_policy_reload_secret());

let body = r#"{"rules": []}"#;
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/policy/reload")
.header("content-type", "application/json")
.header("authorization", "Bearer wrong-secret")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::UNAUTHORIZED);

let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(!response_json["success"].as_bool().unwrap());
assert!(response_json["message"]
.as_str()
.unwrap()
.contains("Invalid bearer token"));
}

#[tokio::test]
async fn test_policy_reload_succeeds_with_valid_token() {
let app = create_router(test_state_with_policy_reload_secret());

let body = r#"{"rules": []}"#;
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/policy/reload")
.header("content-type", "application/json")
.header("authorization", "Bearer test-reload-secret")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);

let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(response_json["success"].as_bool().unwrap());
}

#[tokio::test]
async fn test_policy_reload_disabled_returns_404() {
let app = create_router(test_state_with_policy_reload_disabled());

let body = r#"{"rules": []}"#;
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/policy/reload")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();

// When route is not mounted, axum returns 404
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}

#[tokio::test]
async fn test_policy_reload_no_auth_when_no_secret() {
// Without policy_reload_secret, should work without auth
let app = create_router(test_state());

let body = r#"{"rules": []}"#;
let response = app
.oneshot(
Request::builder()
.method("POST")
.uri("/policy/reload")
.header("content-type", "application/json")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();

assert_eq!(response.status(), StatusCode::OK);

let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(response_json["success"].as_bool().unwrap());
}
}
61 changes: 60 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,29 @@ struct Cli {
/// Maximum delegation depth (default: 5)
#[arg(long, env = "PREDICATE_MAX_DELEGATION_DEPTH", default_value = "5")]
max_delegation_depth: u32,

// --- Policy reload security options ---
/// Secret required for /policy/reload endpoint (bearer token)
#[arg(long, env = "PREDICATE_POLICY_RELOAD_SECRET")]
policy_reload_secret: Option<String>,

/// Disable the /policy/reload endpoint entirely
#[arg(long, env = "PREDICATE_DISABLE_POLICY_RELOAD")]
disable_policy_reload: bool,

// --- SSRF protection options ---
/// Allowed endpoints that bypass SSRF protection (comma-separated, host:port format)
/// Example: --ssrf-allow 172.30.192.1:11434,127.0.0.1:9200
#[arg(
long = "ssrf-allow",
env = "PREDICATE_SSRF_ALLOW",
value_delimiter = ','
)]
ssrf_allow: Vec<String>,

/// Disable SSRF protection entirely (not recommended)
#[arg(long, env = "PREDICATE_SSRF_DISABLED")]
ssrf_disabled: bool,
}

#[derive(Subcommand, Debug)]
Expand Down Expand Up @@ -461,6 +484,28 @@ async fn main() -> anyhow::Result<()> {
// Initialize policy engine
let policy_engine = PolicyEngine::new();

// Configure SSRF protection
let ssrf_disabled = cli.ssrf_disabled || file_config.ssrf.disabled;
let ssrf_allowed_endpoints: Vec<String> = if !cli.ssrf_allow.is_empty() {
cli.ssrf_allow.clone()
} else {
file_config.ssrf.allowed_endpoints.clone()
};

if ssrf_disabled {
policy_engine.set_ssrf_protection(None);
warn!("SSRF protection disabled - all endpoints allowed");
} else if !ssrf_allowed_endpoints.is_empty() {
use predicate_authorityd::ssrf::SsrfProtection;
let ssrf = SsrfProtection::new().with_allowed_endpoints(ssrf_allowed_endpoints.clone());
policy_engine.set_ssrf_protection(Some(ssrf));
info!(
"SSRF protection enabled with {} allowed endpoints: {:?}",
ssrf_allowed_endpoints.len(),
ssrf_allowed_endpoints
);
}

// Load policy file if specified (supports JSON and YAML formats)
if let Some(ref policy_path) = policy_file {
let format = policy_loader::detect_format(policy_path);
Expand Down Expand Up @@ -503,8 +548,22 @@ async fn main() -> anyhow::Result<()> {
info!("Audit mode enabled via --audit-mode flag");
}

// Merge policy reload config
let policy_reload_secret = cli
.policy_reload_secret
.or(file_config.policy.reload_secret);
let disable_policy_reload = cli.disable_policy_reload || file_config.policy.disable_reload;

// Create application state
let mut state = AppState::new(policy_engine, &mode);
let mut state = AppState::new(policy_engine, &mode)
.with_policy_reload_secret(policy_reload_secret.clone())
.with_policy_reload_disabled(disable_policy_reload);

if disable_policy_reload {
info!("Policy reload endpoint disabled");
} else if policy_reload_secret.is_some() {
info!("Policy reload endpoint protected with bearer token");
}

// Initialize IdP bridge based on identity_mode
let local_idp_signing_key = std::env::var(&local_idp_signing_key_env)
Expand Down
Loading
Loading