Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion workers/openrelik-worker-timesketch/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
import redis
from celery.app import Celery

REDIS_URL = os.getenv("REDIS_URL")
REDIS_URL = os.getenv("REDIS_URL") or "redis://127.0.0.1:6379/0"
celery = Celery(broker=REDIS_URL, backend=REDIS_URL, include=["src.tasks"])
redis_client = redis.Redis.from_url(REDIS_URL)
172 changes: 134 additions & 38 deletions workers/openrelik-worker-timesketch/src/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import requests
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add this with uv add the pyproject.toml and uv.lock files to this PR.

uv add requests

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack need to do this


from openrelik_worker_common.task_utils import create_task_result, get_input_files
from timesketch_api_client import client as timesketch_client
Expand All @@ -29,43 +30,93 @@ def get_or_create_sketch(
workflow_id=None,
):
"""
Retrieves or creates a sketch, handling locking if needed.
This uses Redis distrubuted lock to avoid race conditions.
Retrieves an existing Timesketch sketch or creates a new one.

If `sketch_id` is provided, it attempts to fetch that specific sketch.
If `sketch_name` is provided (and `sketch_id` is not), it attempts to create
a sketch with that name.
If neither `sketch_id` nor `sketch_name` is provided, it generates a default
sketch name based on the `workflow_id`. In this default case, a Redis
distributed lock is used to prevent race conditions if multiple workers
attempt to create the same sketch concurrently. The lock ensures that only
one worker will create the sketch if it doesn't already exist.

Args:
client: Timesketch API client.
redis_client: Redis client.
sketch_id: ID of the sketch to retrieve.
sketch_name: Name of the sketch to create.
workflow_id: ID of the workflow.
timesketch_api_client: An instance of the Timesketch API client.
redis_client: An instance of the Redis client, used for distributed locking.
sketch_id (int, optional): The ID of an existing sketch to retrieve.
sketch_name (str, optional): The name for a new sketch to be created.
workflow_id (str, optional): The ID of the workflow, used to generate
a default sketch name if `sketch_id` and `sketch_name` are not provided.

Returns:
Timesketch sketch object or None if failed
timesketch_api_client.Sketch: The retrieved or created Timesketch sketch object.

Raises:
ValueError: If sketch_id is provided but the sketch is not found,
or if workflow_id is missing when required for default naming.
RuntimeError: If sketch creation or retrieval fails for other reasons,
such as API errors.
"""
sketch = None

if sketch_id:
sketch = timesketch_api_client.get_sketch(int(sketch_id))
try:
sketch = timesketch_api_client.get_sketch(int(sketch_id))
if not sketch:
raise ValueError(f"Sketch with ID '{sketch_id}' not found.")
return sketch
except (ValueError, RuntimeError, requests.exceptions.RequestException) as e:
if isinstance(e, ValueError):
raise
raise RuntimeError(
f"Failed to retrieve sketch with ID '{sketch_id}': {e}"
) from e
elif sketch_name:
sketch = timesketch_api_client.create_sketch(sketch_name)
try:
sketch = timesketch_api_client.create_sketch(sketch_name)
if not sketch:
raise RuntimeError(
f"Failed to create sketch with name '{sketch_name}' "
f"(API returned no sketch object)."
)
return sketch
except (RuntimeError, requests.exceptions.RequestException) as e:
raise RuntimeError(
f"Failed to create sketch with name '{sketch_name}': {e}"
) from e
else:
sketch_name = f"openrelik-workflow-{workflow_id}"
if not workflow_id:
raise ValueError(
"workflow_id is required to generate a default sketch name when "
"sketch_id and sketch_name are not provided."
)
default_sketch_name = f"openrelik-workflow-{workflow_id}"
# Prevent multiple distributed workers from concurrently creating the same
# sketch. This Redis-based lock ensures only one worker proceeds at a time, even
# across different machines. The code will block until the lock is acquired.
# The lock automatically expires after 60 seconds to prevent deadlocks.
with redis_client.lock(sketch_name, timeout=60, blocking_timeout=5):
with redis_client.lock(default_sketch_name, timeout=60, blocking_timeout=5):
# Search for an existing sketch while having the lock
for _sketch in timesketch_api_client.list_sketches():
if _sketch.name == sketch_name:
sketch = _sketch
break

# If not found, create a new one
if not sketch:
sketch = timesketch_api_client.create_sketch(sketch_name)

return sketch
try:
for _sketch in timesketch_api_client.list_sketches():
if _sketch.name == default_sketch_name:
sketch = _sketch
break
# If not found, create a new one
if not sketch:
sketch = timesketch_api_client.create_sketch(default_sketch_name)
if not sketch:
raise RuntimeError(
f"Failed to create default sketch '{default_sketch_name}' "
f"after acquiring lock."
)
return sketch
except (RuntimeError, requests.exceptions.RequestException) as e:
raise RuntimeError(
f"Failed to retrieve or create default sketch "
f"'{default_sketch_name}': {e}"
) from e


# Task name used to register and route the task to the correct queue.
Expand Down Expand Up @@ -97,6 +148,14 @@ def get_or_create_sketch(
"type": "text",
"required": False,
},
{
"name": "make_sketch_public",
"label": "Make sketch public",
"description": "Set the sketch to be publicly accessible in Timesketch.",
"type": "boolean",
"required": False,
"default": False,
},
],
}

Expand All @@ -110,17 +169,41 @@ def upload(
workflow_id: str = None,
task_config: dict = None,
) -> str:
"""Export files to Timesketch.
"""
Uploads files to a Timesketch instance, creating or updating a sketch and timelines.

Args:
pipe_result: Base64-encoded result from the previous Celery task, if any.
input_files: List of input file dictionaries (unused if pipe_result exists).
output_path: Path to the output directory.
workflow_id: ID of the workflow.
task_config: User configuration for the task.
self: The Celery task instance.
pipe_result (str, optional): Base64-encoded string representing the result
from a previous Celery task.
input_files (list, optional): A list of dictionaries representing input files.
output_path (str, optional): Path to the output directory.
workflow_id (str, optional): The ID of the OpenRelik workflow.
task_config (dict, optional): A dictionary containing user configuration.

Returns:
Base64-encoded dictionary containing task results.
str: A Base64-encoded dictionary string containing the task results.
"""
return _upload(
pipe_result=pipe_result,
input_files=input_files,
output_path=output_path,
workflow_id=workflow_id,
task_config=task_config,
)


def _upload(
pipe_result: str = None,
input_files: list = None,
output_path: str = None,
workflow_id: str = None,
task_config: dict = None,
) -> str:
"""Helper function to perform the upload.

Note: This is separated from the @celery.task decorated 'upload' function
to allow for clean unit testing without Celery proxy/decorator interference.
"""
input_files = get_input_files(pipe_result, input_files or [])

Expand All @@ -130,10 +213,19 @@ def upload(
timesketch_username = os.environ.get("TIMESKETCH_USERNAME")
timesketch_password = os.environ.get("TIMESKETCH_PASSWORD")

# Validate required environment variables
assert timesketch_server_url, "Missing TIMESKETCH_SERVER_URL"
assert timesketch_server_public_url, "Missing TIMESKETCH_SERVER_PUBLIC_URL"
assert timesketch_username, "Missing TIMESKETCH_USERNAME"
assert timesketch_password, "Missing TIMESKETCH_PASSWORD"

# User supplied config.
sketch_id = task_config.get("sketch_id")
sketch_name = task_config.get("sketch_name")
sketch_identifier = {"sketch_id": sketch_id} if sketch_id else {"sketch_name": sketch_name}
sketch_identifier = (
{"sketch_id": sketch_id} if sketch_id else {"sketch_name": sketch_name}
)
make_sketch_public = task_config.get("make_sketch_public", False)

# Create a Timesketch API client.
timesketch_api_client = timesketch_client.TimesketchApi(
Expand All @@ -150,17 +242,21 @@ def upload(
workflow_id=workflow_id,
)

if not sketch:
raise Exception(f"Failed to create or retrieve sketch '{sketch_name}'")

# Make the sketch public.
# TODO: Make this user configurable.
sketch.add_to_acl(make_public=True)
# Make the sketch public if configured.
if make_sketch_public:
try:
sketch.add_to_acl(make_public=True)
except (RuntimeError, requests.exceptions.RequestException) as e:
raise RuntimeError(
f"Failed to make sketch {sketch.id} ('{sketch.name}') public: {e}"
) from e

# Import each input file to it's own index.
# Import each input file to its own index.
for input_file in input_files:
input_file_path = input_file.get("path")
timeline_name = task_config.get("timeline_name") or input_file.get("display_name")
timeline_name = task_config.get("timeline_name") or input_file.get(
"display_name"
)
with importer.ImportStreamer() as streamer:
streamer.set_sketch(sketch)
streamer.set_timeline_name(timeline_name)
Expand Down
Loading
Loading