diff --git a/src/stackone_defender/core/prompt_defense.py b/src/stackone_defender/core/prompt_defense.py index 9018cb0..d036ce2 100644 --- a/src/stackone_defender/core/prompt_defense.py +++ b/src/stackone_defender/core/prompt_defense.py @@ -55,7 +55,7 @@ def __init__( if block_high_risk: self._config.block_high_risk = True - tool_rules = self._config.tool_rules if use_default_tool_rules else [] + tool_rules = (config or {}).get("tool_rules") or (self._config.tool_rules if use_default_tool_rules else []) self._tool_sanitizer: ToolResultSanitizer = create_tool_result_sanitizer( risky_fields=self._config.risky_fields, @@ -120,7 +120,20 @@ def defend_tool_result(self, value: Any, tool_name: str) -> DefenseResult: tier2_idx = _RISK_LEVELS.index(tier2_risk) risk_level = _RISK_LEVELS[max(tier1_idx, tier2_idx)] - allowed = not self._config.block_high_risk or risk_level not in ("high", "critical") + # Determine whether any threat signals were found (Tier 1 or Tier 2). + # fields_sanitized captures sanitization methods (role stripping, encoding detection, etc.) + # that may fire without adding named pattern detections, so we include it here. + has_threats = ( + len(detections) > 0 + or len(fields_sanitized) > 0 + or (tier2_score is not None and tier2_score >= self._config.tier2.high_risk_threshold) + ) + + # Three cases for allowed: + # 1. block_high_risk is off -> always allow + # 2. No threat signals found -> allow (base risk from tool rules alone does not block) + # 3. Risk did not reach high/critical -> allow + allowed = not self._config.block_high_risk or not has_threats or risk_level not in ("high", "critical") return DefenseResult( allowed=allowed, diff --git a/tests/test_integration.py b/tests/test_integration.py index 29ab83f..50735e0 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -147,6 +147,44 @@ def test_returns_latency(self): assert result.latency_ms > 0 +class TestUseDefaultToolRules: + def test_does_not_apply_tool_rules_by_default(self): + defense = create_prompt_defense() + data = {"subject": "Weekly team update", "body": "Reminder about the meeting tomorrow at 10am.", "thread_id": "thread123"} + result = defense.defend_tool_result(data, "gmail_get_message") + # Without use_default_tool_rules, gmail tool rule should NOT seed risk_level to 'high' + assert result.risk_level not in ("high", "critical") + + def test_does_not_apply_tool_rules_when_explicitly_false(self): + defense = create_prompt_defense(use_default_tool_rules=False) + data = {"subject": "Weekly team update", "body": "Reminder about the meeting tomorrow at 10am.", "thread_id": "thread123"} + result = defense.defend_tool_result(data, "gmail_get_message") + assert result.risk_level not in ("high", "critical") + + def test_applies_tool_rules_when_true(self): + defense = create_prompt_defense(use_default_tool_rules=True, block_high_risk=True) + data = {"subject": "Weekly team update", "body": "Reminder about the meeting tomorrow at 10am.", "thread_id": "thread123"} + result = defense.defend_tool_result(data, "gmail_get_message") + # With use_default_tool_rules, gmail tool rule seeds risk_level: 'high' as base risk, + # but safe content with no detections should still be allowed through. + assert result.risk_level == "high" + assert result.allowed is True + + def test_always_applies_custom_tool_rules_from_config(self): + from stackone_defender.types import ToolSanitizationRule + defense = create_prompt_defense( + use_default_tool_rules=False, + config={"tool_rules": [ToolSanitizationRule(tool_pattern="custom_*", sanitization_level="high")]}, + block_high_risk=True, + ) + data = {"name": "Safe content"} + result = defense.defend_tool_result(data, "custom_tool") + # Custom rules set base risk_level: 'high', but safe content with no detections + # should still be allowed through — base risk alone does not block. + assert result.risk_level == "high" + assert result.allowed is True + + class TestRealWorldScenarios: def setup_method(self): self.defense = create_prompt_defense()