diff --git a/contentctl/actions/detection_testing/DetectionTestingManager.py b/contentctl/actions/detection_testing/DetectionTestingManager.py index 8a8dd741..ae0df1e3 100644 --- a/contentctl/actions/detection_testing/DetectionTestingManager.py +++ b/contentctl/actions/detection_testing/DetectionTestingManager.py @@ -1,7 +1,16 @@ +import concurrent.futures +import datetime +import signal +import traceback +from dataclasses import dataclass from typing import List, Union -from contentctl.objects.config import test, test_servers, Container, Infrastructure + +import docker +from pydantic import BaseModel + from contentctl.actions.detection_testing.infrastructures.DetectionTestingInfrastructure import ( DetectionTestingInfrastructure, + DetectionTestingManagerOutputDto, ) from contentctl.actions.detection_testing.infrastructures.DetectionTestingInfrastructureContainer import ( DetectionTestingInfrastructureContainer, @@ -9,24 +18,12 @@ from contentctl.actions.detection_testing.infrastructures.DetectionTestingInfrastructureServer import ( DetectionTestingInfrastructureServer, ) -import signal -import datetime - -# from queue import Queue -from dataclasses import dataclass - -# import threading -from contentctl.actions.detection_testing.infrastructures.DetectionTestingInfrastructure import ( - DetectionTestingManagerOutputDto, -) from contentctl.actions.detection_testing.views.DetectionTestingView import ( DetectionTestingView, ) -from contentctl.objects.enums import PostTestBehavior -from pydantic import BaseModel +from contentctl.objects.config import Container, Infrastructure, test, test_servers from contentctl.objects.detection import Detection -import concurrent.futures -import docker +from contentctl.objects.enums import PostTestBehavior @dataclass(frozen=False) @@ -63,12 +60,14 @@ def sigint_handler(signum, frame): # a newline '\r\n' which will cause that wait to stop print("*******************************") print( - "If testing is paused and you are debugging a detection, you MUST hit CTRL-D at the prompt to complete shutdown." + "If testing is paused and you are debugging a detection, you MUST hit CTRL-D " + "at the prompt to complete shutdown." ) print("*******************************") signal.signal(signal.SIGINT, sigint_handler) + # TODO (#337): futures can be hard to maintain/debug; let's consider alternatives with ( concurrent.futures.ThreadPoolExecutor( max_workers=len(self.input_dto.config.test_instances), @@ -80,10 +79,19 @@ def sigint_handler(signum, frame): max_workers=len(self.input_dto.config.test_instances), ) as view_shutdowner, ): + # Capture any errors for reporting at the end after all threads have been gathered + errors: dict[str, list[Exception]] = { + "INSTANCE SETUP ERRORS": [], + "TESTING ERRORS": [], + "ERRORS DURING VIEW SHUTDOWN": [], + "ERRORS DURING VIEW EXECUTION": [], + } + # Start all the views future_views = { view_runner.submit(view.setup): view for view in self.input_dto.views } + # Configure all the instances future_instances_setup = { instance_pool.submit(instance.setup): instance @@ -96,7 +104,11 @@ def sigint_handler(signum, frame): future.result() except Exception as e: self.output_dto.terminate = True - print(f"Error setting up container: {str(e)}") + # Output the traceback if we encounter errors in verbose mode + if self.input_dto.config.verbose: + tb = traceback.format_exc() + print(tb) + errors["INSTANCE SETUP ERRORS"].append(e) # Start and wait for all tests to run if not self.output_dto.terminate: @@ -111,7 +123,11 @@ def sigint_handler(signum, frame): future.result() except Exception as e: self.output_dto.terminate = True - print(f"Error running in container: {str(e)}") + # Output the traceback if we encounter errors in verbose mode + if self.input_dto.config.verbose: + tb = traceback.format_exc() + print(tb) + errors["TESTING ERRORS"].append(e) self.output_dto.terminate = True @@ -123,14 +139,34 @@ def sigint_handler(signum, frame): try: future.result() except Exception as e: - print(f"Error stopping view: {str(e)}") + # Output the traceback if we encounter errors in verbose mode + if self.input_dto.config.verbose: + tb = traceback.format_exc() + print(tb) + errors["ERRORS DURING VIEW SHUTDOWN"].append(e) # Wait for original view-related threads to complete for future in concurrent.futures.as_completed(future_views): try: future.result() except Exception as e: - print(f"Error running container: {str(e)}") + # Output the traceback if we encounter errors in verbose mode + if self.input_dto.config.verbose: + tb = traceback.format_exc() + print(tb) + errors["ERRORS DURING VIEW EXECUTION"].append(e) + + # Log any errors + for error_type in errors: + if len(errors[error_type]) > 0: + print() + print(f"[{error_type}]:") + for error in errors[error_type]: + print(f"\t❌ {str(error)}") + if isinstance(error, ExceptionGroup): + for suberror in error.exceptions: # type: ignore + print(f"\t\t❌ {str(suberror)}") # type: ignore + print() return self.output_dto @@ -154,12 +190,15 @@ def create_DetectionTestingInfrastructureObjects(self): ) if len(parts) != 2: raise Exception( - f"Expected to find a name:tag in {self.input_dto.config.container_settings.full_image_path}, " - f"but instead found {parts}. Note that this path MUST include the tag, which is separated by ':'" + "Expected to find a name:tag in " + f"{self.input_dto.config.container_settings.full_image_path}, " + f"but instead found {parts}. Note that this path MUST include the " + "tag, which is separated by ':'" ) print( - f"Getting the latest version of the container image [{self.input_dto.config.container_settings.full_image_path}]...", + "Getting the latest version of the container image " + f"[{self.input_dto.config.container_settings.full_image_path}]...", end="", flush=True, ) @@ -168,7 +207,8 @@ def create_DetectionTestingInfrastructureObjects(self): break except Exception as e: raise Exception( - f"Failed to pull docker container image [{self.input_dto.config.container_settings.full_image_path}]: {str(e)}" + "Failed to pull docker container image " + f"[{self.input_dto.config.container_settings.full_image_path}]: {str(e)}" ) already_staged_container_files = False diff --git a/contentctl/actions/detection_testing/infrastructures/DetectionTestingInfrastructure.py b/contentctl/actions/detection_testing/infrastructures/DetectionTestingInfrastructure.py index b981589e..5f9fbdae 100644 --- a/contentctl/actions/detection_testing/infrastructures/DetectionTestingInfrastructure.py +++ b/contentctl/actions/detection_testing/infrastructures/DetectionTestingInfrastructure.py @@ -11,13 +11,20 @@ from ssl import SSLEOFError, SSLZeroReturnError from sys import stdout from tempfile import TemporaryDirectory, mktemp -from typing import Optional, Union +from typing import Callable, Optional, Union import requests # type: ignore import splunklib.client as client # type: ignore -import splunklib.results import tqdm # type: ignore -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, dataclasses +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + computed_field, + dataclasses, +) +from semantic_version import Version from splunklib.binding import HTTPError # type: ignore from splunklib.results import JSONResultsReader, Message # type: ignore from urllib3 import disable_warnings @@ -31,7 +38,8 @@ from contentctl.helper.utils import Utils from contentctl.objects.base_test import BaseTest from contentctl.objects.base_test_result import TestResultStatus -from contentctl.objects.config import Infrastructure, test_common +from contentctl.objects.config import All, Infrastructure, test_common +from contentctl.objects.content_versioning_service import ContentVersioningService from contentctl.objects.correlation_search import CorrelationSearch, PbarData from contentctl.objects.detection import Detection from contentctl.objects.enums import AnalyticsType, PostTestBehavior @@ -42,6 +50,9 @@ from contentctl.objects.unit_test import UnitTest from contentctl.objects.unit_test_result import UnitTestResult +# The app name of ES; needed to check ES version +ES_APP_NAME = "SplunkEnterpriseSecuritySuite" + class SetupTestGroupResults(BaseModel): exception: Union[Exception, None] = None @@ -136,21 +147,27 @@ def setup(self): ) self.start_time = time.time() + + # Init the list of setup functions we always need + primary_setup_functions: list[ + tuple[Callable[[], None | client.Service], str] + ] = [ + (self.start, "Starting"), + (self.get_conn, "Waiting for App Installation"), + (self.configure_conf_file_datamodels, "Configuring Datamodels"), + (self.create_replay_index, f"Create index '{self.sync_obj.replay_index}'"), + (self.get_all_indexes, "Getting all indexes from server"), + (self.check_for_es_install, "Checking for ES Install"), + (self.configure_imported_roles, "Configuring Roles"), + (self.configure_delete_indexes, "Configuring Indexes"), + (self.configure_hec, "Configuring HEC"), + (self.wait_for_ui_ready, "Finishing Primary Setup"), + ] + + # Execute and report on each setup function try: - for func, msg in [ - (self.start, "Starting"), - (self.get_conn, "Waiting for App Installation"), - (self.configure_conf_file_datamodels, "Configuring Datamodels"), - ( - self.create_replay_index, - f"Create index '{self.sync_obj.replay_index}'", - ), - (self.get_all_indexes, "Getting all indexes from server"), - (self.configure_imported_roles, "Configuring Roles"), - (self.configure_delete_indexes, "Configuring Indexes"), - (self.configure_hec, "Configuring HEC"), - (self.wait_for_ui_ready, "Finishing Setup"), - ]: + # Run the primary setup functions + for func, msg in primary_setup_functions: self.format_pbar_string( TestReportingType.SETUP, self.get_name(), @@ -160,18 +177,114 @@ def setup(self): func() self.check_for_teardown() + # Run any setup functions only applicable to content versioning validation + if self.should_test_content_versioning: + self.pbar.write( + self.format_pbar_string( + TestReportingType.SETUP, + self.get_name(), + "Beginning Content Versioning Validation...", + set_pbar=False, + ) + ) + for func, msg in self.content_versioning_service.setup_functions: + self.format_pbar_string( + TestReportingType.SETUP, + self.get_name(), + msg, + update_sync_status=True, + ) + func() + self.check_for_teardown() + except Exception as e: - self.pbar.write(str(e)) + msg = f"[{self.get_name()}]: {str(e)}" self.finish() - return + if isinstance(e, ExceptionGroup): + raise ExceptionGroup(msg, e.exceptions) from e # type: ignore + raise Exception(msg) from e - self.format_pbar_string( - TestReportingType.SETUP, self.get_name(), "Finished Setup!" + self.pbar.write( + self.format_pbar_string( + TestReportingType.SETUP, + self.get_name(), + "Finished Setup!", + set_pbar=False, + ) ) def wait_for_ui_ready(self): self.get_conn() + @computed_field + @property + def content_versioning_service(self) -> ContentVersioningService: + """ + A computed field returning a handle to the content versioning service, used by ES to + version detections. We use this model to validate that all detections have been installed + compatibly with ES versioning. + + :return: a handle to the content versioning service on the instance + :rtype: :class:`contentctl.objects.content_versioning_service.ContentVersioningService` + """ + return ContentVersioningService( + global_config=self.global_config, + infrastructure=self.infrastructure, + service=self.get_conn(), + detections=self.sync_obj.inputQueue, + ) + + @property + def should_test_content_versioning(self) -> bool: + """ + Indicates whether we should test content versioning. Content versioning + should be tested when integration testing is enabled, the mode is all, and ES is at least + version 8.0.0. + + :return: a bool indicating whether we should test content versioning + :rtype: bool + """ + es_version = self.es_version + return ( + self.global_config.enable_integration_testing + and isinstance(self.global_config.mode, All) + and es_version is not None + and es_version >= Version("8.0.0") + ) + + @property + def es_version(self) -> Version | None: + """ + Returns the version of Enterprise Security installed on the instance; None if not installed. + + :return: the version of ES, as a semver aware object + :rtype: :class:`semantic_version.Version` + """ + if not self.es_installed: + return None + return Version(self.get_conn().apps[ES_APP_NAME]["version"]) # type: ignore + + @property + def es_installed(self) -> bool: + """ + Indicates whether ES is installed on the instance. + + :return: a bool indicating whether ES is installed or not + :rtype: bool + """ + return ES_APP_NAME in self.get_conn().apps + + def check_for_es_install(self) -> None: + """ + Validating function which raises an error if Enterprise Security is not installed and + integration testing is enabled. + """ + if not self.es_installed and self.global_config.enable_integration_testing: + raise Exception( + "Enterprise Security does not appear to be installed on this instance and " + "integration testing is enabled." + ) + def configure_hec(self): self.hec_channel = str(uuid.uuid4()) try: @@ -298,14 +411,14 @@ def configure_imported_roles( imported_roles: list[str] = ["user", "power", "can_delete"], enterprise_security_roles: list[str] = ["ess_admin", "ess_analyst", "ess_user"], ): - try: - # Set which roles should be configured. For Enterprise Security/Integration Testing, - # we must add some extra foles. - if self.global_config.enable_integration_testing: - roles = imported_roles + enterprise_security_roles - else: - roles = imported_roles + # Set which roles should be configured. For Enterprise Security/Integration Testing, + # we must add some extra foles. + if self.global_config.enable_integration_testing: + roles = imported_roles + enterprise_security_roles + else: + roles = imported_roles + try: self.get_conn().roles.post( self.infrastructure.splunk_app_username, imported_roles=roles, @@ -314,16 +427,9 @@ def configure_imported_roles( ) return except Exception as e: - self.pbar.write( - f"The following role(s) do not exist:'{enterprise_security_roles}: {str(e)}" - ) - - self.get_conn().roles.post( - self.infrastructure.splunk_app_username, - imported_roles=imported_roles, - srchIndexesAllowed=";".join(self.all_indexes_on_server), - srchIndexesDefault=self.sync_obj.replay_index, - ) + msg = f"Error configuring roles: {str(e)}" + self.pbar.write(msg) + raise Exception(msg) from e def configure_delete_indexes(self): endpoint = "/services/properties/authorize/default/deleteIndexesAllowed" @@ -1170,8 +1276,6 @@ def retry_search_until_timeout( # on a field. In this case, the field will appear but will not contain any values current_empty_fields: set[str] = set() - # TODO (cmcginley): @ljstella is this something we're keeping for testing as - # well? for field in full_rba_field_set: if result.get(field, "null") == "null": if field in risk_object_fields_set: @@ -1257,7 +1361,7 @@ def delete_attack_data(self, attack_data_files: list[TestAttackData]): job = self.get_conn().jobs.create(splunk_search, **kwargs) results_stream = job.results(output_mode="json") # TODO: should we be doing something w/ this reader? - _ = splunklib.results.JSONResultsReader(results_stream) + _ = JSONResultsReader(results_stream) except Exception as e: raise ( diff --git a/contentctl/actions/detection_testing/views/DetectionTestingViewCLI.py b/contentctl/actions/detection_testing/views/DetectionTestingViewCLI.py index 2050fc10..2246783f 100644 --- a/contentctl/actions/detection_testing/views/DetectionTestingViewCLI.py +++ b/contentctl/actions/detection_testing/views/DetectionTestingViewCLI.py @@ -47,6 +47,8 @@ def showStatus(self, interval: int = 1): while True: summary = self.getSummaryObject() + # TODO (#338): there's a 1-off error here I think (we show one more than we + # actually have during testing) total = len( summary.get("tested_detections", []) + summary.get("untested_detections", []) diff --git a/contentctl/actions/test.py b/contentctl/actions/test.py index 96b070fe..90ac3951 100644 --- a/contentctl/actions/test.py +++ b/contentctl/actions/test.py @@ -1,44 +1,36 @@ +import pathlib from dataclasses import dataclass from typing import List -from contentctl.objects.config import test_common, Selected, Changes -from contentctl.objects.detection import Detection - - from contentctl.actions.detection_testing.DetectionTestingManager import ( DetectionTestingManager, DetectionTestingManagerInputDto, ) - - from contentctl.actions.detection_testing.infrastructures.DetectionTestingInfrastructure import ( DetectionTestingManagerOutputDto, ) - - -from contentctl.actions.detection_testing.views.DetectionTestingViewWeb import ( - DetectionTestingViewWeb, -) - from contentctl.actions.detection_testing.views.DetectionTestingViewCLI import ( DetectionTestingViewCLI, ) - from contentctl.actions.detection_testing.views.DetectionTestingViewFile import ( DetectionTestingViewFile, ) - +from contentctl.actions.detection_testing.views.DetectionTestingViewWeb import ( + DetectionTestingViewWeb, +) +from contentctl.objects.config import Changes, Selected +from contentctl.objects.config import test as test_ +from contentctl.objects.config import test_servers +from contentctl.objects.detection import Detection from contentctl.objects.integration_test import IntegrationTest -import pathlib - MAXIMUM_CONFIGURATION_TIME_SECONDS = 600 @dataclass(frozen=True) class TestInputDto: detections: List[Detection] - config: test_common + config: test_ | test_servers class Test: @@ -77,8 +69,8 @@ def execute(self, input_dto: TestInputDto) -> bool: if len(input_dto.detections) == 0: print( - f"With Detection Testing Mode '{input_dto.config.mode.mode_name}', there were [0] detections found to test." - "\nAs such, we will quit immediately." + f"With Detection Testing Mode '{input_dto.config.mode.mode_name}', there were " + "[0] detections found to test.\nAs such, we will quit immediately." ) # Directly call stop so that the summary.yml will be generated. Of course it will not # have any test results, but we still want it to contain a summary showing that now @@ -109,6 +101,10 @@ def execute(self, input_dto: TestInputDto) -> bool: try: summary_results = file.getSummaryObject() summary = summary_results.get("summary", {}) + if not isinstance(summary, dict): + raise ValueError( + f"Summary in results was an unexpected type ({type(summary)}): {summary}" + ) print(f"Test Summary (mode: {summary.get('mode', 'Error')})") print(f"\tSuccess : {summary.get('success', False)}") @@ -152,7 +148,7 @@ def execute(self, input_dto: TestInputDto) -> bool: "detection types (e.g. Correlation), but there may be overlap between these\n" "categories." ) - return summary_results.get("summary", {}).get("success", False) + return summary.get("success", False) except Exception as e: print(f"Error determining if whole test was successful: {str(e)}") diff --git a/contentctl/helper/utils.py b/contentctl/helper/utils.py index ba458b8b..027c7cae 100644 --- a/contentctl/helper/utils.py +++ b/contentctl/helper/utils.py @@ -6,6 +6,7 @@ import string from timeit import default_timer import pathlib +import logging from typing import Union, Tuple import tqdm @@ -490,3 +491,58 @@ def getPercent(numerator: float, denominator: float, decimal_places: int) -> str ratio = numerator / denominator percent = ratio * 100 return Utils.getFixedWidth(percent, decimal_places) + "%" + + @staticmethod + def get_logger( + name: str, log_level: int, log_path: str, enable_logging: bool + ) -> logging.Logger: + """ + Gets a logger instance for the given name; logger is configured if not already configured. + The NullHandler is used to suppress loggging when running in production so as not to + conflict w/ contentctl's larger pbar-based logging. The StreamHandler is enabled by setting + enable_logging to True (useful for debugging/testing locally) + + :param name: the logger name + :type name: str + :param log_level: the logging level (e.g. `logging.Debug`) + :type log_level: int + :param log_path: the path for the log file + :type log_path: str + :param enable_logging: a flag indicating whether logging should be redirected from null to + the stream handler + :type enable_logging: bool + + :return: a logger + :rtype: :class:`logging.Logger` + """ + # get logger for module + logger = logging.getLogger(name) + + # set propagate to False if not already set as such (needed to that we do not flow up to any + # root loggers) + if logger.propagate: + logger.propagate = False + + # if logger has no handlers, it needs to be configured for the first time + if not logger.hasHandlers(): + # set level + logger.setLevel(log_level) + + # if logging enabled, use a StreamHandler; else, use the NullHandler to suppress logging + handler: logging.Handler + if enable_logging: + handler = logging.FileHandler(log_path) + else: + handler = logging.NullHandler() + + # Format our output + formatter = logging.Formatter( + "%(asctime)s - %(levelname)s:%(name)s - %(message)s" + ) + handler.setFormatter(formatter) + + # Set handler level and add to logger + handler.setLevel(log_level) + logger.addHandler(handler) + + return logger diff --git a/contentctl/objects/config.py b/contentctl/objects/config.py index ac6cef78..e4149fea 100644 --- a/contentctl/objects/config.py +++ b/contentctl/objects/config.py @@ -97,7 +97,8 @@ def getApp(self, config: test, stage_file: bool = False) -> str: return str(self.getSplunkbasePath()) if self.version is None or self.uid is None: print( - f"Not downloading {self.title} from Splunkbase since uid[{self.uid}] AND version[{self.version}] MUST be defined" + f"Not downloading {self.title} from Splunkbase since uid[{self.uid}] AND " + f"version[{self.version}] MUST be defined" ) elif isinstance(self.hardcoded_path, pathlib.Path): @@ -149,7 +150,10 @@ class CustomApp(App_Base): exclude=True, default=int(datetime.now(UTC).strftime("%Y%m%d%H%M%S")), validate_default=True, - description="Build number for your app. This will always be a number that corresponds to the time of the build in the format YYYYMMDDHHMMSS", + description=( + "Build number for your app. This will always be a number that corresponds to the " + "time of the build in the format YYYYMMDDHHMMSS" + ), ) # id has many restrictions: # * Omit this setting for apps that are for internal use only and not intended @@ -194,7 +198,8 @@ def validate_version(cls, v, values): except Exception as e: raise ( ValueError( - f"The specified version does not follow the semantic versioning spec (https://semver.org/). {str(e)}" + "The specified version does not follow the semantic versioning spec " + f"(https://semver.org/). {str(e)}" ) ) return v @@ -416,8 +421,6 @@ class inspect(build): f"or CLI invocation appropriately] {validate.model_fields['enrichments'].description}" ), ) - # TODO (cmcginley): wording should change here if we want to be able to download any app from - # Splunkbase previous_build: str | None = Field( default=None, description=( @@ -548,7 +551,10 @@ class ContainerSettings(BaseModel): ) full_image_path: str = Field( default="registry.hub.docker.com/splunk/splunk:9.3", - title="Full path to the container image to be used. We are currently pinned to 9.3 as we resolve an issue with waiting to run until app installation completes.", + title=( + "Full path to the container image to be used. We are currently pinned to 9.3 as we " + "resolve an issue with waiting to run until app installation completes." + ), ) def getContainers(self) -> List[Container]: @@ -577,7 +583,10 @@ class Changes(BaseModel): mode_name: str = "Changes" target_branch: str = Field( ..., - description="The target branch to diff against. Note that this includes uncommitted changes in the working directory as well.", + description=( + "The target branch to diff against. Note that this includes uncommitted changes in the " + "working directory as well." + ), ) @@ -821,7 +830,8 @@ class test_common(build): f"'{PostTestBehavior.always_pause}' - the state of " "the test will always pause after a test, allowing the user to log into the " "server and experiment with the search and data before it is removed.\n\n" - f"'{PostTestBehavior.pause_on_failure}' - pause execution ONLY when a test fails. The user may press ENTER in the terminal " + f"'{PostTestBehavior.pause_on_failure}' - pause execution ONLY when a test fails. " + "The user may press ENTER in the terminal " "running the test to move on to the next test.\n\n" f"'{PostTestBehavior.never_pause}' - never stop testing, even if a test fails.\n\n" "***SPECIAL NOTE FOR CI/CD*** 'never_pause' MUST be used for a test to " diff --git a/contentctl/objects/content_versioning_service.py b/contentctl/objects/content_versioning_service.py new file mode 100644 index 00000000..68a529ba --- /dev/null +++ b/contentctl/objects/content_versioning_service.py @@ -0,0 +1,508 @@ +import json +import logging +import re +import time +import uuid +from functools import cached_property +from typing import Any, Callable + +import splunklib.client as splunklib # type: ignore +from pydantic import BaseModel, Field, PrivateAttr, computed_field +from splunklib.binding import HTTPError, ResponseReader # type: ignore +from splunklib.data import Record # type: ignore + +from contentctl.helper.utils import Utils +from contentctl.objects.config import Infrastructure, test_common +from contentctl.objects.correlation_search import ResultIterator +from contentctl.objects.detection import Detection + +# Suppress logging by default; enable for local testing +ENABLE_LOGGING = False +LOG_LEVEL = logging.DEBUG +LOG_PATH = "content_versioning_service.log" + + +class ContentVersioningService(BaseModel): + """ + A model representing the content versioning service used in ES 8.0.0+. This model can be used + to validate that detections have been installed in a way that is compatible with content + versioning. + """ + + # The global contentctl config + global_config: test_common + + # The instance specific infra config + infrastructure: Infrastructure + + # The splunklib service + service: splunklib.Service + + # The list of detections + detections: list[Detection] + + # The logger to use (logs all go to a null pipe unless ENABLE_LOGGING is set to True, so as not + # to conflict w/ tqdm) + logger: logging.Logger = Field( + default_factory=lambda: Utils.get_logger( + __name__, LOG_LEVEL, LOG_PATH, ENABLE_LOGGING + ) + ) + + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + + # Log instance details + self.logger.info( + f"[{self.infrastructure.instance_name} ({self.infrastructure.instance_address})] " + "Initing ContentVersioningService" + ) + + # The cached job on the splunk instance of the cms events + _cms_main_job: splunklib.Job | None = PrivateAttr(default=None) + + class Config: + # We need to allow arbitrary type for the splunklib service + arbitrary_types_allowed = True + + @computed_field + @property + def setup_functions(self) -> list[tuple[Callable[[], None], str]]: + """ + Returns the list of setup functions needed for content versioning testing + """ + return [ + (self.activate_versioning, "Activating Content Versioning"), + (self.wait_for_cms_main, "Waiting for CMS Parser"), + (self.validate_content_against_cms, "Validating Against CMS"), + ] + + def _query_content_versioning_service( + self, method: str, body: dict[str, Any] = {} + ) -> Record: + """ + Queries the SA-ContentVersioning service. Output mode defaults to JSON. + + :param method: HTTP request method (e.g. GET) + :type method: str + :param body: the payload/data/body of the request + :type body: dict[str, Any] + + :returns: a splunklib Record object (wrapper around dict) indicating the response + :rtype: :class:`splunklib.data.Record` + """ + # Add output mode to body + if "output_mode" not in body: + body["output_mode"] = "json" + + # Query the content versioning service + try: + response = self.service.request( # type: ignore + method=method, + path_segment="configs/conf-feature_flags/general", + body=body, + app="SA-ContentVersioning", + ) + except HTTPError as e: + # Raise on any HTTP errors + raise HTTPError(f"Error querying content versioning service: {e}") from e + + return response + + @property + def is_versioning_activated(self) -> bool: + """ + Indicates whether the versioning service is activated or not + + :returns: a bool indicating if content versioning is activated or not + :rtype: bool + """ + # Query the SA-ContentVersioning service for versioning status + response = self._query_content_versioning_service(method="GET") + + # Grab the response body and check for errors + if "body" not in response: + raise KeyError( + f"Cannot retrieve versioning status, 'body' was not found in JSON response: {response}" + ) + body: Any = response["body"] # type: ignore + if not isinstance(body, ResponseReader): + raise ValueError( + "Cannot retrieve versioning status, value at 'body' in JSON response had an unexpected" + f" type: expected '{ResponseReader}', received '{type(body)}'" + ) + + # Read the JSON and parse it into a dictionary + json_ = body.readall() + try: + data = json.loads(json_) + except json.JSONDecodeError as e: + raise ValueError(f"Unable to parse response body as JSON: {e}") from e + + # Find the versioning_activated field and report any errors + try: + for entry in data["entry"]: + if entry["name"] == "general": + return bool(int(entry["content"]["versioning_activated"])) + except KeyError as e: + raise KeyError( + "Cannot retrieve versioning status, unable to determine versioning status using " + f"the expected keys: {e}" + ) from e + raise ValueError( + "Cannot retrieve versioning status, unable to find an entry matching 'general' in the " + "response." + ) + + def activate_versioning(self) -> None: + """ + Activate the content versioning service + """ + # Post to the SA-ContentVersioning service to set versioning status + self._query_content_versioning_service( + method="POST", body={"versioning_activated": True} + ) + + # Confirm versioning has been enabled + if not self.is_versioning_activated: + raise Exception( + "Something went wrong, content versioning is still disabled." + ) + + self.logger.info( + f"[{self.infrastructure.instance_name}] Versioning service successfully activated" + ) + + @computed_field + @cached_property + def cms_fields(self) -> list[str]: + """ + Property listing the fields we want to pull from the cms_main index + + :returns: a list of strings, the fields we want + :rtype: list[str] + """ + return [ + "app_name", + "detection_id", + "version", + "action.correlationsearch.label", + "sourcetype", + ] + + @property + def is_cms_parser_enabled(self) -> bool: + """ + Indicates whether the cms_parser mod input is enabled or not. + + :returns: a bool indicating if cms_parser mod input is activated or not + :rtype: bool + """ + # Get the data input entity + cms_parser = self.service.input("data/inputs/cms_parser/main") # type: ignore + + # Convert the 'disabled' field to an int, then a bool, and then invert to be 'enabled' + return not bool(int(cms_parser.content["disabled"])) # type: ignore + + def force_cms_parser(self) -> None: + """ + Force the cms_parser to run by disabling and re-enabling it. + """ + # Get the data input entity + cms_parser = self.service.input("data/inputs/cms_parser/main") # type: ignore + + # Disable and re-enable + cms_parser.disable() + cms_parser.enable() + + # Confirm the cms_parser is enabled + if not self.is_cms_parser_enabled: + raise Exception("Something went wrong, cms_parser is still disabled.") + + self.logger.info( + f"[{self.infrastructure.instance_name}] cms_parser successfully toggled to force run" + ) + + def wait_for_cms_main(self) -> None: + """ + Checks the cms_main index until it has the expected number of events, or it times out. + """ + # Force the cms_parser to start parsing our savedsearches.conf + self.force_cms_parser() + + # Set counters and limits for out exp. backoff timer + elapsed_sleep_time = 0 + num_tries = 0 + time_to_sleep = 2**num_tries + max_sleep = 600 + + # Loop until timeout + while elapsed_sleep_time < max_sleep: + # Sleep, and add the time to the elapsed counter + self.logger.info( + f"[{self.infrastructure.instance_name}] Waiting {time_to_sleep} for cms_parser to " + "finish" + ) + time.sleep(time_to_sleep) + elapsed_sleep_time += time_to_sleep + self.logger.info( + f"[{self.infrastructure.instance_name}] Checking cms_main (attempt #{num_tries + 1}" + f" - {elapsed_sleep_time} seconds elapsed of {max_sleep} max)" + ) + + # Check if the number of CMS events matches or exceeds the number of detections + if self.get_num_cms_events() >= len(self.detections): + self.logger.info( + f"[{self.infrastructure.instance_name}] Found " + f"{self.get_num_cms_events(use_cache=True)} events in cms_main which " + f"meets or exceeds the expected {len(self.detections)}." + ) + break + else: + self.logger.info( + f"[{self.infrastructure.instance_name}] Found " + f"{self.get_num_cms_events(use_cache=True)} matching events in cms_main; " + f"expecting {len(self.detections)}. Continuing to wait..." + ) + # Update the number of times we've tried, and increment the time to sleep + num_tries += 1 + time_to_sleep = 2**num_tries + + # If the computed time to sleep will exceed max_sleep, adjust appropriately + if (elapsed_sleep_time + time_to_sleep) > max_sleep: + time_to_sleep = max_sleep - elapsed_sleep_time + + def _query_cms_main(self, use_cache: bool = False) -> splunklib.Job: + """ + Queries the cms_main index, optionally appending the provided query suffix. + + :param use_cache: a flag indicating whether the cached job should be returned + :type use_cache: bool + + :returns: a search Job entity + :rtype: :class:`splunklib.client.Job` + """ + # Use the cached job if asked to do so + if use_cache: + if self._cms_main_job is not None: + return self._cms_main_job + raise Exception( + "Attempting to return a cached job against the cms_main index, but no job has been" + " cached yet." + ) + + # Construct the query looking for CMS events matching the content app name + query = ( + f"search index=cms_main sourcetype=stash_common_detection_model " + f'app_name="{self.global_config.app.appid}" | fields {", ".join(self.cms_fields)}' + ) + self.logger.debug( + f"[{self.infrastructure.instance_name}] Query on cms_main: {query}" + ) + + # Get the job as a blocking operation, set the cache, and return + self._cms_main_job = self.service.search(query, exec_mode="blocking") # type: ignore + return self._cms_main_job + + def get_num_cms_events(self, use_cache: bool = False) -> int: + """ + Gets the number of matching events in the cms_main index + + :param use_cache: a flag indicating whether the cached job should be returned + :type use_cache: bool + + :returns: the count of matching events + :rtype: int + """ + # Query the cms_main index + job = self._query_cms_main(use_cache=use_cache) + + # Convert the result count to an int + return int(job["resultCount"]) + + def validate_content_against_cms(self) -> None: + """ + Using the cms_main index, validate content against the index to ensure our + savedsearches.conf is compatible with ES content versioning features. **NOTE**: while in + the future, this function may validate more types of content, currently, we only validate + detections against the cms_main index. + """ + # Get the cached job and result count + result_count = self.get_num_cms_events(use_cache=True) + job = self._query_cms_main(use_cache=True) + + # Create a running list of validation errors + exceptions: list[Exception] = [] + + # Generate an error for the count mismatch + if result_count != len(self.detections): + msg = ( + f"[{self.infrastructure.instance_name}] Expected {len(self.detections)} matching " + f"events in cms_main, but found {result_count}." + ) + self.logger.error(msg) + exceptions.append(Exception(msg)) + self.logger.info( + f"[{self.infrastructure.instance_name}] Expecting {len(self.detections)} matching " + f"events in cms_main, found {result_count}." + ) + + # Init some counters and a mapping of detections to their names + count = 100 + offset = 0 + remaining_detections = { + x.get_action_dot_correlationsearch_dot_label(self.global_config.app): x + for x in self.detections + } + matched_detections: dict[str, Detection] = {} + + # Create a filter for a specific memory error we're ok ignoring + sub_second_order_pattern = re.compile( + r".*Events might not be returned in sub-second order due to search memory limits.*" + ) + + # Iterate over the results until we've gone through them all + while offset < result_count: + iterator = ResultIterator( + response_reader=job.results( # type: ignore + output_mode="json", count=count, offset=offset + ), + error_filters=[sub_second_order_pattern], + ) + + # Iterate over the currently fetched results + for cms_event in iterator: + # Increment the offset for each result + offset += 1 + + # Get the name of the search in the CMS event + cms_entry_name = cms_event["action.correlationsearch.label"] + self.logger.info( + f"[{self.infrastructure.instance_name}] {offset}: Matching cms_main entry " + f"'{cms_entry_name}' against detections" + ) + + # If CMS entry name matches one of the detections already matched, we've got an + # unexpected repeated entry + if cms_entry_name in matched_detections: + msg = ( + f"[{self.infrastructure.instance_name}] [{cms_entry_name}]: Detection " + f"appears more than once in the cms_main index." + ) + self.logger.error(msg) + exceptions.append(Exception(msg)) + continue + + # Iterate over the detections and compare the CMS entry name against each + result_matches_detection = False + for detection_cs_label in remaining_detections: + # If we find a match, break this loop, set the found flag and move the detection + # from those that still need to matched to those already matched + if cms_entry_name == detection_cs_label: + self.logger.info( + f"[{self.infrastructure.instance_name}] {offset}: Succesfully matched " + f"cms_main entry against detection ('{detection_cs_label}')!" + ) + + # Validate other fields of the cms_event against the detection + exception = self.validate_detection_against_cms_event( + cms_event, remaining_detections[detection_cs_label] + ) + + # Save the exception if validation failed + if exception is not None: + exceptions.append(exception) + + # Delete the matched detection and move it to the matched list + result_matches_detection = True + matched_detections[detection_cs_label] = remaining_detections[ + detection_cs_label + ] + del remaining_detections[detection_cs_label] + break + + # Generate an exception if we couldn't match the CMS main entry to a detection + if result_matches_detection is False: + msg = ( + f"[{self.infrastructure.instance_name}] [{cms_entry_name}]: Could not " + "match entry in cms_main against any of the expected detections." + ) + self.logger.error(msg) + exceptions.append(Exception(msg)) + + # If we have any remaining detections, they could not be matched against an entry in + # cms_main and there may have been a parsing issue with savedsearches.conf + if len(remaining_detections) > 0: + # Generate exceptions for the unmatched detections + for detection_cs_label in remaining_detections: + msg = ( + f"[{self.infrastructure.instance_name}] [{detection_cs_label}]: Detection not " + "found in cms_main; there may be an issue with savedsearches.conf" + ) + self.logger.error(msg) + exceptions.append(Exception(msg)) + + # Raise exceptions as a group + if len(exceptions) > 0: + raise ExceptionGroup( + "1 or more issues validating our detections against the cms_main index", + exceptions, + ) + + # Else, we've matched/validated all detections against cms_main + self.logger.info( + f"[{self.infrastructure.instance_name}] Matched and validated all detections against " + "cms_main!" + ) + + def validate_detection_against_cms_event( + self, cms_event: dict[str, Any], detection: Detection + ) -> Exception | None: + """ + Given an event from the cms_main index and the matched detection, compare fields and look + for any inconsistencies + + :param cms_event: The event from the cms_main index + :type cms_event: dict[str, Any] + :param detection: The matched detection + :type detection: :class:`contentctl.objects.detection.Detection` + + :return: The generated exception, or None + :rtype: Exception | None + """ + # TODO (PEX-509): validate additional fields between the cms_event and the detection + + cms_uuid = uuid.UUID(cms_event["detection_id"]) + rule_name_from_detection = detection.get_action_dot_correlationsearch_dot_label( + self.global_config.app + ) + + # Compare the correlation search label + if cms_event["action.correlationsearch.label"] != rule_name_from_detection: + msg = ( + f"[{self.infrastructure.instance_name}][{detection.name}]: Correlation search " + f"label in cms_event ('{cms_event['action.correlationsearch.label']}') does not " + "match detection name" + ) + self.logger.error(msg) + return Exception(msg) + elif cms_uuid != detection.id: + # Compare the UUIDs + msg = ( + f"[{self.infrastructure.instance_name}] [{detection.name}]: UUID in cms_event " + f"('{cms_uuid}') does not match UUID in detection ('{detection.id}')" + ) + self.logger.error(msg) + return Exception(msg) + elif cms_event["version"] != f"{detection.version}.1": + # Compare the versions (we append '.1' to the detection version to be in line w/ the + # internal representation in ES) + msg = ( + f"[{self.infrastructure.instance_name}] [{detection.name}]: Version in cms_event " + f"('{cms_event['version']}') does not match version in detection " + f"('{detection.version}.1')" + ) + self.logger.error(msg) + return Exception(msg) + + return None diff --git a/contentctl/objects/correlation_search.py b/contentctl/objects/correlation_search.py index 0cebf5cc..4306cc8e 100644 --- a/contentctl/objects/correlation_search.py +++ b/contentctl/objects/correlation_search.py @@ -1,35 +1,36 @@ +import json import logging +import re import time -import json -from typing import Any -from enum import StrEnum, IntEnum +from enum import IntEnum, StrEnum from functools import cached_property +from typing import Any -from pydantic import ConfigDict, BaseModel, computed_field, Field, PrivateAttr -from splunklib.results import JSONResultsReader, Message # type: ignore -from splunklib.binding import HTTPError, ResponseReader # type: ignore import splunklib.client as splunklib # type: ignore +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field +from splunklib.binding import HTTPError, ResponseReader # type: ignore +from splunklib.results import JSONResultsReader, Message # type: ignore from tqdm import tqdm # type: ignore -from contentctl.objects.risk_analysis_action import RiskAnalysisAction -from contentctl.objects.notable_action import NotableAction -from contentctl.objects.base_test_result import TestResultStatus -from contentctl.objects.integration_test_result import IntegrationTestResult from contentctl.actions.detection_testing.progress_bar import ( - format_pbar_string, # type: ignore - TestReportingType, TestingStates, + TestReportingType, + format_pbar_string, # type: ignore ) +from contentctl.helper.utils import Utils +from contentctl.objects.base_test_result import TestResultStatus +from contentctl.objects.detection import Detection from contentctl.objects.errors import ( + ClientError, IntegrationTestingError, ServerError, - ClientError, ValidationFailed, ) -from contentctl.objects.detection import Detection -from contentctl.objects.risk_event import RiskEvent +from contentctl.objects.integration_test_result import IntegrationTestResult +from contentctl.objects.notable_action import NotableAction from contentctl.objects.notable_event import NotableEvent - +from contentctl.objects.risk_analysis_action import RiskAnalysisAction +from contentctl.objects.risk_event import RiskEvent # Suppress logging by default; enable for local testing ENABLE_LOGGING = False @@ -37,46 +38,6 @@ LOG_PATH = "correlation_search.log" -def get_logger() -> logging.Logger: - """ - Gets a logger instance for the module; logger is configured if not already configured. The - NullHandler is used to suppress loggging when running in production so as not to conflict w/ - contentctl's larger pbar-based logging. The StreamHandler is enabled by setting ENABLE_LOGGING - to True (useful for debugging/testing locally) - """ - # get logger for module - logger = logging.getLogger(__name__) - - # set propagate to False if not already set as such (needed to that we do not flow up to any - # root loggers) - if logger.propagate: - logger.propagate = False - - # if logger has no handlers, it needs to be configured for the first time - if not logger.hasHandlers(): - # set level - logger.setLevel(LOG_LEVEL) - - # if logging enabled, use a StreamHandler; else, use the NullHandler to suppress logging - handler: logging.Handler - if ENABLE_LOGGING: - handler = logging.FileHandler(LOG_PATH) - else: - handler = logging.NullHandler() - - # Format our output - formatter = logging.Formatter( - "%(asctime)s - %(levelname)s:%(name)s - %(message)s" - ) - handler.setFormatter(formatter) - - # Set handler level and add to logger - handler.setLevel(LOG_LEVEL) - logger.addHandler(handler) - - return logger - - class SavedSearchKeys(StrEnum): """ Various keys into the SavedSearch content @@ -135,34 +96,58 @@ class ResultIterator: Given a ResponseReader, constructs a JSONResultsReader and iterates over it; when Message instances are encountered, they are logged if the message is anything other than "error", in which case an error is raised. Regular results are returned as expected + :param response_reader: a ResponseReader object - :param logger: a Logger object + :type response_reader: :class:`splunklib.binding.ResponseReader` + :param error_filters: set of re Patterns used to filter out errors we're ok ignoring + :type error_filters: list[:class:`re.Pattern[str]`] """ - def __init__(self, response_reader: ResponseReader) -> None: + def __init__( + self, response_reader: ResponseReader, error_filters: list[re.Pattern[str]] = [] + ) -> None: # init the results reader self.results_reader: JSONResultsReader = JSONResultsReader(response_reader) + # the list of patterns for errors to ignore + self.error_filters: list[re.Pattern[str]] = error_filters + # get logger - self.logger: logging.Logger = get_logger() + self.logger: logging.Logger = Utils.get_logger( + __name__, LOG_LEVEL, LOG_PATH, ENABLE_LOGGING + ) def __iter__(self) -> "ResultIterator": return self - def __next__(self) -> dict[Any, Any]: + def __next__(self) -> dict[str, Any]: # Use a reader for JSON format so we can iterate over our results for result in self.results_reader: # log messages, or raise if error if isinstance(result, Message): # convert level string to level int - level_name = result.type.strip().upper() # type: ignore + level_name: str = result.type.strip().upper() # type: ignore + # TODO (PEX-510): this method is deprecated; replace with our own enum level: int = logging.getLevelName(level_name) # log message at appropriate level and raise if needed message = f"SPLUNK: {result.message}" # type: ignore self.logger.log(level, message) + filtered = False if level == logging.ERROR: - raise ServerError(message) + # if the error matches any of the filters, flag it + for filter in self.error_filters: + self.logger.debug(f"Filter: {filter}; message: {message}") + if filter.match(message) is not None: + self.logger.debug( + f"Error matched filter {filter}; continuing" + ) + filtered = True + break + + # if no filter was matched, raise + if not filtered: + raise ServerError(message) # if dict, just return elif isinstance(result, dict): @@ -218,7 +203,12 @@ class CorrelationSearch(BaseModel): # The logger to use (logs all go to a null pipe unless ENABLE_LOGGING is set to True, so as not # to conflict w/ tqdm) - logger: logging.Logger = Field(default_factory=get_logger, init=False) + logger: logging.Logger = Field( + default_factory=lambda: Utils.get_logger( + __name__, LOG_LEVEL, LOG_PATH, ENABLE_LOGGING + ), + init=False, + ) # The set of indexes to clear on cleanup indexes_to_purge: set[str] = Field(default=set(), init=False) diff --git a/contentctl/output/conf_output.py b/contentctl/output/conf_output.py index c96ad581..bcbff3db 100644 --- a/contentctl/output/conf_output.py +++ b/contentctl/output/conf_output.py @@ -93,6 +93,10 @@ def writeMiscellaneousAppFiles(self) -> set[pathlib.Path]: return written_files + # TODO (#339): we could have a discrepancy between detections tested and those delivered + # based on the jinja2 template + # {% if (detection.type == 'TTP' or detection.type == 'Anomaly' or + # detection.type == 'Hunting' or detection.type == 'Correlation') %} def writeDetections(self, objects: list[Detection]) -> set[pathlib.Path]: written_files: set[pathlib.Path] = set() for output_app_path, template_name in [