diff --git a/src/cyvest/cyvest.py b/src/cyvest/cyvest.py index 15342b0..ec53f98 100644 --- a/src/cyvest/cyvest.py +++ b/src/cyvest/cyvest.py @@ -39,7 +39,7 @@ ) from cyvest.io_visualization import generate_network_graph from cyvest.levels import Level -from cyvest.model import Check, Enrichment, Observable, Tag, Taxonomy, ThreatIntel +from cyvest.model import Check, Enrichment, Observable, Tag, Taxonomy, ThreatIntel, round_score_decimal from cyvest.model_enums import ObservableType, PropagationMode, RelationshipDirection, RelationshipType from cyvest.model_schema import InvestigationSchema, StatisticsSchema from cyvest.proxies import CheckProxy, EnrichmentProxy, ObservableProxy, TagProxy, ThreatIntelProxy @@ -1063,6 +1063,83 @@ def io_load_dict(data: dict[str, Any]) -> Cyvest: """ return load_investigation_dict(data) + def io_load_threat_intel_draft( + self, + report: dict[str, Any], + *, + preprocessor: Callable[[dict[str, Any]], dict[str, Any]] | None = None, + ) -> ThreatIntel: + """ + Load a ThreatIntel draft from an external API report dict. + + Extracts standard threat-intel fields (source, score, level, comment, + extra, taxonomies) from *report*, rounds the score to two decimal + places, and returns a validated :class:`ThreatIntel` instance that is + **not yet bound** to any observable. Attach it afterwards with + :meth:`observable_with_ti_draft` or + :meth:`ObservableProxy.with_ti_draft`. + + Args: + report: Dictionary with threat-intel fields coming from an + external service (e.g. a SOAR/TIP API response). + preprocessor: Optional callback that receives a **shallow copy** + of *report* and returns a (possibly modified) dict before + validation. Useful for source-specific normalisation such + as overriding the level for warning-list entries. + + Returns: + Unbound ThreatIntel instance (observable_key is empty). + + Raises: + TypeError: If *report* is not a dict. + pydantic.ValidationError: If the extracted payload fails + ThreatIntel model validation. + + Examples: + Basic usage:: + + report = {"source": "virustotal", "score": 4.256, "level": "SUSPICIOUS"} + ti = cv.io_load_threat_intel_draft(report) + obs.with_ti_draft(ti) + + With a preprocessor that forces MISP warning-list reports to SAFE:: + + def misp_warning_list_preprocessor(data: dict) -> dict: + extra = data.get("extra") + task_name = str(extra.get("task_name", "")) if isinstance(extra, dict) else "" + warning_list_tasks = {"MISP.analyzer.DBWarningList", "MISP.analyzer.SearchWarningList"} + if task_name in warning_list_tasks and data.get("level") not in ("INFO", "SAFE"): + data["level"] = "SAFE" + data["score"] = 0.0 + return data + + ti = cv.io_load_threat_intel_draft(report, preprocessor=misp_warning_list_preprocessor) + """ + if not isinstance(report, dict): + raise TypeError(f"report must be a dict, got {type(report).__name__}") + + data: dict[str, Any] = dict(report) # shallow copy + if preprocessor is not None: + data = preprocessor(data) + + raw_score = data.get("score") + if raw_score is not None: + rounded = round_score_decimal(Decimal(str(raw_score))) + else: + rounded = None # let ThreatIntel validators decide + + ti_payload: dict[str, Any] = { + "source": str(data.get("source", "")), + "observable_key": "", + "comment": str(data.get("comment", "") or ""), + "extra": data.get("extra"), + "score": rounded, + "level": data.get("level"), + "taxonomies": data.get("taxonomies", []), + } + + return ThreatIntel.model_validate(ti_payload) + # Shared context, investigation merging, finalization, comparison def shared_context( diff --git a/src/cyvest/model.py b/src/cyvest/model.py index d285895..a77a106 100644 --- a/src/cyvest/model.py +++ b/src/cyvest/model.py @@ -49,17 +49,22 @@ def model_dump_json(self, *, by_alias: bool = True, **kwargs: Any) -> str: return super().model_dump_json(by_alias=by_alias, **kwargs) -def _format_score_decimal(value: Decimal | None, *, places: int = _DEFAULT_SCORE_PLACES) -> str: - if value is None: - return "-" +def round_score_decimal(value: Decimal, *, places: int = _DEFAULT_SCORE_PLACES) -> Decimal: + """Round a Decimal score to *places* decimal places (ROUND_HALF_UP).""" if places < 0: raise ValueError("places must be >= 0") quantizer = Decimal("1").scaleb(-places) + quantized = value.quantize(quantizer, rounding=ROUND_HALF_UP) + if quantized == 0: + quantized = Decimal("0").quantize(quantizer) + return quantized + + +def _format_score_decimal(value: Decimal | None, *, places: int = _DEFAULT_SCORE_PLACES) -> str: + if value is None: + return "-" try: - quantized = value.quantize(quantizer, rounding=ROUND_HALF_UP) - if quantized == 0: - quantized = Decimal("0").quantize(quantizer) - return format(quantized, "f") + return format(round_score_decimal(value, places=places), "f") except InvalidOperation: return str(value) diff --git a/tests/test_cyvest.py b/tests/test_cyvest.py index a72801a..d81170a 100644 --- a/tests/test_cyvest.py +++ b/tests/test_cyvest.py @@ -809,3 +809,93 @@ def test_io_save_json_with_invalid_path() -> None: # Try to write to a directory that doesn't exist with pytest.raises(OSError): cv.io_save_json("/nonexistent/directory/investigation.json") + + +# ── threat_intel_from_report ────────────────────────────────────────── + + +def test_threat_intel_from_report_basic() -> None: + """Valid report dict returns an unbound ThreatIntel with rounded score.""" + cv = Cyvest() + report = { + "source": "virustotal", + "score": 4.256, + "level": "SUSPICIOUS", + "comment": "Detected by 12 engines", + "extra": {"engines": 12}, + "taxonomies": [{"level": "SUSPICIOUS", "name": "VT", "value": "12/70"}], + } + ti = cv.io_load_threat_intel_draft(report) + + assert ti.source == "virustotal" + assert ti.score == Decimal("4.26") # rounded + assert ti.level.value == "SUSPICIOUS" + assert ti.comment == "Detected by 12 engines" + assert ti.extra == {"engines": 12} + assert ti.observable_key == "" + assert len(ti.taxonomies) == 1 + assert ti.taxonomies[0].name == "VT" + + +def test_threat_intel_from_report_invalid_type() -> None: + """Non-dict input raises TypeError.""" + cv = Cyvest() + with pytest.raises(TypeError, match="report must be a dict"): + cv.io_load_threat_intel_draft("not a dict") # type: ignore[arg-type] + with pytest.raises(TypeError, match="report must be a dict"): + cv.io_load_threat_intel_draft(42) # type: ignore[arg-type] + + +def test_threat_intel_from_report_validation_error() -> None: + """Missing required score raises ValidationError.""" + from pydantic import ValidationError + + cv = Cyvest() + with pytest.raises(ValidationError): + cv.io_load_threat_intel_draft({"source": "test"}) + + +def test_threat_intel_from_report_preprocessor() -> None: + """Preprocessor callback can modify data before validation.""" + cv = Cyvest() + + def force_safe(data: dict) -> dict: + data["level"] = "SAFE" + data["score"] = 0.0 + return data + + ti = cv.io_load_threat_intel_draft( + {"source": "misp", "score": 7.5, "level": "MALICIOUS"}, + preprocessor=force_safe, + ) + assert ti.level.value == "SAFE" + assert ti.score == Decimal("0.00") + + +def test_threat_intel_from_report_does_not_mutate_input() -> None: + """Original report dict is not modified.""" + cv = Cyvest() + report = {"source": "test", "score": 3.0} + original = dict(report) + + def mutating_preprocessor(data: dict) -> dict: + data["level"] = "MALICIOUS" + return data + + cv.io_load_threat_intel_draft(report, preprocessor=mutating_preprocessor) + assert report == original + + +def test_threat_intel_from_report_attach_to_observable() -> None: + """Round-trip: from_report → observable_with_ti_draft → TI attached and scored.""" + cv = Cyvest() + obs = cv.observable_create("domain", "evil.example.com") + ti = cv.io_load_threat_intel_draft({"source": "abuse.ch", "score": 8.0}) + + bound = cv.observable_with_ti_draft(obs, ti) + assert bound.source == "abuse.ch" + assert bound.score == Decimal("8.00") + + refreshed = cv.observable_get(obs.key) + assert refreshed is not None + assert refreshed.score >= Decimal("8.00")