diff --git a/.github/workflows/testing-coverage.yml b/.github/workflows/testing-coverage.yml index 0d68a93ec..44a358175 100644 --- a/.github/workflows/testing-coverage.yml +++ b/.github/workflows/testing-coverage.yml @@ -17,10 +17,15 @@ jobs: if: ${{ !(contains(github.event.pull_request.labels.*.name, 'WIP (no-ci)')) && !(contains(github.event.pull_request.labels.*.name, 'WIP (lint-only)')) }} name: BEE Integration Test strategy: + fail-fast: false matrix: bee_worker: [Slurmrestd, SlurmCommands, Flux] + # TODO get neo4j test to work again + # gdb_backend: [neo4j, sqlite3] + gdb_backend: [sqlite3] env: BEE_WORKER: ${{ matrix.bee_worker }} + GDB_BACKEND: ${{ matrix.gdb_backend }} # Note: Needs to run on 22.04 or later since slurmrestd doesn't seem to be # available on 20.04 runs-on: ubuntu-22.04 @@ -37,13 +42,20 @@ jobs: run: | . ./ci/env.sh ./ci/integration_test.sh - mv .coverage.integration ".coverage.${{ matrix.bee_worker }}" + mv .coverage.integration ".coverage.${{ matrix.bee_worker }}.${{ matrix.gdb_backend }}" + - name: Upload bee error logs + if: failure() # only when the job fails + uses: actions/upload-artifact@v4 + with: + name: bee-error-logs-${{ matrix.bee_worker }}-${{ matrix.gdb_backend }}.log + path: bee-error-*.log + if-no-files-found: ignore - name: Upload coverage artifact uses: actions/upload-artifact@v4 with: - name: coverage-${{ matrix.bee_worker }} + name: coverage-${{ matrix.bee_worker }}-${{ matrix.gdb_backend }} include-hidden-files: true - path: ".coverage.${{ matrix.bee_worker }}" + path: ".coverage.${{ matrix.bee_worker }}.${{ matrix.gdb_backend }}" unit-test: if: ${{ !(contains(github.event.pull_request.labels.*.name, 'WIP (no-ci)')) && !(contains(github.event.pull_request.labels.*.name, 'WIP (lint-only)')) }} @@ -93,22 +105,40 @@ jobs: name: coverage-unit-tests path: ./ - - name: Get Slurmrestd coverage + # - name: Get Slurmrestd+neo4j coverage + # uses: actions/download-artifact@v4 + # with: + # name: coverage-Slurmrestd-neo4j + # path: ./ + + - name: Get Slurmrestd+sqlite3 coverage uses: actions/download-artifact@v4 with: - name: coverage-Slurmrestd + name: coverage-Slurmrestd-sqlite3 path: ./ - - - name: Get SlurmCommands coverage + + # - name: Get SlurmCommands+neo4j coverage + # uses: actions/download-artifact@v4 + # with: + # name: coverage-SlurmCommands-neo4j + # path: ./ + + - name: Get SlurmCommands+sqlite3 coverage uses: actions/download-artifact@v4 with: - name: coverage-SlurmCommands + name: coverage-SlurmCommands-sqlite3 path: ./ - - - name: Get Flux coverage + + # - name: Get Flux+neo4j coverage + # uses: actions/download-artifact@v4 + # with: + # name: coverage-Flux-neo4j + # path: ./ + + - name: Get Flux+sqlite3 coverage uses: actions/download-artifact@v4 with: - name: coverage-Flux + name: coverage-Flux-sqlite3 path: ./ - name: Combine coverage diff --git a/beeflow/client/core.py b/beeflow/client/core.py index 7bc2152a8..9bdde50f8 100644 --- a/beeflow/client/core.py +++ b/beeflow/client/core.py @@ -202,6 +202,11 @@ def need_slurmrestd(): and not bc.get('slurm', 'use_commands')) +def need_neo4j(): + """Check if neo4j is needed.""" + return bc.get('graphdb', 'type').lower() == 'neo4j' + + def init_components(remote=False): """Initialize the components and component manager.""" mgr = ComponentManager() @@ -255,14 +260,15 @@ def celery(): # container_path = paths.redis_container() # If it exists, we assume that it actually has a valid container - if not container_manager.check_container_dir('neo4j'): - print('Unpacking Neo4j image...') - container_manager.create_image('neo4j') + if need_neo4j(): + if not container_manager.check_container_dir('neo4j'): + print('Unpacking Neo4j image...') + container_manager.create_image('neo4j') - @mgr.component('neo4j-database', ('wf_manager',)) - def start_neo4j(): - """Start the neo4j graph database.""" - return neo4j_manager.start() + @mgr.component('neo4j-database', ('wf_manager',)) + def start_neo4j(): + """Start the neo4j graph database.""" + return neo4j_manager.start() @mgr.component('redis', ()) def start_redis(): @@ -673,14 +679,17 @@ def pull_deps(outdir: str = typer.Option('.', '--outdir', '-o', help='directory to store containers in')): """Pull required BEE containers and store in outdir.""" load_check_charliecloud() - neo4j_path = os.path.join(os.path.realpath(outdir), 'neo4j.tar.gz') - neo4j_dockerfile = str(Path(REPO_PATH, "beeflow/data/dockerfiles/Dockerfile.neo4j")) - build_to_tar('neo4j_image', neo4j_dockerfile, neo4j_path) + if need_neo4j(): + neo4j_path = os.path.join(os.path.realpath(outdir), 'neo4j.tar.gz') + neo4j_dockerfile = str(Path(REPO_PATH, "beeflow/data/dockerfiles/Dockerfile.neo4j")) + build_to_tar('neo4j_image', neo4j_dockerfile, neo4j_path) redis_path = os.path.join(os.path.realpath(outdir), 'redis.tar.gz') pull_to_tar('redis', redis_path) - AlterConfig(changes={'DEFAULT': {'neo4j_image': neo4j_path, - 'redis_image': redis_path}}).save() + if need_neo4j(): + AlterConfig(changes={'DEFAULT': {'neo4j_image': neo4j_path}}).save() + + AlterConfig(changes={'DEFAULT': {'redis_image': redis_path}}).save() dep_dir = container_manager.get_dep_dir() if os.path.isdir(dep_dir): diff --git a/beeflow/common/config_driver.py b/beeflow/common/config_driver.py index 701aa3f9a..3a09d6f38 100644 --- a/beeflow/common/config_driver.py +++ b/beeflow/common/config_driver.py @@ -397,6 +397,8 @@ def validate_chrun_opts(opts): info='extra Charliecloud setup to put in a job script') # Graph Database VALIDATOR.section('graphdb', info='Main graph database configuration section.') +VALIDATOR.option('graphdb', 'type', default='sqlite3', choices=('neo4j', 'sqlite3'), + info='type of graph database to use', prompt=False) VALIDATOR.option('graphdb', 'hostname', default='localhost', prompt=False, info='hostname of database') diff --git a/beeflow/common/db/bdb.py b/beeflow/common/db/bdb.py index f38dc47d2..e7c72f03b 100644 --- a/beeflow/common/db/bdb.py +++ b/beeflow/common/db/bdb.py @@ -15,9 +15,10 @@ def create_connection(db_file): conn = None try: conn = sqlite3.connect(db_file) + conn.execute("PRAGMA foreign_keys = ON;") return conn except Error as error: - print(error) + print("Error connecting to database: ", error) return conn @@ -42,9 +43,24 @@ def run(db_file, stmt, params=None): cursor.execute(stmt) conn.commit() except Error as error: + print("Error running: ", stmt) print(error) + +def runscript(db_file, script): + """Run the sql script on the database. Doesn't return anything.""" + with create_connection(db_file) as conn: + try: + cursor = conn.cursor() + cursor.executescript(script) + conn.commit() + except Error as error: + print("Error running script") + print(error) + + + def getone(db_file, stmt, params=None): """Run the sql statement on the database and return the result.""" with create_connection(db_file) as conn: @@ -55,7 +71,9 @@ def getone(db_file, stmt, params=None): else: cursor.execute(stmt) result = cursor.fetchone() - except Error: + except Error as error: + print("Error fetching one: ", stmt) + print(error) result = None return result @@ -70,7 +88,9 @@ def getall(db_file, stmt, params=None): else: cursor.execute(stmt) result = cursor.fetchall() - except Error: + except Error as error: + print("Error fetching all: ", stmt) + print(error) result = None return result diff --git a/beeflow/common/db/gdb_db.py b/beeflow/common/db/gdb_db.py new file mode 100644 index 000000000..e41d93ddf --- /dev/null +++ b/beeflow/common/db/gdb_db.py @@ -0,0 +1,804 @@ +"""Graph Database SQL implementation.""" +# pylint:disable=C0103 + +import json +from typing import Optional +from beeflow.common.db import bdb +from beeflow.common.object_models import (Workflow, Task, Requirement, Hint, +InputParameter, OutputParameter, StepInput, StepOutput) +from beeflow.wf_manager.models import WorkflowInfo + +failed_task_states = ['FAILED', 'SUBMIT_FAIL', 'BUILD_FAIL', 'DEP_FAIL', 'TIMEOUT', 'CANCELLED'] +final_task_states = ['COMPLETED', 'RESTARTED'] + failed_task_states + + +class SQL_GDB: + """Graph database implementation using SQLite.""" + def __init__(self, db_file): + self.db_file = db_file + self._init_tables() + + def _init_tables(self): + wfs_stmt = """CREATE TABLE IF NOT EXISTS workflow ( + id TEXT PRIMARY KEY, + name TEXT, + state TEXT, + workdir TEXT, + main_cwl TEXT, + wf_path TEXT, + yaml TEXT, + reqs JSON, + hints JSON, + restart INTEGER DEFAULT 0 + );""" + + wf_inputs_stmt = """CREATE TABLE IF NOT EXISTS workflow_input ( + id TEXT, + workflow_id TEXT, + type TEXT, + value TEXT, + FOREIGN KEY (workflow_id) REFERENCES workflow(id) ON DELETE CASCADE, + PRIMARY KEY (workflow_id , id) + );""" + + wf_outputs_stmt = """CREATE TABLE IF NOT EXISTS workflow_output ( + id TEXT, + workflow_id TEXT, + type TEXT, + value TEXT, + source TEXT, + FOREIGN KEY (workflow_id) REFERENCES workflow(id) ON DELETE CASCADE, + PRIMARY KEY (workflow_id , id) + );""" + + tasks_stmt = """CREATE TABLE IF NOT EXISTS task ( + id TEXT PRIMARY KEY, + workflow_id TEXT, + name TEXT, + state TEXT, + workdir TEXT, + base_command JSON, + stdout TEXT, + stderr TEXT, + reqs JSON, + hints JSON, + metadata JSON, + FOREIGN KEY (workflow_id) REFERENCES workflow(id) ON DELETE CASCADE + );""" + + task_inputs_stmt = """CREATE TABLE IF NOT EXISTS task_input ( + id TEXT, + task_id TEXT, + type TEXT, + value TEXT, + default_val TEXT, + source TEXT, + prefix TEXT, + position INTEGER, + value_from TEXT, + FOREIGN KEY (task_id) REFERENCES task(id) ON DELETE CASCADE, + PRIMARY KEY (task_id, id) + );""" + + task_outputs_stmt = """CREATE TABLE IF NOT EXISTS task_output ( + id TEXT, + task_id TEXT, + type TEXT, + value TEXT, + glob TEXT, + FOREIGN KEY (task_id) REFERENCES task(id) ON DELETE CASCADE, + PRIMARY KEY (task_id, id) + );""" + + task_deps_stmt = """CREATE TABLE IF NOT EXISTS task_dep ( + depending_task_id TEXT NOT NULL, + depends_on_task_id TEXT NOT NULL, + PRIMARY KEY (depending_task_id, depends_on_task_id), + FOREIGN KEY (depending_task_id) REFERENCES task(id) ON DELETE CASCADE, + FOREIGN KEY (depends_on_task_id) REFERENCES task(id) ON DELETE CASCADE + );""" + + task_rst_stmt = """CREATE TABLE IF NOT EXISTS task_restart ( + restarting_task_id TEXT NOT NULL, + restarted_from_task_id TEXT NOT NULL, + PRIMARY KEY (restarting_task_id, restarted_from_task_id), + FOREIGN KEY (restarting_task_id) REFERENCES task(id) ON DELETE CASCADE, + FOREIGN KEY (restarted_from_task_id) REFERENCES task(id) ON DELETE CASCADE + );""" + + add_indexes_stmt = """CREATE INDEX IF NOT EXISTS idx_task_wf_id ON task(workflow_id); + CREATE INDEX IF NOT EXISTS idx_task_wf_state ON task(workflow_id, state); + + CREATE INDEX IF NOT EXISTS idx_task_input_task_id ON task_input(task_id); + CREATE INDEX IF NOT EXISTS idx_task_output_task_id ON task_output(task_id); + + CREATE INDEX IF NOT EXISTS idx_task_dep_depends_on ON task_dep(depends_on_task_id); + CREATE INDEX IF NOT EXISTS idx_task_dep_depending ON task_dep(depending_task_id); + """ + + bdb.create_table(self.db_file, wfs_stmt) + bdb.create_table(self.db_file, wf_inputs_stmt) + bdb.create_table(self.db_file, wf_outputs_stmt) + bdb.create_table(self.db_file, tasks_stmt) + bdb.create_table(self.db_file, task_inputs_stmt) + bdb.create_table(self.db_file, task_outputs_stmt) + bdb.create_table(self.db_file, task_deps_stmt) + bdb.create_table(self.db_file, task_rst_stmt) + bdb.runscript(self.db_file, add_indexes_stmt) + + def create_workflow(self, workflow: Workflow): + """Create a workflow in the db""" + wf_stmt = """INSERT INTO workflow (id, name, state, workdir, main_cwl, + wf_path, yaml, reqs, hints, restart) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?);""" + wf_input_stmt = """INSERT INTO workflow_input (id, workflow_id, type, value) + VALUES (?, ?, ?, ?);""" + wf_output_stmt = """INSERT INTO workflow_output (id, workflow_id, type, value, source) + VALUES (?, ?, ?, ?, ?);""" + + + hints_json = json.dumps([h.model_dump() for h in workflow.hints]) + reqs_json = json.dumps([r.model_dump() for r in workflow.requirements]) + bdb.run(self.db_file, wf_stmt, (workflow.id, workflow.name, workflow.state, + workflow.workdir, workflow.main_cwl, workflow.wf_path, + workflow.yaml, reqs_json, hints_json, 0)) + + + for inp in workflow.inputs: + bdb.run(self.db_file, wf_input_stmt, (inp.id, workflow.id, inp.type, inp.value)) + for outp in workflow.outputs: + bdb.run(self.db_file, wf_output_stmt, (outp.id, workflow.id, outp.type, + outp.value, outp.source)) + + + def set_init_task_inputs(self, wf_id: str): + """Set initial workflow task inputs from workflow inputs or defaults""" + + inputs_query = """ + UPDATE task_input + SET value = ( + SELECT wi.value + FROM task AS t + JOIN workflow_input AS wi + ON wi.workflow_id = t.workflow_id + WHERE + t.workflow_id = :wf_id + AND task_input.task_id = t.id + AND task_input.source = wi.id + AND wi.value IS NOT NULL + ) + WHERE EXISTS ( + SELECT 1 + FROM task AS t + JOIN workflow_input AS wi + ON wi.workflow_id = t.workflow_id + WHERE + t.workflow_id = :wf_id + AND task_input.task_id = t.id + AND task_input.source = wi.id + AND wi.value IS NOT NULL + );""" + + defaults_query = """ + UPDATE task_input + SET value = default_val + WHERE + value IS NULL + AND default_val IS NOT NULL + AND EXISTS ( + SELECT 1 + FROM task AS t + JOIN workflow_input AS wi + ON wi.workflow_id = t.workflow_id + WHERE + t.workflow_id = :wf_id + AND task_input.task_id = t.id + AND task_input.source = wi.id + );""" + + + bdb.run(self.db_file, inputs_query, {'wf_id': wf_id}) + bdb.run(self.db_file, defaults_query, {'wf_id': wf_id}) + + def set_runnable_tasks_to_ready(self, wf_id: str): + """Set all tasks with all inputs satisfied to READY state""" + set_runnable_ready_query = """ + UPDATE task + SET state = 'READY' + WHERE workflow_id = :wf_id + AND state = 'WAITING' + AND NOT EXISTS ( + SELECT 1 + FROM task_input AS ti + WHERE ti.task_id = task.id + AND ti.value IS NULL + );""" + bdb.run(self.db_file, set_runnable_ready_query, {'wf_id': wf_id}) + + def set_workflow_state(self, wf_id: str, state: str): + """Set the state of the workflow.""" + set_wf_state_query = """ + UPDATE workflow + SET state = :state + WHERE id = :wf_id;""" + bdb.run(self.db_file, set_wf_state_query, {'wf_id': wf_id, 'state': state}) + + def create_task(self, task: Task): + """Create a task in the db""" + task_stmt = """INSERT INTO task (id, workflow_id, name, state, workdir, base_command, + stdout, stderr, reqs, hints, metadata) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);""" + task_input_stmt = """INSERT INTO task_input (id, task_id, type, value, default_val, source, + prefix, position, value_from) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);""" + task_output_stmt = """INSERT INTO task_output (id, task_id, type, value, glob) + VALUES (?, ?, ?, ?, ?);""" + + + hints_json = json.dumps([h.model_dump() for h in task.hints]) + reqs_json = json.dumps([r.model_dump() for r in task.requirements]) + metadata_json = json.dumps(task.metadata) + bdb.run(self.db_file, task_stmt, (task.id, task.workflow_id, task.name, + task.state, task.workdir, + json.dumps(task.base_command), task.stdout, task.stderr, + reqs_json, hints_json, metadata_json)) + + + for inp in task.inputs: + bdb.run(self.db_file, task_input_stmt, (inp.id, task.id, inp.type, inp.value, + inp.default, inp.source, + inp.prefix, inp.position, + inp.value_from)) + for outp in task.outputs: + bdb.run(self.db_file, task_output_stmt, (outp.id, task.id, outp.type, + outp.value, outp.glob)) + + def set_task_state(self, task_id: str, state: str): + """Set the state of a task.""" + set_task_state_query = """ + UPDATE task + SET state = :state + WHERE id = :task_id;""" + bdb.run(self.db_file, set_task_state_query, {'task_id': task_id, 'state': state}) + + + def add_dependencies(self, task: Task, old_task: Task=None, restarted_task=False): + """Add dependencies for a task based on its inputs and outputs.""" + if restarted_task: + set_restarted_wf = """ + UPDATE workflow + SET restart = 1 + WHERE id = :wf_id;""" + delete_dependencies_query = """ + DELETE FROM task_dep + WHERE depends_on_task_id = :depends_on_task_id;""" + + restarted_query = """ + INSERT INTO task_restart (restarting_task_id, restarted_from_task_id) + VALUES (:restarting_task_id, :restarted_from_task_id);""" + + dependency_query = """ + INSERT OR IGNORE INTO task_dep (depending_task_id, depends_on_task_id) + SELECT DISTINCT t.id AS depending_task_id, s.id AS depends_on_task_id + FROM task AS s + JOIN task_output AS o + ON o.task_id = s.id + JOIN task_input AS i + ON i.source = o.id + JOIN task AS t + ON t.id = i.task_id + WHERE + s.id = :task_id + AND s.workflow_id = t.workflow_id;""" + bdb.run(self.db_file, set_restarted_wf, {'wf_id': task.workflow_id}) + bdb.run(self.db_file, delete_dependencies_query, {'depends_on_task_id': old_task.id}) + bdb.run(self.db_file, restarted_query, {'restarting_task_id': task.id, + 'restarted_from_task_id': old_task.id}) + bdb.run(self.db_file, dependency_query, {'task_id': task.id}) + else: + dependency_query = """ + INSERT OR IGNORE INTO task_dep (depending_task_id, depends_on_task_id) + SELECT DISTINCT s.id AS depending_task_id, t.id AS depends_on_task_id + FROM task AS s + JOIN task_input AS i + ON i.task_id = s.id + JOIN task_output AS o + ON o.id = i.source + JOIN task AS t + ON t.id = o.task_id + WHERE + s.id = :task_id + AND s.workflow_id = t.workflow_id;""" + + dependent_query = """ + INSERT OR IGNORE INTO task_dep (depending_task_id, depends_on_task_id) + SELECT DISTINCT t.id AS depending_task_id, s.id AS depends_on_task_id + FROM task AS s + JOIN task_output AS o + ON o.task_id = s.id + JOIN task_input AS i + ON i.source = o.id + JOIN task AS t + ON t.id = i.task_id + WHERE + s.id = :task_id + AND s.workflow_id = t.workflow_id;""" + + bdb.run(self.db_file, dependency_query, {'task_id': task.id}) + bdb.run(self.db_file, dependent_query, {'task_id': task.id}) + + + def copy_task_outputs(self, task: Task): + """Use task outputs to set dependent task inputs or workflow outputs + + or set dependent task inputs to default if necessary""" + + task_inputs_query = """ + UPDATE task_input + SET value = ( + SELECT o.value + FROM task_dep AS d + JOIN task_output AS o + ON o.task_id = d.depends_on_task_id + WHERE + d.depends_on_task_id = :task_id + AND d.depending_task_id = task_input.task_id + AND task_input.source = o.id + AND o.value IS NOT NULL + LIMIT 1 + ) + WHERE EXISTS ( + SELECT 1 + FROM task_dep AS d + JOIN task_output AS o + ON o.task_id = d.depends_on_task_id + WHERE + d.depends_on_task_id = :task_id + AND d.depending_task_id = task_input.task_id + AND task_input.source = o.id + AND o.value IS NOT NULL + );""" + + defaults_query = f""" + UPDATE task_input + SET value = default_val + WHERE + value IS NULL + AND default_val IS NOT NULL + AND EXISTS ( + SELECT 1 + FROM task_dep AS d_down + WHERE + d_down.depending_task_id = task_input.task_id + AND d_down.depends_on_task_id = ? + ) + AND NOT EXISTS ( + SELECT 1 + FROM task_dep AS d_up + JOIN task AS pt + ON pt.id = d_up.depends_on_task_id + WHERE + d_up.depending_task_id = task_input.task_id + AND pt.state NOT IN ({', '.join(['?' for _ in final_task_states])}) + );""" + + workflow_output_query = """ + UPDATE workflow_output + SET value = ( + SELECT o.value + FROM task_output AS o + WHERE + o.id = workflow_output.source + AND o.task_id = :task_id + LIMIT 1 + ) + WHERE EXISTS ( + SELECT 1 + FROM task_output AS o + WHERE + o.id = workflow_output.source + AND o.task_id = :task_id + );""" + + bdb.run(self.db_file, task_inputs_query, {'task_id': task.id}) + bdb.run(self.db_file, defaults_query, [task.id, *final_task_states]) + bdb.run(self.db_file, workflow_output_query, {'task_id': task.id}) + + + def get_task(self, task_id: str) -> Optional[Task]: + """Return a reconstructed Task object from the db by its ID.""" + task_data = bdb.getone(self.db_file, 'SELECT * FROM task WHERE id=?', [task_id]) + if not task_data: + return None + + task = Task( + id=task_data[0], + workflow_id=task_data[1], + name=task_data[2], + state=task_data[3], + workdir=task_data[4], + base_command=json.loads(task_data[5]), + stdout=task_data[6], + stderr=task_data[7], + requirements=[Requirement.model_validate(r) for r in json.loads(task_data[8])], + hints=[Hint.model_validate(h) for h in json.loads(task_data[9])], + metadata=json.loads(task_data[10]), + inputs=self.get_task_inputs(task_data[0]), + outputs=self.get_task_outputs(task_data[0]) + ) + return task + + def get_task_inputs(self, task_id: str): + """Return a list of StepInput objects for a task.""" + task_inputs = bdb.getall(self.db_file, 'SELECT * FROM task_input WHERE task_id=?', + [task_id]) + return [StepInput( + id=ti[0], + type=ti[2], + value=ti[3], + default=ti[4], + source=ti[5], + prefix=ti[6], + position=ti[7], + value_from=ti[8] + ) + for ti in task_inputs] if task_inputs else [] + + def get_task_outputs(self, task_id: str): + """Return a list of StepOutput objects for a task.""" + task_outputs = bdb.getall(self.db_file, 'SELECT * FROM task_output WHERE task_id=?', + [task_id]) + return [StepOutput( + id=to[0], + type=to[2], + value=to[3], + glob=to[4] + ) + for to in task_outputs] if task_outputs else [] + + def get_all_workflow_info(self): + """Return a list of all workflows in the db. + + :rtype: list of WorkflowInfo + """ + wf_data = bdb.getall(self.db_file, 'SELECT id, name, state FROM workflow') + wf_info_list = [WorkflowInfo( + wf_id=wf[0], + wf_name=wf[1], + wf_status=wf[2], + ) for wf in wf_data] if wf_data else [] + return wf_info_list + + + def get_workflow(self, wf_id: str) -> Optional[Workflow]: + """Return a reconstructed Workflow object from the db by its ID.""" + wf_data = bdb.getone(self.db_file, 'SELECT * FROM workflow WHERE id=?', [wf_id]) + if not wf_data: + return None + + workflow_object = Workflow( + id=wf_data[0], + name=wf_data[1], + state=wf_data[2], + workdir=wf_data[3], + main_cwl=wf_data[4], + wf_path=wf_data[5], + yaml=wf_data[6], + requirements=[Requirement.model_validate(r) for r in json.loads(wf_data[7])], + hints=[Hint.model_validate(h) for h in json.loads(wf_data[8])], + inputs=self.get_workflow_inputs(wf_data[0]), + outputs=self.get_workflow_outputs(wf_data[0]) + ) + return workflow_object + + def get_workflow_inputs(self, wf_id: str): + """Return a list of InputParameter objects for a workflow.""" + wf_inputs = bdb.getall(self.db_file, 'SELECT * FROM workflow_input WHERE workflow_id=?', + [wf_id]) + return [InputParameter( + id=wi[0], + type=wi[2], + value=wi[3] + ) + for wi in wf_inputs] if wf_inputs else [] + + def get_workflow_outputs(self, wf_id: str): + """Return a list of OutputParameter objects for a workflow.""" + wf_outputs = bdb.getall(self.db_file, 'SELECT * FROM workflow_output WHERE workflow_id=?', + [wf_id]) + return [OutputParameter( + id=wo[0], + type=wo[2], + value=wo[3], + source=wo[4] + ) + for wo in wf_outputs] if wf_outputs else [] + + def get_workflow_state(self, wf_id: str) -> str: + """Return the state of a workflow.""" + state = bdb.getone(self.db_file, 'SELECT state FROM workflow WHERE id=?', [wf_id]) + return state[0] if state else None + + def get_workflow_requirements_and_hints(self, wf_id: str): + """Return all workflow requirements and hints from the db. + + Must return a tuple with the format (requirements, hints) + + :rtype: (list of Requirement, list of Hint) + """ + wf_data = bdb.getone(self.db_file, 'SELECT reqs, hints FROM workflow WHERE id=?', [wf_id]) + if not wf_data: + return ([], []) + + requirements = [Requirement.model_validate(r) for r in json.loads(wf_data[0])] + hints = [Hint.model_validate(h) for h in json.loads(wf_data[1])] + return (requirements, hints) + + def get_workflow_tasks(self, wf_id: str): + """Return a list of all workflow tasks from the db. + + :rtype: list of Task + """ + tasks_data = bdb.getall(self.db_file, 'SELECT id FROM task WHERE workflow_id=?', [wf_id]) + tasks = [self.get_task(t[0]) for t in tasks_data] if tasks_data else [] + return tasks + + def get_ready_tasks(self, wf_id: str): + """Return tasks with state 'READY' from the db. + + :rtype: list of Task + """ + tasks_data = bdb.getall(self.db_file, + "SELECT id FROM task WHERE workflow_id=? AND state='READY'", + [wf_id]) + tasks = [self.get_task(t[0]) for t in tasks_data] if tasks_data else [] + return tasks + + def get_dependent_tasks(self, task_id: str): + """Return the dependent tasks of a workflow task in the db. + + :param task_id: the id of the task to get dependents for + :type task_id: str + :rtype: list of Task + """ + deps_data = bdb.getall(self.db_file, + "SELECT depending_task_id FROM task_dep WHERE depends_on_task_id=?", + [task_id]) + deps = [self.get_task(d[0]) for d in deps_data] if deps_data else [] + return deps + + def get_task_state(self, task_id: str): + """Return the state of a task in the db. + + :param task_id: the id of the task to get the state for + :type task_id: str + :rtype: str + """ + state = bdb.getone(self.db_file, 'SELECT state FROM task WHERE id=?', [task_id]) + return state[0] if state else None + + def get_task_metadata(self, task_id: str): + """Return the metadata of a task in the db. + + :param task_id: the id of the task to get metadata for + :type task_id: str + :rtype: dict + """ + metadata = bdb.getone(self.db_file, 'SELECT metadata FROM task WHERE id=?', [task_id]) + return json.loads(metadata[0]) if metadata and metadata[0] else {} + + def set_task_metadata(self, task_id: str, metadata: dict): + """Set the metadata of a task in the db. + + :param task_id: the id of the task to set metadata for + :type task_id: str + :param metadata: the job description metadata + :type metadata: dict + """ + prior_metadata = self.get_task_metadata(task_id) + prior_metadata.update(metadata) + metadata_json = json.dumps(prior_metadata) + bdb.run(self.db_file, 'UPDATE task SET metadata=? WHERE id=?', + [metadata_json, task_id]) + + def get_task_input(self, task_id: str, input_id: str): + """Get a task input object. + + :param task_id: the id of the task to get the input for + :type task_id: str + :param input_id: the id of the input to get + :type input_id: str + :rtype: StepInput + """ + ti_data = bdb.getone(self.db_file, + 'SELECT * FROM task_input WHERE task_id=? AND id=?', + [task_id, input_id]) + if not ti_data: + return None + + task_input = StepInput( + id=ti_data[0], + type=ti_data[2], + value=ti_data[3], + default=ti_data[4], + source=ti_data[5], + prefix=ti_data[6], + position=ti_data[7], + value_from=ti_data[8] + ) + return task_input + + def set_task_input(self, task_id: str, input_id: str, value: str): + """Set the value of a task input. + + :param task_id: the id of the task to set the input for + :type task_id: str + :param input_id: the id of the input to set + :type input_id: str + :param value: the new value for the input + :type value: str + """ + bdb.run(self.db_file, + 'UPDATE task_input SET value=? WHERE task_id=? AND id=?', + [value, task_id, input_id]) + + def get_task_output(self, task_id: str, output_id: str): + """Get a task output object. + + :param task_id: the id of the task to get the output for + :type task_id: str + :param output_id: the id of the output to get + :type output_id: str + :rtype: StepOutput + """ + to_data = bdb.getone(self.db_file, + 'SELECT * FROM task_output WHERE task_id=? AND id=?', + [task_id, output_id]) + if not to_data: + return None + + task_output = StepOutput( + id=to_data[0], + type=to_data[2], + value=to_data[3], + glob=to_data[4] + ) + return task_output + + def set_task_output(self, task_id: str, output_id: str, value: str): + """Set the value of a task output. + + :param task_id: the id of the task to set the output for + :type task_id: str + :param output_id: the id of the output to set + :type output_id: str + :param value: the new value for the output + :type value: str + """ + bdb.run(self.db_file, + 'UPDATE task_output SET value=? WHERE task_id=? AND id=?', + [value, task_id, output_id]) + + def set_task_input_type(self, task_id: str, input_id: str, type_: str): + """Set the type of a task input. + + :param task_id: the id of the task to set the input type for + :type task_id: str + :param input_id: the id of the input to set + :type input_id: str + :param type_: the new type for the input + :type type_: str + """ + bdb.run(self.db_file, + 'UPDATE task_input SET type=? WHERE task_id=? AND id=?', + [type_, task_id, input_id]) + + def set_task_output_glob(self, task_id: str, output_id: str, glob: str): + """Set the glob of a task output. + + :param task_id: the id of the task to set the output glob for + :type task_id: str + :param output_id: the id of the output to set + :type output_id: str + :param glob: the new glob for the output + :type glob: str + """ + bdb.run(self.db_file, + 'UPDATE task_output SET glob=? WHERE task_id=? AND id=?', + [glob, task_id, output_id]) + + def final_tasks_completed(self, wf_id: str) -> bool: + """Determine if a workflow's final tasks have completed. + + A workflow's final tasks have completed if each of its final tasks has finished or failed. + + :param wf_id: the ID of the workflow to check + :type wf_id: str + :rtype: bool + """ + placeholders = ','.join('?' for _ in final_task_states) + final_tasks_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state NOT IN ({placeholders}); + """ + + params = [wf_id, *final_task_states] + + result = bdb.getone(self.db_file, final_tasks_query, params) + return result is not None and result[0] == 0 + + def final_tasks_succeeded(self, wf_id: str) -> bool: + """Determine if a workflow's final tasks have succeeded. + + A workflow's final tasks have succeeded if each of its final tasks has + finished successfully. + + :param wf_id: the ID of the workflow to check + :type wf_id: str + :rtype: bool + """ + placeholders = ','.join('?' for _ in failed_task_states) + final_tasks_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state IN ({placeholders}); + """ + + params = [wf_id, *failed_task_states] + + result = bdb.getone(self.db_file, final_tasks_query, params) + return self.final_tasks_completed(wf_id) and result is not None and result[0] == 0 + + def final_tasks_failed(self, wf_id: str) -> bool: + """Determine if all of a workflow's final tasks have failed. + + :param wf_id: the ID of the workflow to check + :type wf_id: str + :rtype: bool + """ + placeholders = ','.join('?' for _ in failed_task_states) + ',?' + final_tasks_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state NOT IN ({placeholders}); + """ + + params = [wf_id, *failed_task_states, 'RESTARTED'] + + result = bdb.getone(self.db_file, final_tasks_query, params) + return result is not None and result[0] == 0 + + def cancelled_final_tasks_completed(self, wf_id: str) -> bool: + """Determine if a cancelled workflow's final tasks have completed. + + All of the workflow's scheduled tasks are completed if each of the final task nodes + are not in states 'PENDING', 'RUNNING', or 'COMPLETING'. + + :param wf_id: the ID of the workflow to check + :type wf_id: str + :rtype: bool + """ + incomplete_states = ['SUBMIT', 'PENDING', 'RUNNING', 'COMPLETING'] + placeholders = ','.join('?' for _ in incomplete_states) + final_tasks_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state IN ({placeholders}); + """ + + params = [wf_id, *incomplete_states] + + result = bdb.getone(self.db_file, final_tasks_query, params) + return result is not None and result[0] == 0 + + def remove_workflow(self, wf_id: str): + """Remove a workflow and all its associated tasks from the db.""" + delete_wf_query = """ + DELETE FROM workflow + WHERE id = ?;""" + bdb.run(self.db_file, delete_wf_query, [wf_id]) diff --git a/beeflow/common/db/wfm_db.py b/beeflow/common/db/wfm_db.py index db3f5a0f9..da3cad574 100644 --- a/beeflow/common/db/wfm_db.py +++ b/beeflow/common/db/wfm_db.py @@ -2,6 +2,7 @@ from collections import namedtuple +import time from beeflow.common.db import bdb @@ -26,9 +27,18 @@ def get_port(self, component): """Return port for the specified component.""" # Need to add code here to make sure we chose a valid component. stmt = f"SELECT {component}_port FROM info" - result = bdb.getone(self.db_file, stmt)[0] - port = result - return port + # Retry to get the bolt port retries times + retries = 5 + for attempt in range(1, retries + 1): + result = bdb.getone(self.db_file, stmt) + if result: + port = result[0] + return port + if retries > 0: + time.sleep(attempt) + retries -= 1 + raise RuntimeError("Port for {component} not found in table after" + "{retries} retries. Waited for retries {retries * (retries + 1) / 2") def increment_num_workflows(self): """Set workflow manager port.""" diff --git a/beeflow/common/gdb/gdb_driver.py b/beeflow/common/gdb/gdb_driver.py index 8a22fc2ac..a8e813369 100644 --- a/beeflow/common/gdb/gdb_driver.py +++ b/beeflow/common/gdb/gdb_driver.py @@ -24,40 +24,35 @@ def initialize_workflow(self, workflow): """ @abstractmethod - def execute_workflow(self): + def get_all_workflow_info(self): + """Return a list of all workflows in the graph database. + + :rtype: list of workflowinfo + """ + + @abstractmethod + def execute_workflow(self, workflow_id): """Begin execution of the stored workflow. Set the initial tasks' states to 'READY'. """ @abstractmethod - def pause_workflow(self): + def pause_workflow(self, workflow_id): """Pause execution of a running workflow. Set workflow from state 'RUNNING' to 'PAUSED'. """ @abstractmethod - def resume_workflow(self): + def resume_workflow(self, workflow_id): """Resume execution of a paused workflow. Set workflow state from 'PAUSED' to 'RUNNING'. """ @abstractmethod - def reset_workflow(self, new_id): - """Reset the execution state of a stored workflow. - - Set all task states to 'WAITING'. - Change the workflow ID of the Workflow and Task nodes to new_id. - Delete all task metadata except for task state. - - :param new_id: the new workflow ID - :type new_id: str - """ - - @abstractmethod - def load_task(self, task, task_state): + def load_task(self, task): """Load a task into a stored workflow. Dependencies should be automatically deduced and generated by the graph database @@ -69,7 +64,7 @@ def load_task(self, task, task_state): """ @abstractmethod - def initialize_ready_tasks(self): + def initialize_ready_tasks(self, workflow_id): """Set runnable tasks to state 'READY'. Runnable tasks are tasks with all input dependencies fulfilled. @@ -106,21 +101,21 @@ def get_task_by_id(self, task_id): """ @abstractmethod - def get_workflow_description(self): + def get_workflow_description(self, workflow_id): """Return a reconstructed Workflow object from the graph database. :rtype: Workflow """ @abstractmethod - def get_workflow_state(self): + def get_workflow_state(self, workflow_id): """Return the current state of the workflow. :rtype: str """ @abstractmethod - def set_workflow_state(self, state): + def set_workflow_state(self, workflow_id, state): """Set the state of the workflow. :param state: the new state of the workflow @@ -128,14 +123,14 @@ def set_workflow_state(self, state): """ @abstractmethod - def get_workflow_tasks(self): + def get_workflow_tasks(self, workflow_id): """Return a list of all workflow tasks from the graph database. :rtype: list of Task """ @abstractmethod - def get_workflow_requirements_and_hints(self): + def get_workflow_requirements_and_hints(self, workflow_id): """Return all workflow requirements and hints from the graph database. Must return a tuple with the format (requirements, hints) @@ -144,7 +139,7 @@ def get_workflow_requirements_and_hints(self): """ @abstractmethod - def get_workflow_inputs_and_outputs(self): + def get_workflow_inputs_and_outputs(self, workflow_id): """Return all workflow inputs and outputs from the graph database. Returns a tuple of (inputs, outputs). @@ -153,7 +148,7 @@ def get_workflow_inputs_and_outputs(self): """ @abstractmethod - def get_ready_tasks(self): + def get_ready_tasks(self, workflow_id): """Return tasks with state 'READY' from the graph database. :rtype: list of Task @@ -276,7 +271,7 @@ def set_task_output_glob(self, task_id, output_id, glob): """ @abstractmethod - def workflow_completed(self): + def workflow_completed(self, workflow_id): """Determine if a workflow has completed. A workflow has completed if each of its final tasks has finished or failed. @@ -285,14 +280,14 @@ def workflow_completed(self): """ @abstractmethod - def get_workflow_final_state(self): + def get_workflow_final_state(self, workflow_id): """Get the final state of the workflow. :rtype: Optional[str] """ @abstractmethod - def cancelled_workflow_completed(self): + def cancelled_workflow_completed(self, workflow_id): """Determine if a cancelled workflow has completed. A cancelled workflow has completed if each of its final tasks are not @@ -306,5 +301,5 @@ def close(self): """Close the connection to the graph database.""" @abstractmethod - def export_graphml(self): + def export_graphml(self, workflow_id): """Export a BEE workflow as a graphml.""" diff --git a/beeflow/common/gdb/generate_graph.py b/beeflow/common/gdb/generate_graph.py index 2208e5e51..b58055850 100644 --- a/beeflow/common/gdb/generate_graph.py +++ b/beeflow/common/gdb/generate_graph.py @@ -73,10 +73,14 @@ def add_nodes_to_dot(graph, dot): def get_node_label_and_color(label, attributes, label_to_color): """Return the appropriate node label and color based on node type.""" + if label == ":Task": + task_name = attributes.get('name', label) + task_state = attributes.get('state', '') + + return f"{task_name}\n({task_state})", label_to_color.get(label, 'gray') label_to_attribute = { ":Workflow": "Workflow", ":Output": attributes.get('glob', label), - ":Metadata": attributes.get('state', label), ":Task": attributes.get('name', label), ":Input": attributes.get('source', label), ":Hint": attributes.get('class', label), diff --git a/beeflow/common/gdb/neo4j_cypher.py b/beeflow/common/gdb/neo4j_cypher.py index 1789bdf24..a07af66dd 100644 --- a/beeflow/common/gdb/neo4j_cypher.py +++ b/beeflow/common/gdb/neo4j_cypher.py @@ -5,7 +5,7 @@ log = bee_logging.setup(__name__) failed_task_states = ['FAILED', 'SUBMIT_FAIL', 'BUILD_FAIL', 'DEP_FAIL', 'TIMEOUT', 'CANCELLED'] -final_task_states = ['COMPLETED'] + failed_task_states +final_task_states = ['COMPLETED', 'RESTARTED'] + failed_task_states def create_bee_node(tx): """Create a BEE node in the Neo4j database. @@ -121,14 +121,16 @@ def create_task(tx, task): "SET t.stdout = $stdout " "SET t.stderr = $stderr " "SET t.reqs = $reqs " - "SET t.hints = $hints") + "SET t.hints = $hints " + "SET t.state = $state " + "SET t.workdir = $workdir") # Unpack requirements, hints dictionaries into flat list reqs = len(task.requirements) > 0 hints = len(task.hints) > 0 tx.run(create_query, task_id=task.id, workflow_id=task.workflow_id, name=task.name, base_command=task.base_command, stdout=task.stdout, stderr=task.stderr, - reqs=reqs, hints=hints) + reqs=reqs, hints=hints, state=task.state, workdir=task.workdir) def create_task_hint_nodes(tx, task): @@ -200,7 +202,7 @@ def create_task_output_nodes(tx, task): value=output.value, glob=output.glob) -def create_task_metadata_node(tx, task, task_state): +def create_task_metadata_node(tx, task): """Create a task metadata node in the Neo4j database. The node holds metadata about a task's execution state. @@ -209,9 +211,10 @@ def create_task_metadata_node(tx, task, task_state): :type task: Task """ metadata_query = ("MATCH (t:Task {id: $task_id}) " - "CREATE (m:Metadata {state: $task_state})-[:DESCRIBES]->(t)") + "CREATE (m:Metadata)-[:DESCRIBES]->(t)") - tx.run(metadata_query, task_id=task.id, task_state=task_state) + tx.run(metadata_query, task_id=task.id) + set_task_metadata(tx, task.id, task.metadata) def add_dependencies(tx, task, old_task=None, restarted_task=False): @@ -453,8 +456,7 @@ def get_ready_tasks(tx, wf_id): :type workflow_id: str :rtype: neo4j.Result """ - get_ready_query = ("MATCH (:Metadata {state: 'READY'})-[:DESCRIBES]->" - "(t:Task {workflow_id: $wf_id}) RETURN t") + get_ready_query = "MATCH (t:Task {workflow_id: $wf_id, state: 'READY'}) RETURN t" return [rec['t'] for rec in tx.run(get_ready_query, wf_id=wf_id)] @@ -478,7 +480,7 @@ def get_task_state(tx, task_id): :type task: Task :rtype: str """ - state_query = "MATCH (m:Metadata)-[:DESCRIBES]->(:Task {id: $task_id}) RETURN m.state" + state_query = "MATCH (t:Task {id: $task_id}) RETURN t.state" return tx.run(state_query, task_id=task_id).single().value() @@ -491,8 +493,7 @@ def set_task_state(tx, task_id, state): :param state: the new task state :type state: str """ - state_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(:Task {id: $task_id}) " - "SET m.state = $state") + state_query = "MATCH (t:Task {id: $task_id}) SET t.state = $state" tx.run(state_query, task_id=task_id, state=state) @@ -655,12 +656,12 @@ def copy_task_outputs(tx, task): "WHERE i.source = o.id AND o.value IS NOT NULL " "SET i.value = o.value") # Set any values to defaults if necessary - defaults_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(:Task)<-[:DEPENDS_ON]-" + defaults_query = ("MATCH (pt:Task)<-[:DEPENDS_ON]-" "(t:Task)-[:DEPENDS_ON]->(:Task {id: $task_id}) " - "WITH m, t " + "WITH pt, t " "MATCH (t)<-[:INPUT_OF]-(i:Input) " - "WITH i, collect(m) AS mlist " - "WHERE all(m IN mlist WHERE m.state = 'COMPLETED') " + "WITH i, collect(pt) AS ptlist " + "WHERE all(pt IN ptlist WHERE pt.state = 'COMPLETED') " "AND i.value IS NULL AND i.default IS NOT NULL " "SET i.value = i.default") workflow_output_query = ("MATCH (:Workflow)<-[:OUTPUT_OF]-(wo:Output) " @@ -674,63 +675,17 @@ def copy_task_outputs(tx, task): tx.run(workflow_output_query, task_id=task.id) -def set_running_tasks_to_paused(tx): - """Set 'RUNNING' task states to 'PAUSED'.""" - set_paused_query = ("MATCH (m:Metadata {state: 'RUNNING'})-[:DESCRIBES]->(:Task) " - "SET m.state = 'PAUSED'") - - tx.run(set_paused_query) - - -def set_paused_tasks_to_running(tx): - """Set 'PAUSED' task states to 'RUNNING'.""" - set_running_query = ("MATCH (m:Metadata {state: 'PAUSED'})-[:DESCRIBES]->(:Task) " - "SET m.state = 'RUNNING'") - - tx.run(set_running_query) - - def set_runnable_tasks_to_ready(tx, wf_id): """Set task states to 'READY' if all required inputs have values.""" - set_runnable_ready_query = ("MATCH (m:Metadata)-[:DESCRIBES]->" - "(t:Task {workflow_id: $wf_id})<-[:INPUT_OF]-(i:Input) " - "WITH m, t, collect(i) AS ilist " - "WHERE m.state = 'WAITING' " + set_runnable_ready_query = ("MATCH (t:Task {workflow_id: $wf_id})<-[:INPUT_OF]-(i:Input) " + "WITH t, collect(i) AS ilist " + "WHERE t.state = 'WAITING' " "AND all(i IN ilist WHERE i.value IS NOT NULL) " - "SET m.state = 'READY'") + "SET t.state = 'READY'") tx.run(set_runnable_ready_query, wf_id=wf_id) -def reset_tasks_metadata(tx, wf_id): - """Reset the metadata for each of a workflow's tasks. - - :param wf_id: the workflow's ID - :type wf_id: str - """ - reset_metadata_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(t:Task {workflow_id: $wf_id}) " - "DETACH DELETE m " - "WITH t " - "CREATE (:Metadata {state: 'WAITING'})-[:DESCRIBES]->(t)") - - tx.run(reset_metadata_query, wf_id=wf_id) - - -def reset_workflow_id(tx, old_id, new_id): - """Reset the workflow ID of the workflow using uuid4. - - :param old_id: the old workflow ID - :type old_id: str - :param new_id: the new workflow ID - :type new_id: str - """ - reset_workflow_id_query = ("MATCH (w:Workflow {id: $old_id}), (t:Task {workflow_id: $old_id}) " - "SET w.id = $new_id " - "SET t.workflow_id = $new_id") - - tx.run(reset_workflow_id_query, old_id=old_id, new_id=new_id) - - def final_tasks_completed(tx, wf_id): """Return true if each of a workflow's final Task nodes is in a finished state. @@ -739,9 +694,9 @@ def final_tasks_completed(tx, wf_id): :rtype: bool """ restart = "|RESTARTED_FROM" if get_workflow_by_id(tx, wf_id)['restart'] else "" - not_completed_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(t:Task {workflow_id: $wf_id}) " + not_completed_query = ("MATCH (t:Task {workflow_id: $wf_id}) " f"WHERE NOT (t)<-[:DEPENDS_ON{restart}]-(:Task) " - f"AND NOT m.state IN {final_task_states} " + f"AND NOT t.state IN {final_task_states} " "RETURN t IS NOT NULL LIMIT 1") # False if at least one task is not finished @@ -756,9 +711,9 @@ def final_tasks_succeeded(tx, wf_id): :rtype: bool """ restart = "|RESTARTED_FROM" if get_workflow_by_id(tx, wf_id)['restart'] else "" - not_succeeded_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(t:Task {workflow_id: $wf_id}) " + not_succeeded_query = ("MATCH (t:Task {workflow_id: $wf_id}) " f"WHERE NOT (t)<-[:DEPENDS_ON{restart}]-(:Task) " - "AND m.state <> 'COMPLETED' " + "AND t.state <> 'COMPLETED' " "RETURN t IS NOT NULL LIMIT 1") # False if at least one task with state not 'COMPLETED' @@ -773,9 +728,9 @@ def final_tasks_failed(tx, wf_id): :rtype: bool """ restart = "|RESTARTED_FROM" if get_workflow_by_id(tx, wf_id)['restart'] else "" - not_failed_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(t:Task {workflow_id: $wf_id}) " + not_failed_query = ("MATCH (t:Task {workflow_id: $wf_id}) " f"WHERE NOT (t)<-[:DEPENDS_ON{restart}]-(:Task) " - f"AND NOT m.state IN {failed_task_states} " + f"AND NOT t.state IN {failed_task_states} " "RETURN t IS NOT NULL LIMIT 1") # False if at least one task is not failed @@ -793,9 +748,9 @@ def cancelled_final_tasks_completed(tx, wf_id): :rtype: bool """ restart = "|RESTARTED_FROM" if get_workflow_by_id(tx, wf_id)['restart'] else "" - active_states_query = ("MATCH (m:Metadata)-[:DESCRIBES]->(t:Task {workflow_id: $wf_id}) " + active_states_query = ("MATCH (t:Task {workflow_id: $wf_id}) " f"WHERE NOT (t)<-[:DEPENDS_ON{restart}]-(:Task) " - "AND m.state IN ['PENDING', 'RUNNING', 'COMPLETING'] " + "AND t.state IN ['SUBMIT', 'PENDING', 'RUNNING', 'COMPLETING'] " "RETURN t IS NOT NULL LIMIT 1") # False if at least one task is in 'PENDING', 'RUNNING', or 'COMPLETING' diff --git a/beeflow/common/gdb/neo4j_driver.py b/beeflow/common/gdb/neo4j_driver.py index dc3f5223f..4fd94d69d 100644 --- a/beeflow/common/gdb/neo4j_driver.py +++ b/beeflow/common/gdb/neo4j_driver.py @@ -150,22 +150,7 @@ def resume_workflow(self, workflow_id): tx.set_workflow_state, state="Running", wf_id=workflow_id ) - def reset_workflow(self, old_id, new_id): - """Reset the execution state of an entire workflow. - - Sets all task states to 'WAITING'. - Changes the workflow ID of the Workflow and Task nodes with new_id. - - :param new_id: the new workflow ID - :type new_id: str - """ - with self._driver.session() as session: - session.write_transaction(tx.reset_tasks_metadata, wf_id=old_id) - session.write_transaction( - tx.reset_workflow_id, old_id=old_id, new_id=new_id - ) - - def load_task(self, task, task_state): + def load_task(self, task): """Load a task into a workflow stored in the Neo4j database. Dependencies are automatically deduced and generated by Neo4j upon loading @@ -183,7 +168,7 @@ def load_task(self, task, task_state): session.write_transaction(tx.create_task_input_nodes, task=task) session.write_transaction(tx.create_task_output_nodes, task=task) session.write_transaction( - tx.create_task_metadata_node, task=task, task_state=task_state + tx.create_task_metadata_node, task=task ) session.write_transaction(tx.add_dependencies, task=task) @@ -215,8 +200,9 @@ def restart_task(self, old_task, new_task): session.write_transaction(tx.create_task_input_nodes, task=new_task) session.write_transaction(tx.create_task_output_nodes, task=new_task) session.write_transaction( - tx.create_task_metadata_node, task=new_task, task_state="WAITING" + tx.create_task_metadata_node, task=new_task ) + session.write_transaction(tx.set_task_state, task_id=new_task.id, state="WAITING") session.write_transaction( tx.add_dependencies, task=new_task, @@ -243,7 +229,7 @@ def get_task_by_id(self, task_id): task_record = self._read_transaction(tx.get_task_by_id, task_id=task_id) tuples = self._get_task_data_tuples([task_record]) return _reconstruct_task( - tuples[0][0], tuples[0][1], tuples[0][2], tuples[0][3], tuples[0][4] + tuples[0][0], tuples[0][1], tuples[0][2], tuples[0][3], tuples[0][4], tuples[0][5] ) def get_all_workflow_info(self): @@ -328,7 +314,7 @@ def get_workflow_tasks(self, workflow_id): task_records = self._read_transaction(tx.get_workflow_tasks, wf_id=workflow_id) tuples = self._get_task_data_tuples(task_records) return [ - _reconstruct_task(tup[0], tup[1], tup[2], tup[3], tup[4]) for tup in tuples + _reconstruct_task(tup[0], tup[1], tup[2], tup[3], tup[4], tup[5]) for tup in tuples ] def get_workflow_requirements_and_hints(self, workflow_id): @@ -380,7 +366,7 @@ def get_ready_tasks(self, workflow_id): task_records = self._read_transaction(tx.get_ready_tasks, wf_id=workflow_id) tuples = self._get_task_data_tuples(task_records) return [ - _reconstruct_task(tup[0], tup[1], tup[2], tup[3], tup[4]) for tup in tuples + _reconstruct_task(tup[0], tup[1], tup[2], tup[3], tup[4], tup[5]) for tup in tuples ] def get_dependent_tasks(self, task_id): @@ -393,7 +379,7 @@ def get_dependent_tasks(self, task_id): task_records = self._read_transaction(tx.get_dependent_tasks, task_id=task_id) tuples = self._get_task_data_tuples(task_records) return [ - _reconstruct_task(tup[0], tup[1], tup[2], tup[3], tup[4]) for tup in tuples + _reconstruct_task(tup[0], tup[1], tup[2], tup[3], tup[4], tup[5]) for tup in tuples ] def get_task_state(self, task_id): @@ -607,6 +593,10 @@ def _get_task_data_tuples(self, task_records): session.read_transaction(tx.get_task_outputs, task_id=rec["id"]) for rec in trecords ] + metadata_records = [ + session.read_transaction(tx.get_task_metadata, task_id=rec["id"]) + for rec in trecords + ] hints = [_reconstruct_hints(hint_record) for hint_record in hint_records] reqs = [_reconstruct_requirements(req_record) for req_record in req_records] @@ -616,8 +606,12 @@ def _get_task_data_tuples(self, task_records): outputs = [ _reconstruct_task_outputs(output_record) for output_record in output_records ] + metadata = [ + _reconstruct_metadata(metadata_record) + for metadata_record in metadata_records + ] - return list(zip(trecords, hints, reqs, inputs, outputs)) + return list(zip(trecords, hints, reqs, inputs, outputs, metadata)) def _read_transaction(self, tx_fun, **kwargs): """Run a Neo4j read transaction. @@ -785,7 +779,7 @@ def _reconstruct_workflow(workflow_record, hints, requirements, inputs, outputs) ) -def _reconstruct_task(task_record, hints, requirements, inputs, outputs): +def _reconstruct_task(task_record, hints, requirements, inputs, outputs, metadata): """Reconstruct a Task object by its record retrieved from Neo4j. :param task_record: the database record of the task @@ -807,10 +801,13 @@ def _reconstruct_task(task_record, hints, requirements, inputs, outputs): requirements=requirements, inputs=inputs, outputs=outputs, + metadata=metadata, stdout=task_record["stdout"], stderr=task_record["stderr"], workflow_id=task_record["workflow_id"], id=task_record["id"], + state=task_record["state"], + workdir=task_record["workdir"], ) @@ -823,4 +820,4 @@ def _reconstruct_metadata(metadata_record): :type keys: iterable of str :rtype: dict """ - return {key: val for key, val in metadata_record.items() if key != "state"} + return dict(metadata_record.items()) diff --git a/beeflow/common/gdb/sqlite3_driver.py b/beeflow/common/gdb/sqlite3_driver.py new file mode 100644 index 000000000..aea4422dc --- /dev/null +++ b/beeflow/common/gdb/sqlite3_driver.py @@ -0,0 +1,377 @@ +"""Graph database driver using SQLite as the backend.""" + +import os +from beeflow.common.gdb.gdb_driver import GraphDatabaseDriver +from beeflow.common.config_driver import BeeConfig as bc +from beeflow.common.db.gdb_db import SQL_GDB + +def db_path(): + """Return the SQL GDB database path.""" + bee_workdir = bc.get('DEFAULT', 'bee_workdir') + return os.path.join(bee_workdir, 'gdb_sql.db') + +def connect_db(): + """Connect to the SQL GDB database.""" + return SQL_GDB(db_path()) + + + +class SQLDriver(GraphDatabaseDriver): + """Graph database driver using SQLite as the backend.""" + + def __new__(cls): + if not hasattr(cls, 'instance'): + cls.instance = super(SQLDriver, cls).__new__(cls) + return cls.instance + + def connect(self): + """Connect to the graph database.""" + if not hasattr(self, 'db'): + self.db = connect_db() # pylint: disable=W0201 + + + def initialize_workflow(self, workflow): + """Begin construction of a workflow in the graph database. + + Should create the Workflow, Requirement, and Hint nodes in the graph database. + + :param workflow: the workflow description + :type workflow: Workflow + """ + self.db.create_workflow(workflow) + + + def execute_workflow(self, workflow_id): + """Begin execution of the stored workflow. + + Set the initial tasks' states to 'READY'. + """ + self.db.set_init_task_inputs(workflow_id) + self.db.set_runnable_tasks_to_ready(workflow_id) + + + def pause_workflow(self, workflow_id): + """Pause execution of a running workflow. + + Set workflow from state 'RUNNING' to 'PAUSED'. + """ + self.db.set_workflow_state(workflow_id, 'PAUSED') + + + def resume_workflow(self, workflow_id): + """Resume execution of a paused workflow. + + Set workflow state from 'PAUSED' to 'RUNNING'. + """ + self.db.set_workflow_state(workflow_id, 'RUNNING') + + + def load_task(self, task): + """Load a task into a stored workflow. + + Dependencies should be automatically deduced and generated by the graph database + upon loading each task by matching workflow inputs with new task inputs, + or task outputs with new task inputs. + + :param task: a workflow task + :type task: Task + """ + self.db.create_task(task) + self.db.add_dependencies(task) + + + def initialize_ready_tasks(self, workflow_id): + """Set runnable tasks to state 'READY'. + + Runnable tasks are tasks with all input dependencies fulfilled. + """ + self.db.set_runnable_tasks_to_ready(workflow_id) + + + def restart_task(self, old_task, new_task): + """Restart a failed task. + + Create a Task node for new_task with state 'RESTARTED' and an edge + to indicate that it is the child of the Task node of old_task. + + :param old_task: the failed task + :type old_task: Task + :param new_task: the new (restarted) task + :type new_task: Task + """ + self.db.create_task(new_task) + self.db.set_task_state(new_task.id, 'WAITING') + self.db.add_dependencies(new_task, old_task=old_task, restarted_task=True) + + + def finalize_task(self, task): + """Set task state to 'COMPLETED' and set inputs from source. + + :param task: the task to finalize + :type task: Task + """ + self.db.set_task_state(task.id, 'COMPLETED') + self.db.copy_task_outputs(task) + + + def get_task_by_id(self, task_id): + """Return a reconstructed Task object from the graph database by its ID. + + :param task_id: a task's ID + :type task_id: str + :rtype: Task + """ + return self.db.get_task(task_id) + + + def get_all_workflow_info(self): + """Return a list of all workflows in the graph database. + + :rtype: list of workflowinfo + """ + return self.db.get_all_workflow_info() + + + def get_workflow_description(self, workflow_id): + """Return a reconstructed Workflow object from the graph database. + + :rtype: Workflow + """ + return self.db.get_workflow(workflow_id) + + + def get_workflow_state(self, workflow_id): + """Return the current state of the workflow. + + :rtype: str + """ + return self.db.get_workflow_state(workflow_id) + + + def set_workflow_state(self, workflow_id, state): + """Set the state of the workflow. + + :param state: the new state of the workflow + :type state: str + """ + self.db.set_workflow_state(workflow_id, state) + + + def get_workflow_tasks(self, workflow_id): + """Return a list of all workflow tasks from the graph database. + + :rtype: list of Task + """ + return self.db.get_workflow_tasks(workflow_id) + + + def get_workflow_requirements_and_hints(self, workflow_id): + """Return all workflow requirements and hints from the graph database. + + Must return a tuple with the format (requirements, hints) + + :rtype: (list of Requirement, list of Hint) + """ + return self.db.get_workflow_requirements_and_hints(workflow_id) + + + def get_workflow_inputs_and_outputs(self, workflow_id): + """Return all workflow inputs and outputs from the graph database. + + Returns a tuple of (inputs, outputs). + + :rtype: (list of InputParameter, list of OutputParameter) + """ + return (self.db.get_workflow_inputs(workflow_id), + self.db.get_workflow_outputs(workflow_id)) + + + def get_ready_tasks(self, workflow_id): + """Return tasks with state 'READY' from the graph database. + + :rtype: list of Task + """ + return self.db.get_ready_tasks(workflow_id) + + + def get_dependent_tasks(self, task_id): + """Return the dependent tasks of a workflow task in the graph database. + + :param task_id: the ID of the task whose dependents to retrieve + :type task_id: str + :rtype: list of Task + """ + return self.db.get_dependent_tasks(task_id) + + + def get_task_state(self, task_id): + """Return the state of a task in the graph database. + + :param task_id: the ID of the task whose status to retrieve + :type task_id: str + :rtype: str + """ + return self.db.get_task_state(task_id) + + + def set_task_state(self, task_id, state): + """Set the state of a task in the graph database. + + :param task_id: the ID of the task whose state to set + :type task_id: str + :param state: the new state + :type state: str + """ + self.db.set_task_state(task_id, state) + + + def get_task_metadata(self, task_id): + """Return the metadata of a task in the graph database. + + :param task_id: the ID of the task whose metadata to retrieve + :type task_id: str + :rtype: dict + """ + return self.db.get_task_metadata(task_id) + + + def set_task_metadata(self, task_id, metadata): + """Set the metadata of a task in the graph database. + + :param task_id: the ID of the task whose metadata to set + :type task_id: str + :param metadata: the job description metadata + :type metadata: dict + """ + self.db.set_task_metadata(task_id, metadata) + + + def get_task_input(self, task_id, input_id): + """Get a task input object. + + :param task_id: the ID of the task whose input to retrieve + :type task_id: str + :param input_id: the ID of the input + :type input_id: str + :rtype: StepInput + """ + return self.db.get_task_input(task_id, input_id) + + + def set_task_input(self, task_id, input_id, value): + """Set the value of a task input. + + :param task_id: the ID of the task whose input to set + :type task_id: str + :param input_id: the ID of the input + :type input_id: str + :param value: str or int or float + """ + self.db.set_task_input(task_id, input_id, value) + + + def get_task_output(self, task_id, output_id): + """Get a task output object. + + :param task_id: the ID of the task whose output to retrieve + :type task_id: str + :param output_id: the ID of the output + :type output_id: str + :rtype: StepOutput + """ + return self.db.get_task_output(task_id, output_id) + + + def set_task_output(self, task_id, output_id, value): + """Set the value of a task output. + + :param task_id: the ID of the task whose output to set + :type task_id: str + :param output_id: the ID of the output + :type output_id: str + :param value: the output value to set + :type value: str or int or float + """ + self.db.set_task_output(task_id, output_id, value) + + + def set_task_input_type(self, task_id, input_id, type_): + """Set the type of a task input. + + :param task_id: the ID of the task whose input type to set + :type task_id: str + :param input_id: the ID of the input + :type input_id: str + :param type_: the input type to set + :param type_: str + """ + self.db.set_task_input_type(task_id, input_id, type_) + + + def set_task_output_glob(self, task_id, output_id, glob): + """Set the glob of a task output. + + :param task_id: the ID of the task whose output glob to set + :type task_id: str + :param output_id: the ID of the output + :type output_id: str + :param glob: the output glob to set + :type glob: str + """ + self.db.set_task_output_glob(task_id, output_id, glob) + + + def workflow_completed(self, workflow_id): + """Determine if a workflow has completed. + + A workflow has completed if each of its final tasks has finished or failed. + + :rtype: bool + """ + return self.db.final_tasks_completed(workflow_id) + + + def get_workflow_final_state(self, workflow_id): + """Get the final state of the workflow. + + :rtype: Optional[str] + """ + final_state = None + if self.db.final_tasks_succeeded(workflow_id): + final_state = None + elif self.db.final_tasks_failed(workflow_id): + final_state = 'Failed' + elif self.db.final_tasks_completed(workflow_id): + final_state = 'Partial-Fail' + else: + raise ValueError(f"Workflow with id {workflow_id} has not finished.") + return final_state + + + def cancelled_workflow_completed(self, workflow_id: str) -> bool: + """Determine if a cancelled workflow has completed. + + A cancelled workflow has completed if each of its final tasks are not + 'PENDING', 'RUNNING' 'COMPLETING'. + + :rtype: bool + """ + return self.db.cancelled_final_tasks_completed(workflow_id) + + + def remove_workflow(self, workflow_id): + """Remove a workflow from the graph database. + + :param workflow_id: the ID of the workflow to remove + :type workflow_id: str + """ + self.db.remove_workflow(workflow_id) + + + def close(self): + """Close the graph database connection.""" + + + def export_graphml(self, workflow_id): + """Export a BEE workflow as a graphml.""" + raise NotImplementedError("GraphML export not implemented for SQLite GDB.") diff --git a/beeflow/common/object_models.py b/beeflow/common/object_models.py index dc7fe456c..8daefda76 100644 --- a/beeflow/common/object_models.py +++ b/beeflow/common/object_models.py @@ -192,6 +192,8 @@ class Task(BaseModel): workflow_id: str workdir: Optional[str | Path | os.PathLike] = None id: Optional[str] = None + state: Optional[str] = "WAITING" + metadata: Optional[dict] = {} @model_validator(mode="before") def generate_id_if_missing(cls, data): # pylint: disable=E0213 diff --git a/beeflow/common/wf_interface.py b/beeflow/common/wf_interface.py index 5599ee848..eba91a38e 100644 --- a/beeflow/common/wf_interface.py +++ b/beeflow/common/wf_interface.py @@ -73,7 +73,7 @@ def reset_workflow(self, workflow_id): self._workflow_id = workflow_id self._gdb_driver.set_workflow_state(self._workflow_id, 'SUBMITTED') - def add_task(self, task, task_state): + def add_task(self, task): """Add a new task to a BEE workflow. :param task: the name of the file to which to redirect stderr @@ -90,7 +90,7 @@ def add_task(self, task, task_state): task.hints = [] # Load the new task into the graph database - self._gdb_driver.load_task(task, task_state) + self._gdb_driver.load_task(task) def restart_task(self, task, checkpoint_file): """Restart a failed BEE workflow task. diff --git a/beeflow/tests/mocks.py b/beeflow/tests/mocks.py index 2823223f5..e2df8379d 100644 --- a/beeflow/tests/mocks.py +++ b/beeflow/tests/mocks.py @@ -170,10 +170,10 @@ def reset_workflow(self, old_id, new_id): # pylint: disable=W0613 self.task_metadata[task_id] = {} self.task_states[task_id] = 'WAITING' - def load_task(self, task, task_state): + def load_task(self, task): """Load a task into a workflow in the graph database.""" self.tasks[task.id] = task - self.task_states[task.id] = task_state + self.task_states[task.id] = task.state self.task_metadata[task.id] = {} self.inputs[task.id] = {} self.outputs[task.id] = {} @@ -190,8 +190,7 @@ def initialize_ready_tasks(self, workflow_id): # pylint: disable=W0613 def restart_task(self, _old_task, new_task): """Create a new task from a failed task checkpoint restart enabled.""" - task_state = "WAITING" - self.load_task(new_task, task_state) + self.load_task(new_task) def finalize_task(self, task): """Set a task's state to completed.""" diff --git a/beeflow/tests/test_sqlite_gdb.py b/beeflow/tests/test_sqlite_gdb.py new file mode 100644 index 000000000..ac1e76691 --- /dev/null +++ b/beeflow/tests/test_sqlite_gdb.py @@ -0,0 +1,223 @@ +"""Tests for sql_gdb module.""" + +from unittest.mock import call +import json + +import pytest + +from beeflow.common.db import gdb_db + + +@pytest.fixture +def sql_gdb_instance(mocker, tmp_path): + """Create SQL_GDB instance with bdb.create_table/runscript patched.""" + mocker.patch("beeflow.common.db.gdb_db.bdb.create_table") + mocker.patch("beeflow.common.db.gdb_db.bdb.runscript") + db_file = tmp_path / "test.db" + return gdb_db.SQL_GDB(str(db_file)) + + +@pytest.mark.parametrize("count, expected", [(0, True), (3, False)]) +def test_final_tasks_completed(mocker, sql_gdb_instance, count, expected): + """Regression test final_tasks_completed.""" + getone = mocker.patch( + "beeflow.common.db.gdb_db.bdb.getone", + return_value=(count,), + ) + + wf_id = "WFID" + result = sql_gdb_instance.final_tasks_completed(wf_id) + + placeholders = ",".join("?" for _ in gdb_db.final_task_states) + expected_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state NOT IN ({placeholders}); + """ + + getone.assert_called_once() + assert getone.mock_calls[0] == call( + sql_gdb_instance.db_file, + expected_query, + [wf_id, *gdb_db.final_task_states], + ) + assert result == expected + + +@pytest.mark.parametrize( + "completed, failed_count, expected", + [ + (True, 0, True), # all completed, none failed + (True, 2, False), # completed but some failed + (False, 0, False), # not completed yet + ], +) +def test_final_tasks_succeeded(mocker, sql_gdb_instance, completed, failed_count, expected): + """Regression test final_tasks_succeeded.""" + mocker.patch.object( + gdb_db.SQL_GDB, + "final_tasks_completed", + return_value=completed, + ) + getone = mocker.patch( + "beeflow.common.db.gdb_db.bdb.getone", + return_value=(failed_count,), + ) + + wf_id = "WFID" + result = sql_gdb_instance.final_tasks_succeeded(wf_id) + + placeholders = ",".join("?" for _ in gdb_db.failed_task_states) + expected_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state IN ({placeholders}); + """ + + getone.assert_called_once() + assert getone.mock_calls[0] == call( + sql_gdb_instance.db_file, + expected_query, + [wf_id, *gdb_db.failed_task_states], + ) + assert result == expected + + +@pytest.mark.parametrize("count, expected", [(0, True), (5, False)]) +def test_final_tasks_failed(mocker, sql_gdb_instance, count, expected): + """Regression test final_tasks_failed.""" + getone = mocker.patch( + "beeflow.common.db.gdb_db.bdb.getone", + return_value=(count,), + ) + + wf_id = "WFID" + result = sql_gdb_instance.final_tasks_failed(wf_id) + + placeholders = ",".join("?" for _ in gdb_db.failed_task_states) + ",?" + expected_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state NOT IN ({placeholders}); + """ + + getone.assert_called_once() + assert getone.mock_calls[0] == call( + sql_gdb_instance.db_file, + expected_query, + [wf_id, *gdb_db.failed_task_states, "RESTARTED"], + ) + assert result == expected + + +@pytest.mark.parametrize("count, expected", [(0, True), (1, False)]) +def test_cancelled_final_tasks_completed(mocker, sql_gdb_instance, count, expected): + """Regression test cancelled_final_tasks_completed.""" + getone = mocker.patch( + "beeflow.common.db.gdb_db.bdb.getone", + return_value=(count,), + ) + + wf_id = "WFID" + result = sql_gdb_instance.cancelled_final_tasks_completed(wf_id) + + incomplete_states = ["SUBMIT", "PENDING", "RUNNING", "COMPLETING"] + placeholders = ",".join("?" for _ in incomplete_states) + expected_query = f""" + SELECT COUNT(*) + FROM task + WHERE workflow_id = ? + AND state IN ({placeholders}); + """ + + getone.assert_called_once() + assert getone.mock_calls[0] == call( + sql_gdb_instance.db_file, + expected_query, + [wf_id, *incomplete_states], + ) + assert result == expected + + +@pytest.mark.parametrize( + "db_value, expected", + [ + (None, {}), + ((None,), {}), + ((json.dumps({"a": 1}),), {"a": 1}), + ], +) +def test_get_task_metadata(mocker, sql_gdb_instance, db_value, expected): + """Regression test get_task_metadata.""" + getone = mocker.patch( + "beeflow.common.db.gdb_db.bdb.getone", + return_value=db_value, + ) + + task_id = "TASK1" + result = sql_gdb_instance.get_task_metadata(task_id) + + getone.assert_called_once_with( + sql_gdb_instance.db_file, + "SELECT metadata FROM task WHERE id=?", + [task_id], + ) + assert result == expected + + +def test_set_task_metadata_merges_existing(mocker, sql_gdb_instance): + """Regression test set_task_metadata merges prior and new metadata.""" + task_id = "TASK1" + mocker.patch.object( + gdb_db.SQL_GDB, + "get_task_metadata", + return_value={"a": 1, "b": 2}, + ) + run = mocker.patch("beeflow.common.db.gdb_db.bdb.run") + + sql_gdb_instance.set_task_metadata(task_id, {"b": 3, "c": 4}) + + # Extract the JSON written to the DB and verify merge result. + assert run.call_count == 1 + _, args, _ = run.mock_calls[0] + assert args[0] == sql_gdb_instance.db_file + assert args[1] == "UPDATE task SET metadata=? WHERE id=?" + metadata_json, called_task_id = args[2] + assert called_task_id == task_id + merged = json.loads(metadata_json) + assert merged == {"a": 1, "b": 3, "c": 4} + + +def test_set_task_state(mocker, sql_gdb_instance): + """Regression test set_task_state issues correct UPDATE.""" + run = mocker.patch("beeflow.common.db.gdb_db.bdb.run") + + sql_gdb_instance.set_task_state("TASK1", "RUNNING") + + run.assert_called_once_with( + sql_gdb_instance.db_file, + """ + UPDATE task + SET state = :state + WHERE id = :task_id;""", + {"task_id": "TASK1", "state": "RUNNING"}, + ) + + +def test_set_workflow_state(mocker, sql_gdb_instance): + """Regression test set_workflow_state issues correct UPDATE.""" + run = mocker.patch("beeflow.common.db.gdb_db.bdb.run") + + sql_gdb_instance.set_workflow_state("WFID", "COMPLETED") + + run.assert_called_once_with( + sql_gdb_instance.db_file, + """ + UPDATE workflow + SET state = :state + WHERE id = :wf_id;""", + {"wf_id": "WFID", "state": "COMPLETED"}, + ) diff --git a/beeflow/tests/test_wf_interface.py b/beeflow/tests/test_wf_interface.py index f27fb4869..303936bc2 100644 --- a/beeflow/tests/test_wf_interface.py +++ b/beeflow/tests/test_wf_interface.py @@ -169,11 +169,11 @@ def test_add_task(self): outputs=outputs, stdout=stdout, stderr=stderr, - workflow_id=workflow_id) + workflow_id=workflow_id, + state="WAITING") - task_state = "WAITING" - self.wfi.add_task(task, task_state) + self.wfi.add_task(task) # Task object assertions self.assertEqual(task_name, task.name) @@ -227,11 +227,11 @@ def test_restart_task(self): outputs=outputs, stdout=stdout, stderr=stderr, - workflow_id=workflow_id) + workflow_id=workflow_id, + state="WAITING") - task_state = "WAITING" - self.wfi.add_task(task, task_state) + self.wfi.add_task(task) # Restart the task, should create a new Task new_task = self.wfi.restart_task(task, test_checkpoint_file) @@ -325,11 +325,12 @@ def test_get_task_by_id(self): outputs=outputs, stdout=stdout, stderr=stderr, - workflow_id=workflow_id) + workflow_id=workflow_id, + state="WAITING" + ) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + self.wfi.add_task(task) self.assertEqual(task, self.wfi.get_task_by_id(task.id)) @@ -441,10 +442,9 @@ def test_get_task_state(self): inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") - self.wfi.add_task(task, task_state) + self.wfi.add_task(task) # Should be WAITING self.assertEqual("WAITING", self.wfi.get_task_state(task.id)) @@ -462,9 +462,8 @@ def test_set_task_state(self): inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) self.wfi.set_task_state(task.id, "RUNNING") @@ -484,9 +483,8 @@ def test_get_task_metadata(self): inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) metadata = {"cluster": "fog", "crt": "charliecloud", "container_md5": "67df538c1b6893f4276d10b2af34ccfe", "job_id": 1337} @@ -506,9 +504,8 @@ def test_set_task_metadata(self): inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) metadata = {"cluster": "fog", "crt": "charliecloud", "container_md5": "67df538c1b6893f4276d10b2af34ccfe", "job_id": 1337} @@ -533,9 +530,8 @@ def test_get_task_input(self): inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) self.assertEqual(task.inputs[0], self.wfi.get_task_input(task.id, "test_input")) @@ -551,9 +547,8 @@ def test_set_task_input(self): name="test_task", base_command="ls", hints=None, requirements=None, inputs=[StepInput(id="test_input", type="File", value=None, default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) test_input = StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None) @@ -572,9 +567,8 @@ def test_get_task_output(self): name="test_task", base_command="ls", hints=None, requirements=None, inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) self.assertEqual(task.outputs[0], self.wfi.get_task_output(task.id, "test_task/output")) @@ -590,9 +584,8 @@ def test_set_task_output(self): name="test_task", base_command="ls", hints=None, requirements=None, inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value=None, glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) test_output = StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt") self.wfi.set_task_output(task.id, "test_task/output", "output.txt") @@ -610,9 +603,8 @@ def test_workflow_completed(self): name="test_task", base_command="ls", hints=None, requirements=None, inputs=[StepInput(id="test_input", type="File", value="input.txt", default="default.txt", source="test_input", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="test_task/output", type="File", value="output.txt", glob="output.txt")], - stdout=None, stderr=None, workflow_id=workflow_id) - task_state = "WAITING" - self.wfi.add_task(task, task_state) + stdout=None, stderr=None, workflow_id=workflow_id, state="WAITING") + self.wfi.add_task(task) # Workflow not completed self.assertFalse(self.wfi.workflow_completed()) @@ -634,7 +626,7 @@ def _create_test_tasks(self, workflow_id): inputs=[StepInput(id="test_input", type="File", value=None, default=None, source="test_input", prefix="-l", position=None, value_from=None)], outputs=[StepOutput(id="prep/prep_output", type="File", value="prep_output.txt", glob="prep_output.txt")], stdout="prep_output.txt", stderr=None, - workflow_id=workflow_id), + workflow_id=workflow_id, state="WAITING"), Task( name="compute0", base_command="touch", hints=[Hint(class_="ResourceRequirement", params={"ramMax": 2048})], @@ -642,7 +634,7 @@ def _create_test_tasks(self, workflow_id): inputs=[StepInput(id="input_data", type="File", value=None, default=None, source="prep/prep_output", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="compute0/output", type="File", value="output0.txt", glob="output0.txt")], stdout="output0.txt", stderr=None, - workflow_id=workflow_id), + workflow_id=workflow_id, state="WAITING"), Task( name="compute1", base_command="find", hints=[Hint(class_="ResourceRequirement", params={"ramMax": 2048})], @@ -650,7 +642,7 @@ def _create_test_tasks(self, workflow_id): inputs=[StepInput(id="input_data", type="File", value=None, default=None, source="prep/prep_output", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="compute1/output", type="File", value="output1.txt", glob="output1.txt")], stdout="output1.txt", stderr=None, - workflow_id=workflow_id), + workflow_id=workflow_id, state="WAITING"), Task( name="compute2", base_command="grep", hints=[Hint(class_="ResourceRequirement", params={"ramMax": 2048})], @@ -658,7 +650,7 @@ def _create_test_tasks(self, workflow_id): inputs=[StepInput(id="input_data", type="File", value=None, default=None, source="prep/prep_output", prefix=None, position=None, value_from=None)], outputs=[StepOutput(id="compute2/output", type="File", value="output2.txt", glob="output2.txt")], stdout="output2.txt", stderr=None, - workflow_id=workflow_id), + workflow_id=workflow_id, state="WAITING"), Task( name="visualization", base_command="python", hints=[Hint(class_="ResourceRequirement", params={"ramMax": 2048})], @@ -668,12 +660,11 @@ def _create_test_tasks(self, workflow_id): StepInput(id="input2", type="File", value=None, default=None, source="compute2/output", prefix="-i", position=3, value_from=None)], outputs=[StepOutput(id="viz/output", type="File", value="viz_output.txt", glob="viz_output.txt")], stdout="viz_output.txt", stderr=None, - workflow_id=workflow_id) + workflow_id=workflow_id, state="WAITING") ] - task_state = "WAITING" for task in tasks: - self.wfi.add_task(task, task_state) + self.wfi.add_task(task) return tasks diff --git a/beeflow/tests/test_wf_manager.py b/beeflow/tests/test_wf_manager.py index 60d61c8b6..5d7be242e 100644 --- a/beeflow/tests/test_wf_manager.py +++ b/beeflow/tests/test_wf_manager.py @@ -159,11 +159,13 @@ def test_workflow_status(client, mocker, setup_teardown_workflow): workflow_id=WF_ID, name='task', base_command='', + state='RUNNING' ), '124': Task( id='124', workflow_id=WF_ID, name='task', base_command='', + state='WAITING' )} mockGDB.task_states = { '123': 'RUNNING', diff --git a/beeflow/tests/test_wf_update.py b/beeflow/tests/test_wf_update.py index 841aa3e18..bdbb68e28 100644 --- a/beeflow/tests/test_wf_update.py +++ b/beeflow/tests/test_wf_update.py @@ -22,7 +22,13 @@ def test_archive_workflow(tmpdir, mocker, test_function, expected_state): ) mocker.patch( "beeflow.common.config_driver.BeeConfig.get", - return_value=str(tmpdir / "bee_archive_dir"), + side_effect=lambda section, option, *a, **kw: ( + str(tmpdir / "bee_archive_dir") + if (section, option) == ("DEFAULT", "bee_archive_dir") + else "neo4j" + if (section, option) == ("graphdb", "type") + else str(tmpdir / "bee_archive_dir") + ), ) mock_export_dag = mocker.patch("beeflow.wf_manager.resources.wf_utils.export_dag") mock_update_wf_status = mocker.patch( diff --git a/beeflow/wf_manager/resources/wf_actions.py b/beeflow/wf_manager/resources/wf_actions.py index d93fb3ca0..49695bba7 100644 --- a/beeflow/wf_manager/resources/wf_actions.py +++ b/beeflow/wf_manager/resources/wf_actions.py @@ -63,8 +63,7 @@ def get(wf_id): tasks = wfi.get_tasks() tasks_status = [] for task in tasks: - state = wfi.get_task_state(task.id) - tasks_status.append((task.id, task.name, state, wfi.get_task_metadata(task.id))) + tasks_status.append((task.id, task.name, task.state, wfi.get_task_metadata(task.id))) return ( WorkflowStatusResponse( diff --git a/beeflow/wf_manager/resources/wf_list.py b/beeflow/wf_manager/resources/wf_list.py index 584028ab9..54912907c 100644 --- a/beeflow/wf_manager/resources/wf_list.py +++ b/beeflow/wf_manager/resources/wf_list.py @@ -14,6 +14,7 @@ from beeflow.common import log as bee_logging from beeflow.common.gdb.neo4j_driver import Neo4jDriver +from beeflow.common.gdb import sqlite3_driver # from beeflow.common.wf_profiler import WorkflowProfiler @@ -68,9 +69,15 @@ class WFList(Resource): def get(self): """Return list of workflows to client.""" - db = connect_db(wfm_db, db_path) - wf_utils.connect_neo4j_driver(db.info.get_port("bolt")) - info = Neo4jDriver().get_all_workflow_info() + gdb = bc.get('graphdb', 'type').lower() + if gdb == 'sqlite3': + driver = sqlite3_driver.SQLDriver() + driver.connect() + info = driver.get_all_workflow_info() + else: + db = connect_db(wfm_db, db_path) + wf_utils.connect_neo4j_driver(db.info.get_port("bolt")) + info = Neo4jDriver().get_all_workflow_info() return ListWorkflowsResponse(workflow_info_list=info).model_dump(), 200 diff --git a/beeflow/wf_manager/resources/wf_update.py b/beeflow/wf_manager/resources/wf_update.py index 73626ceae..ec3c49b70 100644 --- a/beeflow/wf_manager/resources/wf_update.py +++ b/beeflow/wf_manager/resources/wf_update.py @@ -34,11 +34,12 @@ def archive_workflow(wf_id, final_state=None): shutil.copyfile(os.path.expanduser("~") + '/.config/beeflow/bee.conf', workflow_dir + '/' + 'bee.conf') # Archive Completed DAG - graphmls_dir = workflow_dir + "/graphmls" - os.makedirs(graphmls_dir, exist_ok=True) - dags_dir = workflow_dir + "/dags" - os.makedirs(dags_dir, exist_ok=True) - wf_utils.export_dag(wf_id, dags_dir, graphmls_dir, no_dag_dir=True) + if bc.get('graphdb', 'type').lower() == 'neo4j': + graphmls_dir = workflow_dir + "/graphmls" + os.makedirs(graphmls_dir, exist_ok=True) + dags_dir = workflow_dir + "/dags" + os.makedirs(dags_dir, exist_ok=True) + wf_utils.export_dag(wf_id, dags_dir, graphmls_dir, no_dag_dir=True) wf_state = f'Archived/{final_state}' if final_state is not None else 'Archived' wf_utils.update_wf_status(wf_id, wf_state) @@ -90,7 +91,7 @@ def put(self): msg='Task states updated successfully', ).model_dump(), 200 - def handle_metadata(self, state_update, task_id, wfi): + def handle_metadata(self, state_update, task_id, wfi, task_workdir): """Handle metadata for a task update.""" bee_workdir = wf_utils.get_bee_workdir() @@ -103,7 +104,6 @@ def handle_metadata(self, state_update, task_id, wfi): wfi.set_task_metadata(task_id, old_metadata) task_name = wfi.get_task_by_id(task_id).name - task_workdir = old_metadata['workdir'] task_dir = f'{task_workdir}/{task_name}-{task_id[:4]}' metadata_path = os.path.join(task_dir,'metadata.txt') @@ -150,7 +150,7 @@ def handle_state_change(self, state_update, task, wfi): wfi.set_task_output(task.id, output.id, output.glob) else: wfi.set_task_output(task.id, output.id, "temp") - wf_utils.copy_task_output(task, wfi) + wf_utils.copy_task_output(task) tasks = wfi.finalize_task(task) if tasks and wf_state not in ('Paused', 'Cancelled'): wf_utils.schedule_submit_tasks(state_update.wf_id, tasks) @@ -179,6 +179,6 @@ def update_task_state(self, state_update): task = wfi.get_task_by_id(state_update.task_id) wfi.set_task_state(state_update.task_id, state_update.job_state) - self.handle_metadata(state_update, state_update.task_id, wfi) + self.handle_metadata(state_update, state_update.task_id, wfi, task.workdir) if not self.handle_checkpoint_restart(state_update, task, wfi): self.handle_state_change(state_update, task, wfi) diff --git a/beeflow/wf_manager/resources/wf_utils.py b/beeflow/wf_manager/resources/wf_utils.py index be018325c..c8b08642e 100644 --- a/beeflow/wf_manager/resources/wf_utils.py +++ b/beeflow/wf_manager/resources/wf_utils.py @@ -9,7 +9,7 @@ from beeflow.common import log as bee_logging from beeflow.common.config_driver import BeeConfig as bc -from beeflow.common.gdb import neo4j_driver +from beeflow.common.gdb import neo4j_driver, sqlite3_driver from beeflow.common.gdb.generate_graph import generate_viz from beeflow.common.gdb.graphml_key_updater import update_graphml from beeflow.common.wf_interface import WorkflowInterface @@ -141,15 +141,20 @@ def create_wf_namefile(wf_name, wf_id): def get_workflow_interface(wf_id): """Instantiate and return workflow interface object.""" - db = connect_db(wfm_db, get_db_path()) + # Wait for the GDB # bolt_port = db.info.get_bolt_port() # return get_workflow_interface_by_bolt_port(wf_id, bolt_port) - driver = neo4j_driver.Neo4jDriver() - bolt_port = db.info.get_port("bolt") - if bolt_port != -1: - connect_neo4j_driver(bolt_port) + if bc.get('graphdb','type').lower() == 'sqlite3': + driver = sqlite3_driver.SQLDriver() + driver.connect() + else: + db = connect_db(wfm_db, get_db_path()) + driver = neo4j_driver.Neo4jDriver() + bolt_port = db.info.get_port("bolt") + if bolt_port != -1: + connect_neo4j_driver(bolt_port) wfi = WorkflowInterface(wf_id, driver) return wfi @@ -201,9 +206,6 @@ def _resource(component, tag=""): def submit_tasks_tm(wf_id, tasks, allocation): # pylint: disable=W0613 """Submit a task to the task manager.""" wfi = get_workflow_interface(wf_id) - for task in tasks: - metadata = wfi.get_task_metadata(task.id) - task.workdir = metadata["workdir"] # Serialize task with json names = [task.name for task in tasks] log.info("Submitted %s to Task Manager", names) @@ -277,11 +279,8 @@ def setup_workflow(wf_id, wf_name, wf_dir, wf_workdir, no_start, workflow=None, log.info("Setting workflow metadata") create_wf_metadata(wf_id, wf_name) for task in tasks: - task_state = "" if no_start else "WAITING" - wfi.add_task(task, task_state) - metadata = wfi.get_task_metadata(task.id) - metadata["workdir"] = task.workdir - wfi.set_task_metadata(task.id, metadata) + task.state = "" if no_start else "WAITING" + wfi.add_task(task) if no_start: update_wf_status(wf_id, "No Start") @@ -315,8 +314,7 @@ def start_workflow(wf_id): _, tasks = wfi.get_workflow() tasks.reverse() for task in tasks: - task_state = wfi.get_task_state(task.id) - if task_state == "": + if task.state == "": wfi.set_task_state(task.id, "WAITING") wfi.execute_workflow() tasks = wfi.get_ready_tasks() @@ -325,7 +323,7 @@ def start_workflow(wf_id): return True -def copy_task_output(task, wfi): +def copy_task_output(task): """Copies stdout, stderr, and metadata information to the task directory in the WF archive.""" bee_workdir = get_bee_workdir() @@ -333,7 +331,7 @@ def copy_task_output(task, wfi): task_save_path = pathlib.Path( f"{bee_workdir}/workflows/{task.workflow_id}/{task.name}-{task.id[:4]}" ) - task_workdir = wfi.get_task_metadata(task.id)["workdir"] + task_workdir = task.workdir task_metadata_path = pathlib.Path(f"{task_workdir}/{task.name}-{task.id[:4]}/"\ f"metadata.txt") if task.stdout: diff --git a/ci/bee_config.sh b/ci/bee_config.sh index c443e5269..8af473c33 100755 --- a/ci/bee_config.sh +++ b/ci/bee_config.sh @@ -11,7 +11,8 @@ Flux) esac mkdir -p $(dirname $BEE_CONFIG) -cat >> $BEE_CONFIG < $BEE_CONFIG <