diff --git a/workers/openrelik-worker-timesketch/src/app.py b/workers/openrelik-worker-timesketch/src/app.py index 0e6d561..2cb08ab 100644 --- a/workers/openrelik-worker-timesketch/src/app.py +++ b/workers/openrelik-worker-timesketch/src/app.py @@ -17,6 +17,6 @@ import redis from celery.app import Celery -REDIS_URL = os.getenv("REDIS_URL") +REDIS_URL = os.getenv("REDIS_URL") or "redis://127.0.0.1:6379/0" celery = Celery(broker=REDIS_URL, backend=REDIS_URL, include=["src.tasks"]) redis_client = redis.Redis.from_url(REDIS_URL) diff --git a/workers/openrelik-worker-timesketch/src/tasks.py b/workers/openrelik-worker-timesketch/src/tasks.py index 0df7b0d..5770f3c 100644 --- a/workers/openrelik-worker-timesketch/src/tasks.py +++ b/workers/openrelik-worker-timesketch/src/tasks.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import requests from openrelik_worker_common.task_utils import create_task_result, get_input_files from timesketch_api_client import client as timesketch_client @@ -29,43 +30,93 @@ def get_or_create_sketch( workflow_id=None, ): """ - Retrieves or creates a sketch, handling locking if needed. - This uses Redis distrubuted lock to avoid race conditions. + Retrieves an existing Timesketch sketch or creates a new one. + + If `sketch_id` is provided, it attempts to fetch that specific sketch. + If `sketch_name` is provided (and `sketch_id` is not), it attempts to create + a sketch with that name. + If neither `sketch_id` nor `sketch_name` is provided, it generates a default + sketch name based on the `workflow_id`. In this default case, a Redis + distributed lock is used to prevent race conditions if multiple workers + attempt to create the same sketch concurrently. The lock ensures that only + one worker will create the sketch if it doesn't already exist. Args: - client: Timesketch API client. - redis_client: Redis client. - sketch_id: ID of the sketch to retrieve. - sketch_name: Name of the sketch to create. - workflow_id: ID of the workflow. + timesketch_api_client: An instance of the Timesketch API client. + redis_client: An instance of the Redis client, used for distributed locking. + sketch_id (int, optional): The ID of an existing sketch to retrieve. + sketch_name (str, optional): The name for a new sketch to be created. + workflow_id (str, optional): The ID of the workflow, used to generate + a default sketch name if `sketch_id` and `sketch_name` are not provided. Returns: - Timesketch sketch object or None if failed + timesketch_api_client.Sketch: The retrieved or created Timesketch sketch object. + + Raises: + ValueError: If sketch_id is provided but the sketch is not found, + or if workflow_id is missing when required for default naming. + RuntimeError: If sketch creation or retrieval fails for other reasons, + such as API errors. """ sketch = None if sketch_id: - sketch = timesketch_api_client.get_sketch(int(sketch_id)) + try: + sketch = timesketch_api_client.get_sketch(int(sketch_id)) + if not sketch: + raise ValueError(f"Sketch with ID '{sketch_id}' not found.") + return sketch + except (ValueError, RuntimeError, requests.exceptions.RequestException) as e: + if isinstance(e, ValueError): + raise + raise RuntimeError( + f"Failed to retrieve sketch with ID '{sketch_id}': {e}" + ) from e elif sketch_name: - sketch = timesketch_api_client.create_sketch(sketch_name) + try: + sketch = timesketch_api_client.create_sketch(sketch_name) + if not sketch: + raise RuntimeError( + f"Failed to create sketch with name '{sketch_name}' " + f"(API returned no sketch object)." + ) + return sketch + except (RuntimeError, requests.exceptions.RequestException) as e: + raise RuntimeError( + f"Failed to create sketch with name '{sketch_name}': {e}" + ) from e else: - sketch_name = f"openrelik-workflow-{workflow_id}" + if not workflow_id: + raise ValueError( + "workflow_id is required to generate a default sketch name when " + "sketch_id and sketch_name are not provided." + ) + default_sketch_name = f"openrelik-workflow-{workflow_id}" # Prevent multiple distributed workers from concurrently creating the same # sketch. This Redis-based lock ensures only one worker proceeds at a time, even # across different machines. The code will block until the lock is acquired. # The lock automatically expires after 60 seconds to prevent deadlocks. - with redis_client.lock(sketch_name, timeout=60, blocking_timeout=5): + with redis_client.lock(default_sketch_name, timeout=60, blocking_timeout=5): # Search for an existing sketch while having the lock - for _sketch in timesketch_api_client.list_sketches(): - if _sketch.name == sketch_name: - sketch = _sketch - break - - # If not found, create a new one - if not sketch: - sketch = timesketch_api_client.create_sketch(sketch_name) - - return sketch + try: + for _sketch in timesketch_api_client.list_sketches(): + if _sketch.name == default_sketch_name: + sketch = _sketch + break + # If not found, create a new one + if not sketch: + sketch = timesketch_api_client.create_sketch(default_sketch_name) + if not sketch: + raise RuntimeError( + f"Failed to create default sketch '{default_sketch_name}' " + f"after acquiring lock." + ) + return sketch + except (RuntimeError, requests.exceptions.RequestException) as e: + raise RuntimeError( + f"Failed to retrieve or create default sketch " + f"'{default_sketch_name}': {e}" + ) from e # Task name used to register and route the task to the correct queue. @@ -97,6 +148,14 @@ def get_or_create_sketch( "type": "text", "required": False, }, + { + "name": "make_sketch_public", + "label": "Make sketch public", + "description": "Set the sketch to be publicly accessible in Timesketch.", + "type": "boolean", + "required": False, + "default": False, + }, ], } @@ -110,17 +169,41 @@ def upload( workflow_id: str = None, task_config: dict = None, ) -> str: - """Export files to Timesketch. + """ + Uploads files to a Timesketch instance, creating or updating a sketch and timelines. Args: - pipe_result: Base64-encoded result from the previous Celery task, if any. - input_files: List of input file dictionaries (unused if pipe_result exists). - output_path: Path to the output directory. - workflow_id: ID of the workflow. - task_config: User configuration for the task. + self: The Celery task instance. + pipe_result (str, optional): Base64-encoded string representing the result + from a previous Celery task. + input_files (list, optional): A list of dictionaries representing input files. + output_path (str, optional): Path to the output directory. + workflow_id (str, optional): The ID of the OpenRelik workflow. + task_config (dict, optional): A dictionary containing user configuration. Returns: - Base64-encoded dictionary containing task results. + str: A Base64-encoded dictionary string containing the task results. + """ + return _upload( + pipe_result=pipe_result, + input_files=input_files, + output_path=output_path, + workflow_id=workflow_id, + task_config=task_config, + ) + + +def _upload( + pipe_result: str = None, + input_files: list = None, + output_path: str = None, + workflow_id: str = None, + task_config: dict = None, +) -> str: + """Helper function to perform the upload. + + Note: This is separated from the @celery.task decorated 'upload' function + to allow for clean unit testing without Celery proxy/decorator interference. """ input_files = get_input_files(pipe_result, input_files or []) @@ -130,10 +213,19 @@ def upload( timesketch_username = os.environ.get("TIMESKETCH_USERNAME") timesketch_password = os.environ.get("TIMESKETCH_PASSWORD") + # Validate required environment variables + assert timesketch_server_url, "Missing TIMESKETCH_SERVER_URL" + assert timesketch_server_public_url, "Missing TIMESKETCH_SERVER_PUBLIC_URL" + assert timesketch_username, "Missing TIMESKETCH_USERNAME" + assert timesketch_password, "Missing TIMESKETCH_PASSWORD" + # User supplied config. sketch_id = task_config.get("sketch_id") sketch_name = task_config.get("sketch_name") - sketch_identifier = {"sketch_id": sketch_id} if sketch_id else {"sketch_name": sketch_name} + sketch_identifier = ( + {"sketch_id": sketch_id} if sketch_id else {"sketch_name": sketch_name} + ) + make_sketch_public = task_config.get("make_sketch_public", False) # Create a Timesketch API client. timesketch_api_client = timesketch_client.TimesketchApi( @@ -150,17 +242,21 @@ def upload( workflow_id=workflow_id, ) - if not sketch: - raise Exception(f"Failed to create or retrieve sketch '{sketch_name}'") - - # Make the sketch public. - # TODO: Make this user configurable. - sketch.add_to_acl(make_public=True) + # Make the sketch public if configured. + if make_sketch_public: + try: + sketch.add_to_acl(make_public=True) + except (RuntimeError, requests.exceptions.RequestException) as e: + raise RuntimeError( + f"Failed to make sketch {sketch.id} ('{sketch.name}') public: {e}" + ) from e - # Import each input file to it's own index. + # Import each input file to its own index. for input_file in input_files: input_file_path = input_file.get("path") - timeline_name = task_config.get("timeline_name") or input_file.get("display_name") + timeline_name = task_config.get("timeline_name") or input_file.get( + "display_name" + ) with importer.ImportStreamer() as streamer: streamer.set_sketch(sketch) streamer.set_timeline_name(timeline_name) diff --git a/workers/openrelik-worker-timesketch/tests/test_tasks.py b/workers/openrelik-worker-timesketch/tests/test_tasks.py index 8325719..15ee09e 100644 --- a/workers/openrelik-worker-timesketch/tests/test_tasks.py +++ b/workers/openrelik-worker-timesketch/tests/test_tasks.py @@ -12,16 +12,252 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests tasks.""" - -# Note: Use pytest for writing tests! import pytest +from unittest.mock import MagicMock, patch + +# Patch src.app before importing tasks to avoid redis connection +with patch("redis.Redis.from_url") as mock_redis_from_url: + with patch("celery.Celery") as mock_celery_init: + from src.tasks import get_or_create_sketch, upload, _upload + +import os +import requests + + +class TestGetOrCreateSketch: + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_by_id_success(self, mock_redis): + """Tests retrieving an existing sketch by its ID.""" + mock_timesketch_client = MagicMock() + mock_sketch = MagicMock() + mock_timesketch_client.get_sketch.return_value = mock_sketch + + sketch = get_or_create_sketch( + mock_timesketch_client, mock_redis, sketch_id=123 + ) + + mock_timesketch_client.get_sketch.assert_called_once_with(123) + assert sketch == mock_sketch + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_by_id_not_found(self, mock_redis): + """Tests handling the case where a sketch ID is not found.""" + mock_timesketch_client = MagicMock() + mock_timesketch_client.get_sketch.return_value = None + + with pytest.raises(ValueError) as excinfo: + get_or_create_sketch( + mock_timesketch_client, mock_redis, sketch_id=123 + ) + + assert "Sketch with ID '123' not found." in str(excinfo.value) + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_by_id_api_error(self, mock_redis): + """Tests handling an API error when retrieving by ID.""" + mock_timesketch_client = MagicMock() + mock_timesketch_client.get_sketch.side_effect = RuntimeError("API Error") + + with pytest.raises(RuntimeError) as excinfo: + get_or_create_sketch( + mock_timesketch_client, mock_redis, sketch_id=123 + ) + + assert "Failed to retrieve sketch with ID '123': API Error" in str(excinfo.value) + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_by_name_success(self, mock_redis): + """Tests creating a new sketch by a given name.""" + mock_timesketch_client = MagicMock() + mock_sketch = MagicMock() + mock_timesketch_client.create_sketch.return_value = mock_sketch + + sketch = get_or_create_sketch( + mock_timesketch_client, mock_redis, sketch_name="Test Sketch" + ) + + mock_timesketch_client.create_sketch.assert_called_once_with("Test Sketch") + assert sketch == mock_sketch + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_by_name_failure(self, mock_redis): + """Tests handling the failure of sketch creation by name.""" + mock_timesketch_client = MagicMock() + mock_timesketch_client.create_sketch.return_value = None + + with pytest.raises(RuntimeError) as excinfo: + get_or_create_sketch( + mock_timesketch_client, mock_redis, sketch_name="Test Sketch" + ) + + assert "Failed to create sketch with name 'Test Sketch'" in str(excinfo.value) + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_default_name_existing(self, mock_redis): + """Tests retrieving an existing sketch using the default naming convention.""" + mock_timesketch_client = MagicMock() + mock_sketch = MagicMock() + mock_sketch.name = "openrelik-workflow-123" + mock_timesketch_client.list_sketches.return_value = [mock_sketch] + + # Mocking the context manager for the lock + mock_lock = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_lock.__enter__.return_value = MagicMock() + + sketch = get_or_create_sketch( + mock_timesketch_client, mock_redis, workflow_id="123" + ) + + assert sketch == mock_sketch + mock_redis.lock.assert_called_once_with( + "openrelik-workflow-123", timeout=60, blocking_timeout=5 + ) + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_default_name_new(self, mock_redis): + """Tests creating a new sketch using the default naming convention.""" + mock_timesketch_client = MagicMock() + mock_sketch = MagicMock() + mock_sketch.name = "openrelik-workflow-123" + mock_timesketch_client.list_sketches.return_value = [] + mock_timesketch_client.create_sketch.return_value = mock_sketch + + mock_lock = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_lock.__enter__.return_value = MagicMock() + + sketch = get_or_create_sketch( + mock_timesketch_client, mock_redis, workflow_id="123" + ) + assert sketch == mock_sketch + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_default_name_create_failure(self, mock_redis): + """Tests handling the failure of default sketch creation.""" + mock_timesketch_client = MagicMock() + mock_timesketch_client.list_sketches.return_value = [] + mock_timesketch_client.create_sketch.return_value = None + + mock_lock = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_lock.__enter__.return_value = MagicMock() + + with pytest.raises(RuntimeError) as excinfo: + get_or_create_sketch( + mock_timesketch_client, mock_redis, workflow_id="123" + ) + assert "after acquiring lock" in str(excinfo.value) + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_default_name_list_error(self, mock_redis): + """Tests handling an error when listing sketches for default name.""" + mock_timesketch_client = MagicMock() + mock_timesketch_client.list_sketches.side_effect = RuntimeError("List Error") + + mock_lock = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_lock.__enter__.return_value = MagicMock() + + with pytest.raises(RuntimeError) as excinfo: + get_or_create_sketch( + mock_timesketch_client, mock_redis, workflow_id="123" + ) + assert "Failed to retrieve or create default sketch" in str(excinfo.value) + + @patch("src.tasks.redis_client") + def test_get_or_create_sketch_missing_workflow_id(self, mock_redis): + """Tests that ValueError is raised if no identification is provided and workflow_id is missing.""" + mock_timesketch_client = MagicMock() + + with pytest.raises(ValueError) as excinfo: + get_or_create_sketch( + mock_timesketch_client, mock_redis + ) + + assert "workflow_id is required" in str(excinfo.value) + + +class TestUpload: + @patch.dict(os.environ, { + "TIMESKETCH_SERVER_URL": "http://localhost", + "TIMESKETCH_SERVER_PUBLIC_URL": "http://public", + "TIMESKETCH_USERNAME": "user", + "TIMESKETCH_PASSWORD": "pass", + }) + @patch("src.tasks.timesketch_client.TimesketchApi", autospec=True) + @patch("src.tasks.get_input_files") + @patch("src.tasks.importer.ImportStreamer", autospec=True) + @patch("src.tasks.create_task_result") + @patch("src.tasks.redis_client") + def test_upload_success( + self, + mock_redis, + mock_create_task_result, + mock_import_streamer_class, + mock_get_input_files, + mock_timesketch_api_class, + ): + """Tests a successful upload task using the real get_or_create_sketch logic.""" + # Setup input files + mock_get_input_files.return_value = [{"path": "/tmp/file", "display_name": "file"}] + + # Setup Timesketch client and sketch + mock_ts_client = mock_timesketch_api_class.return_value + mock_sketch = MagicMock() + mock_sketch.id = 1 + mock_ts_client.get_sketch.return_value = mock_sketch + + # Setup mock for context manager importer.ImportStreamer() + mock_streamer_instance = mock_import_streamer_class.return_value.__enter__.return_value + + # Call the helper function directly + # By not mocking get_or_create_sketch, we test its integration + _upload( + pipe_result=None, + input_files=[], + output_path="/tmp", + workflow_id="123", + task_config={"sketch_id": "1", "make_sketch_public": True}, + ) + + # Verify get_sketch was called (part of get_or_create_sketch logic) + mock_ts_client.get_sketch.assert_called_once_with(1) + mock_sketch.add_to_acl.assert_called_once_with(make_public=True) + mock_streamer_instance.add_file.assert_called_once_with("/tmp/file") + mock_create_task_result.assert_called_once() -# from src.tasks import command + @patch.dict(os.environ, {}, clear=True) + def test_upload_missing_env(self): + """Tests upload failure when environment variables are missing.""" + with pytest.raises(AssertionError): + _upload(task_config={}) + @patch.dict(os.environ, { + "TIMESKETCH_SERVER_URL": "http://localhost", + "TIMESKETCH_SERVER_PUBLIC_URL": "http://public", + "TIMESKETCH_USERNAME": "user", + "TIMESKETCH_PASSWORD": "pass", + }) + @patch("src.tasks.timesketch_client.TimesketchApi", autospec=True) + @patch("src.tasks.get_input_files") + def test_upload_acl_error( + self, + mock_get_input_files, + mock_timesketch_api_class, + ): + """Tests upload failure when ACL update fails.""" + mock_get_input_files.return_value = [] + mock_ts_client = mock_timesketch_api_class.return_value + mock_sketch = MagicMock() + mock_sketch.id = 1 + mock_sketch.name = "test" + mock_sketch.add_to_acl.side_effect = RuntimeError("ACL Error") + mock_ts_client.get_sketch.return_value = mock_sketch -def test_task_command(): - """Test command task.""" + with pytest.raises(RuntimeError) as excinfo: + _upload( + task_config={"sketch_id": "1", "make_sketch_public": True}, + ) + assert "Failed to make sketch 1 ('test') public" in str(excinfo.value) - ret = "some dummy return value" - assert isinstance(ret, str)