From 96a8e5353f5cfc10d41b92f92d8cbe63d4cdd1b6 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 12 May 2025 16:03:01 +0000 Subject: [PATCH 001/168] Initial commit --- .../databases/spanner_test.sdl | 11 + .../databases/spanner_test_lib.py | 137 ++++++ .../databases/spanner_utils.py | 463 ++++++++++++++++++ .../databases/spanner_utils_test.py | 322 ++++++++++++ 4 files changed, 933 insertions(+) create mode 100644 grr/server/grr_response_server/databases/spanner_test.sdl create mode 100644 grr/server/grr_response_server/databases/spanner_test_lib.py create mode 100644 grr/server/grr_response_server/databases/spanner_utils.py create mode 100644 grr/server/grr_response_server/databases/spanner_utils_test.py diff --git a/grr/server/grr_response_server/databases/spanner_test.sdl b/grr/server/grr_response_server/databases/spanner_test.sdl new file mode 100644 index 000000000..7737e1402 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_test.sdl @@ -0,0 +1,11 @@ +CREATE TABLE `Table`( + `Key` STRING(MAX) NOT NULL, + `Column` STRING(MAX), + `Time` TIMESTAMP OPTIONS (allow_commit_timestamp = true) +) PRIMARY KEY (`Key`); + +CREATE TABLE `Subtable`( + `Key` STRING(MAX) NOT NULL, + `Subkey` STRING(MAX), +) PRIMARY KEY (`Key`, `Subkey`), + INTERLEAVE IN PARENT `Table` ON DELETE CASCADE; \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py new file mode 100644 index 000000000..4ee8fd174 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -0,0 +1,137 @@ +"""A library with utilities for testing the Spanner database implementation.""" +import os +import unittest + +from typing import Optional + +from absl.testing import absltest + +from google.cloud import spanner_v1 as spanner_lib +from google.cloud import spanner_admin_database_v1 +from google.cloud.spanner import Client, KeySet +from google.cloud.spanner_admin_database_v1.types import spanner_database_admin + +from grr_response_server.databases import spanner_utils + +OPERATION_TIMEOUT_SECONDS = 240 + +PROD_SCHEMA_SDL_PATH = "grr/server/grr_response_server/databases/spanner.sdl" +TEST_SCHEMA_SDL_PATH = "grr/server/grr_response_server/databases/spanner_test.sdl" + +def _GetEnvironOrSkip(key): + value = os.environ.get(key) + if value is None: + raise unittest.SkipTest("'%s' variable is not set" % key) + return value + +def _readSchemaFromFile(file_path): + """Reads DDL statements from a file.""" + with open(file_path, 'r') as f: + # Read the whole file and split by semicolon. + # Filter out any empty strings resulting from split. + ddl_statements = [stmt.strip() for stmt in f.read().split(';') if stmt.strip()] + return ddl_statements + +def Init(sdl_path: str) -> None: + """Initializes the Spanner testing environment. + + This must be called only once per test process. A `setUpModule` method is + a perfect place for it. + + """ + global _TEST_DB + + if _TEST_DB is not None: + raise AssertionError("Spanner test library already initialized") + + project_id = _GetEnvironOrSkip("PROJECT_ID") + instance_id = _GetEnvironOrSkip("SPANNER_TEST_INSTANCE") + database_id = _GetEnvironOrSkip("SPANNER_TEST_DATABASE") + + spanner_client = Client(project_id) + database_admin_api = spanner_client.database_admin_api + + ddl_statements = _readSchemaFromFile(sdl_path) + + request = spanner_database_admin.CreateDatabaseRequest( + parent=database_admin_api.instance_path(spanner_client.project, instance_id), + create_statement=f"CREATE DATABASE `{database_id}`", + extra_statements=ddl_statements + ) + + operation = database_admin_api.create_database(request=request) + + print("Waiting for operation to complete...") + database = operation.result(OPERATION_TIMEOUT_SECONDS) + + print( + "Created database {} on instance {}".format( + database.name, + database_admin_api.instance_path(spanner_client.project, instance_id), + ) + ) + + instance = spanner_client.instance(instance_id) + _TEST_DB = instance.database(database_id) + + +def TearDown() -> None: + """Tears down the Spanner testing environment. + + This must be called once per process after all the tests. A `tearDownModule` + is a perfect place for it. + """ + if _TEST_DB is not None: + # Create a client + _TEST_DB.drop() + + +def CreateTestDatabase() -> spanner_lib.database: + """Creates an empty test spanner database. + + Returns: + A PySpanner instance pointing to the created database. + """ + #if _TEST_DB is None: + # raise AssertionError("Spanner test database not initialized") + + db = spanner_utils.Database(_TEST_DB) + + query = """ + SELECT t.table_name + FROM information_schema.tables AS t + WHERE t.table_catalog = "" + AND t.table_schema = "" + ORDER BY t.table_name ASC + """ + + table_names = set() + for (table_name,) in db.Query(query): + table_names.add(table_name) + + query = """ + SELECT v.table_name + FROM information_schema.views AS v + WHERE v.table_catalog = "" + AND v.table_schema = "" + ORDER BY v.table_name ASC + """ + view_names = set() + for (view_name,) in db.Query(query): + view_names.add(view_name) + + # `table_names` is a superset of `view_names` (since the `VIEWS` table is, + # well, just a view to the `TABLES` table [1]). Since deleting from views + # makes no sense, we have to exclude them from the tables we want to clean. + table_names -= view_names + + keyset = KeySet(all_=True) + + with _TEST_DB.batch() as batch: + # Deletes sample data from all tables in the given database. + for table_name in table_names: + batch.delete(table_name, keyset) + + return _TEST_DB + +_TEST_DB: spanner_lib.database = None \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py new file mode 100644 index 000000000..53a33a08d --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -0,0 +1,463 @@ +"""Spanner-related helpers and other utilities.""" + +import contextlib +import datetime +import re +from typing import Any +from typing import Callable +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar + +from google.cloud import spanner_v1 as spanner_lib +from google.cloud.spanner import KeyRange, KeySet +from google.cloud.spanner_admin_database_v1.types import spanner_database_admin +from google.cloud.spanner_v1 import Mutation, param_types + +from grr_response_core.lib.util import collection +from grr_response_core.lib.util import iterator + +Row = Tuple[Any, ...] +Cursor = Iterator[Row] + +_T = TypeVar("_T") + +class Database: + """A wrapper around the PySpanner class. + + The wrapper is supposed to streamline the usage of Spanner database through + an abstraction that is much harder to misuse. The wrapper will run retryable + queries through a transaction runner handling all brittle logic for the user. + """ + + _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") + + def __init__(self, pyspanner: spanner_lib.database) -> None: + super().__init__() + self._pyspanner = pyspanner + + def _parametrize(self, query: str, names: Iterable[str]) -> str: + match = self._PYSPANNER_PARAM_REGEX.search(query) + if match is not None: + raise ValueError(f"Query contains illegal sequence: {match.group(0)}") + + kwargs = {} + for name in names: + kwargs[name] = f"@{name}" + + return query.format(**kwargs) + + def _get_param_type(self, value): + """ + Infers the Google Cloud Spanner type from a Python value. + + Args: + value: The Python value whose Spanner type is to be inferred. + + Returns: + A google.cloud.spanner_v1.types.Type object, or None if the type + cannot be reliably inferred (e.g., for a standalone None value or + an empty list). + """ + if value is None: + # Cannot determine a specific Spanner type from a None value alone. + # This indicates that the type is ambiguous without further schema context. + return None + + py_type = type(value) + + if py_type is int: + return param_types.INT64 + elif py_type is float: + return param_types.FLOAT64 + elif py_type is str: + return param_types.STRING + elif py_type is bool: + return param_types.BOOL + elif py_type is bytes: + return param_types.BYTES + elif py_type is datetime.date: + return param_types.DATE + elif py_type is datetime.datetime: + # Note: Spanner TIMESTAMPs are stored in UTC. Ensure datetime objects + # are timezone-aware (UTC) when writing data. This function only maps the type. + return param_types.TIMESTAMP + elif py_type is decimal.Decimal: + return param_types.NUMERIC + else: + # Potentially raise an error for unsupported types or return None + # For a generic solution, raising an error for unknown types is often safer. + raise TypeError(f"Unsupported Python type: {py_type.__name__} for Spanner type conversion.") + + def Transact( + self, + func: Callable[["Transaction"], _T], + txn_tag: Optional[str] = None, + ) -> List[Any]: + + """Execute the given callback function in a Spanner transaction. + + Args: + func: A transaction function to execute. + txn_tag: Transaction tag to apply. + + Returns: + The result the transaction function returned. + """ + return self._pyspanner.run_in_transaction(func) + + def Mutate( + self, func: Callable[["Mutation"], None], txn_tag: Optional[str] = None + ) -> None: + """Execute the given callback function in a Spanner mutation. + + Args: + func: A mutation function to execute. + txn_tag: Optional[str] = None, + """ + + self.Transact(func, txn_tag=txn_tag) + + def Query(self, query: str, txn_tag: Optional[str] = None) -> Cursor: + """Queries Spanner database using the given query string. + + Args: + query: An SQL string. + txn_tag: Spanner transaction tag. + + Returns: + A cursor over the query results. + """ + with self._pyspanner.snapshot() as snapshot: + results = snapshot.execute_sql(query) + + return results + + def QuerySingle(self, query: str, txn_tag: Optional[str] = None) -> Row: + """Queries PySpanner for a single row using the given query string. + + Args: + query: An SQL string. + txn_tag: Spanner transaction tag. + + Returns: + A single row matching the query. + + Raises: + NotFound: If the query did not return any results. + ValueError: If the query yielded more than one result. + """ + return self.Query(query, txn_tag=txn_tag).one() + + def ParamQuery( + self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> Cursor: + """Queries PySpanner database using the given query string with params. + + The query string should specify parameters with the standard Python format + placeholder syntax [1]. Note that parameters inside string literals in the + query itself have to be escaped. + + Also, the query literal is not allowed to contain any '@p{idx}' strings + inside as that would lead to an incorrect behaviour when evaluating the + query. To prevent mistakes the function will raise an exception in such + cases. + + [1]: https://docs.python.org/3/library/stdtypes.html#str.format + + Args: + query: An SQL string with parameter placeholders. + params: A dictionary mapping parameter name to a value. + txn_tag: Spanner transaction tag. + + Returns: + A cursor over the query results. + + Raises: + ValueError: If the query contains disallowed sequences. + KeyError: If some parameter is not specified. + """ + names, values = collection.Unzip(params.items()) + query = self._parametrize(query, names) + + param_type = {} + for key, value in params.items(): + try: + param_type[key] = self._get_param_type(value) + except TypeError as e: + print(f"Warning for key '{key}': {e}. Setting type to None.") + param_type[key] = None # Or re-raise, or handle differently + + print("query: {}".format(query)) + print("params: {}".format(params)) + print("param_type: {}".format(param_type)) + + with self._pyspanner.snapshot() as snapshot: + results = snapshot.execute_sql( + query, + params=params, + param_types=param_type, + ) + + return results + + def ParamQuerySingle( + self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> Row: + """Queries the database for a single row using with a query with params. + + See documentation for `ParamQuery` to learn more about the syntax of query + parameters and other caveats. + + Args: + query: An SQL string with parameter placeholders. + params: A dictionary mapping parameter name to a value. + txn_tag: Spanner transaction tag. + + Returns: + A single result of running the query. + + Raises: + NotFound: If the query did not return any results. + ValueError: If the query yielded more than one result. + ValueError: If the query contains disallowed sequences. + KeyError: If some parameter is not specified. + """ + return self.ParamQuery(query, params, txn_tag=txn_tag).one() + + def ParamExecute( + self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Executes the given query with parameters against a Spanner database. + + Args: + query: An SQL string with parameter placeholders. + params: A dictionary mapping parameter name to a value. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + + Raises: + ValueError: If the query contains disallowed sequences. + KeyError: If some parameter is not specified. + """ + names, values = collection.Unzip(params.items()) + query = self._parametrize(query, names) + + param_type = {} + for key, value in params.items(): + try: + param_type[key] = self._get_param_type(value) + except TypeError as e: + print(f"Warning for key '{key}': {e}. Setting type to None.") + param_type[key] = None # Or re-raise, or handle differently + + print("query: {}".format(query)) + print("params: {}".format(params)) + print("param_type: {}".format(param_type)) + + def param_execute(transaction): + row_ct = transaction.execute_update( + query, + params=params, + param_types=param_type, + ) + + print("{} record(s) updated.".format(row_ct)) + self._pyspanner.run_in_transaction(param_execute) + + def ExecutePartitioned( + self, query: str, txn_tag: Optional[str] = None + ) -> None: + """Executes the given query against a Spanner database. + + This is a more efficient variant of the `Execute` method, but it does not + guarantee atomicity. See the official documentation on partitioned updates + for more information [1]. + + [1]: go/spanner-partitioned-dml + + Args: + query: An SQL query string to execute. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + query_options = None + if txn_tag is not None: + query_options = spanner_lib.QueryOptions() + query_options.SetTag(txn_tag) + + return self._pyspanner.execute_partitioned_dml(query) + + def Insert( + self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Insert a row into the given table. + + Args: + table: A table into which the row is to be inserted. + row: A mapping from column names to column values of the row. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + columns, values = collection.Unzip(row.items()) + + columns = list(columns) + values = list(values) + + with self._pyspanner.batch() as batch: + batch.insert( + table=table, + columns=columns, + values=[values] + ) + + def Update( + self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Updates a row in the given table. + + Args: + table: A table in which the row is to be updated. + row: A mapping from column names to column values of the row. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + columns, values = collection.Unzip(row.items()) + + columns = list(columns) + values = list(values) + + with self._pyspanner.batch() as batch: + batch.update( + table=table, + columns=columns, + values=[values] + ) + + def InsertOrUpdate( + self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Insert or update a row into the given table within the transaction. + + Args: + table: A table into which the row is to be inserted. + row: A mapping from column names to column values of the row. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + columns, values = collection.Unzip(row.items()) + + columns = list(columns) + values = list(values) + + with self._pyspanner.mutation_groups() as groups: + groups.group().insert_or_update( + table=table, + columns=columns, + values=[values] + ) + groups.batch_write() + + def Delete( + self, table: str, key: Sequence[Any], txn_tag: Optional[str] = None + ) -> None: + """Deletes a specified row from the given table. + + Args: + table: A table from which the row is to be deleted. + key: A sequence of values denoting the key of the row to delete. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + keyset = KeySet(all_=True) + if key: + keyset = KeySet(keys=[key]) + with self._pyspanner.batch() as batch: + batch.delete(table, keyset) + + def DeleteWithPrefix(self, table: str, key_prefix: Sequence[Any]) -> None: + """Deletes a range of rows with common key prefix from the given table. + + Args: + table: A table from which rows are to be deleted. + key: A sequence of value denoting the prefix of the key of rows to delete. + + Returns: + Nothing. + """ + range = KeyRange(start_closed=key_prefix, end_closed=key_prefix) + keyset = KeySet(ranges=[range]) + + with self._pyspanner.batch() as batch: + batch.delete(table, keyset) + + def Read( + self, + table: str, + key: Sequence[Any], + cols: Sequence[str], + ) -> Mapping[str, Any]: + """Read a single row with the given key from the specified table. + + Args: + table: A name of the table to read from. + key: A key of the row to read. + cols: Columns of the row to read. + + Returns: + A mapping from columns to values of the read row. + """ + range = KeyRange(start_closed=key, end_closed=key) + keyset = KeySet(ranges=[range]) + with self._pyspanner.snapshot() as snapshot: + results = snapshot.read( + table=table, + columns=cols, + keyset=keyset + ) + + return results.one() + + def ReadSet( + self, + table: str, + rows: KeySet, + cols: Sequence[str], + ) -> Iterator[Mapping[str, Any]]: + """Read a set of rows from the specified table. + + Args: + table: A name of the table to read from. + rows: A set of keys specifying which rows to read. + cols: Columns of the row to read. + + Returns: + Mappings from columns to values of the rows read. + """ + with self._pyspanner.snapshot() as snapshot: + results = snapshot.read( + table=table, + columns=cols, + keyset=rows + ) + + return results \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py new file mode 100644 index 000000000..a7bbd41c5 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -0,0 +1,322 @@ +import datetime +import time +from typing import Any +from typing import Iterator +from typing import List +from typing import Mapping +from unittest import mock + +from absl.testing import absltest + +from google.cloud import spanner as spanner_lib +from google.api_core.exceptions import NotFound + +from grr_response_core.lib.util import iterator + +from grr_response_server.databases import spanner_test_lib +from grr_response_server.databases import spanner_utils + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.TEST_SCHEMA_SDL_PATH) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + +class DatabaseTest(absltest.TestCase): + + def setUp(self): + super().setUp() + + pyspanner = spanner_test_lib.CreateTestDatabase() + self.db = spanner_utils.Database(pyspanner) + + ####################################### + # Transact Tests + ####################################### + def testTransactionTransactional(self): + + def TransactionWrite(txn) -> None: + txn.insert( + table="Table", + columns=("Key",), + values=[("foo",), ("bar",)] + ) + + def TransactionRead(txn) -> List[Any]: + result = list(txn.execute_sql("SELECT t.Key FROM Table AS t")) + return result + + self.db.Transact(TransactionWrite) + results = self.db.Transact(TransactionRead) + self.assertCountEqual(results, [["foo"], ["bar"]]) + + ####################################### + # Query Tests + ####################################### + def testQuerySimple(self): + results = list(self.db.Query("SELECT 'foo', 42")) + self.assertEqual(results, [["foo", 42]]) + + def testQueryWithPlaceholders(self): + results = list(self.db.Query("SELECT '{}', '@p0'")) + self.assertEqual(results, [["{}", "@p0"]]) + + ####################################### + # QuerySingle Tests + ####################################### + def testQuerySingle(self): + result = self.db.QuerySingle("SELECT 'foo', 42") + self.assertEqual(result, ["foo", 42]) + + def testQuerySingleEmpty(self): + with self.assertRaises(NotFound): + self.db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([])") + + def testQuerySingleMultiple(self): + with self.assertRaises(ValueError): + self.db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([1, 2])") + + ####################################### + # ParamQuery Tests + ####################################### + def testParamQuerySingleParam(self): + query = "SELECT {abc}" + params = {"abc": 1337} + + results = list(self.db.ParamQuery(query, params)) + self.assertEqual(results, [[1337,]]) + + def testParamQueryMultipleParams(self): + timestamp = datetime.datetime.now(datetime.timezone.utc) + + query = "SELECT {int}, {str}, {timestamp}" + params = {"int": 1337, "str": "quux", "timestamp": timestamp} + + results = list(self.db.ParamQuery(query, params)) + self.assertEqual(results, [[1337, "quux", timestamp]]) + + def testParamQueryMissingParams(self): + with self.assertRaisesRegex(KeyError, "bar"): + self.db.ParamQuery("SELECT {foo}, {bar}", {"foo": 42}) + + def testParamQueryExtraParams(self): + query = "SELECT 42, {foo}" + params = {"foo": "foo", "bar": "bar"} + + results = list(self.db.ParamQuery(query, params)) + self.assertEqual(results, [[42, "foo"]]) + + def testParamQueryIllegalSequence(self): + with self.assertRaisesRegex(ValueError, "@p1337"): + self.db.ParamQuery("SELECT @p1337", {}) + + def testParamQueryLegalSequence(self): + results = list(self.db.ParamQuery("SELECT '@p', '@q'", {})) + self.assertEqual(results, [["@p", "@q"]]) + + def testParamQueryBraceEscape(self): + results = list(self.db.ParamQuery("SELECT '{{foo}}'", {})) + self.assertEqual(results, [["{foo}",]]) + + ####################################### + # ParamExecute Tests + ####################################### + def testParamExecuteSingleParam(self): + query = """ + INSERT INTO Table(Key) + VALUES ({key}) + """ + params = {"key": "foo"} + + self.db.ParamExecute(query, params) + + ####################################### + # ParamQuerySingle Tests + ####################################### + def testParamQuerySingle(self): + query = "SELECT {str}, {int}" + params = {"str": "foo", "int": 42} + + result = self.db.ParamQuerySingle(query, params) + self.assertEqual(result, ["foo", 42]) + + def testParamQuerySingleEmpty(self): + query = "SELECT {str}, {int} FROM UNNEST([])" + params = {"str": "foo", "int": 42} + + with self.assertRaises(NotFound): + self.db.ParamQuerySingle(query, params) + + def testParamQuerySingleMultiple(self): + query = "SELECT {str}, {int} FROM UNNEST([1, 2])" + params = {"str": "foo", "int": 42} + + with self.assertRaises(ValueError): + self.db.ParamQuerySingle(query, params) + + ####################################### + # ExecutePartitioned Tests + ####################################### + def testExecutePartitioned(self): + self.db.Insert(table="Table", row={"Key": "foo"}) + self.db.Insert(table="Table", row={"Key": "bar"}) + self.db.Insert(table="Table", row={"Key": "baz"}) + + self.db.ExecutePartitioned("DELETE FROM Table AS t WHERE t.Key LIKE 'ba%'") + + results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + self.assertLen(results, 1) + self.assertEqual(results[0], ["foo",]) + + ####################################### + # Insert Tests + ####################################### + def testInsert(self): + self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.db.Insert(table="Table", row={"Key": "bar", "Column": "bar@x.com"}) + + results = list(self.db.Query("SELECT t.Column FROM Table AS t")) + self.assertCountEqual(results, [["foo@x.com",], ["bar@x.com",]]) + + ####################################### + # Update Tests + ####################################### + def testUpdate(self): + self.db.Insert(table="Table", row={"Key": "foo", "Column": "bar@y.com"}) + self.db.Update(table="Table", row={"Key": "foo", "Column": "qux@y.com"}) + + results = list(self.db.Query("SELECT t.Column FROM Table AS t")) + self.assertEqual(results, [["qux@y.com",]]) + + def testUpdateNotExisting(self): + with self.assertRaises(NotFound): + self.db.Update(table="Table", row={"Key": "foo", "Column": "x@y.com"}) + + ####################################### + # InsertOrUpdate Tests + ####################################### + def testInsertOrUpdate(self): + row = {"Key": "foo"} + + row["Column"] = "bar@example.com" + self.db.InsertOrUpdate(table="Table", row=row) + + row["Column"] = "baz@example.com" + self.db.InsertOrUpdate(table="Table", row=row) + + results = list(self.db.Query("SELECT t.Column FROM Table AS t")) + self.assertEqual(results, [["baz@example.com",]]) + + ####################################### + # Delete Tests + ####################################### + def testDelete(self): + self.db.InsertOrUpdate(table="Table", row={"Key": "foo"}) + self.db.Delete(table="Table", key=("foo",)) + + results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + self.assertEmpty(results) + + def testDeleteSingle(self): + self.db.Insert(table="Table", row={"Key": "foo"}) + self.db.InsertOrUpdate(table="Table", row={"Key": "bar"}) + self.db.Delete(table="Table", key=("foo",)) + + results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + self.assertEqual(results, [["bar",]]) + + def testDeleteNotExisting(self): + # Should not raise. + self.db.Delete(table="Table", key=("foo",)) + + ####################################### + # DeleteWithPrefix Tests + ####################################### + def testDeleteWithPrefix(self): + self.db.Insert(table="Table", row={"Key": "foo"}) + self.db.Insert(table="Table", row={"Key": "quux"}) + + self.db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "bar"}) + self.db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "baz"}) + self.db.Insert(table="Subtable", row={"Key": "quux", "Subkey": "norf"}) + + self.db.DeleteWithPrefix(table="Subtable", key_prefix=["foo"]) + + results = list(self.db.Query("SELECT t.Key, t.Subkey FROM Subtable AS t")) + self.assertLen(results, 1) + self.assertEqual(results[0], ["quux", "norf"]) + + ####################################### + # Read Tests + ####################################### + def testReadSimple(self): + self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + + result = self.db.Read(table="Table", key=("foo",), cols=("Column",)) + self.assertEqual(result, ['foo@x.com']) + + def testReadNotExisting(self): + with self.assertRaises(NotFound): + self.db.Read(table="Table", key=("foo",), cols=("Column",)) + + ####################################### + # ReadSet Tests + ####################################### + def testReadSetEmpty(self): + self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + + rows = spanner_lib.KeySet() + results = list(self.db.ReadSet(table="Table", rows=rows, cols=("Column",))) + + self.assertEmpty(results) + + def testReadSetSimple(self): + self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.db.Insert(table="Table", row={"Key": "bar", "Column": "bar@y.com"}) + self.db.Insert(table="Table", row={"Key": "baz", "Column": "baz@z.com"}) + + keyset = spanner_lib.KeySet(keys=[["foo"], ["bar"]]) + results = list(self.db.ReadSet(table="Table", rows=keyset, cols=("Column",))) + + self.assertIn(["foo@x.com"], results) + self.assertIn(["bar@y.com"], results) + self.assertNotIn(["baz@z.com"], results) + + ####################################### + # Mutate Tests + ####################################### + def testMutateSimple(self): + + def Mutation(txn) -> None: + txn.insert( + table="Table", + columns=("Key",), + values=[("foo",)] + ) + txn.insert( + table="Table", + columns=("Key",), + values=[("bar",)] + ) + + self.db.Mutate(Mutation) + + results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + self.assertCountEqual(results, [["foo",], ["bar",]]) + + def testMutateException(self): + + def Mutation(txn) -> None: + txn.insert( + table="Table", + columns=("Key",), + values=[("foo",)] + ) + raise RuntimeError() + + with self.assertRaises(RuntimeError): + self.db.Mutate(Mutation) + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From d2f0ffe13beb1f4917d536bf0d79b77824fec74b Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 13 May 2025 18:52:22 +0000 Subject: [PATCH 002/168] First draft of GRR Datastore Spanner schema --- .gitignore | 2 + .../grr_response_server/databases/spanner.sdl | 670 ++++++++++++++++++ .../databases/spanner_setup.sh | 28 + .../databases/spanner_test_lib.py | 4 +- grr/server/setup.py | 1 + 5 files changed, 703 insertions(+), 2 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner.sdl create mode 100755 grr/server/grr_response_server/databases/spanner_setup.sh diff --git a/.gitignore b/.gitignore index 66aacf16b..b9e317e2d 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,5 @@ grr/server/grr_response_server/gui/ui/.angular/ docker_config_files/*.pem compose.watch.yaml Dockerfile.client + +grr/server/grr_response_server/databases/spanner_grr.pb diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl new file mode 100644 index 000000000..7ccbe9fc4 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -0,0 +1,670 @@ +CREATE PROTO BUNDLE ( + `google.protobuf.Any`, + `google.protobuf.Timestamp`, + `grr.ApprovalRequest`, + `grr.ApprovalRequest.ApprovalType`, + `grr.ClientLabel`, + `grr.GRRUser`, + `grr.GRRUser.UserType`, + `grr.GUISettings`, + `grr.GUISettings.UIMode`, + `grr.FleetspeakValidationInfo`, + `grr.FleetspeakValidationInfoTag`, + `grr.Password`, + `grr.PathInfo`, + `grr.PathInfo.PathType`, + `grr.UserNotification`, + `grr.UserNotification.State`, + `grr.UserNotification.Type`, + -- RDF magic types. + `grr.AttributedDict`, + `grr.BlobArray`, + `grr.Dict`, + `grr.DataBlob`, + `grr.DataBlob.CompressionType`, + `grr.EmbeddedRDFValue`, + `grr.KeyValue`, + -- Client snapshot types. + `grr.AmazonCloudInstance`, + `grr.ClientInformation`, + `grr.ClientCrash`, + `grr.ClientSnapshot`, + `grr.CloudInstance`, + `grr.CloudInstance.InstanceType`, + `grr.EdrAgent`, + `grr.Filesystem`, + `grr.GoogleCloudInstance`, + `grr.HardwareInfo`, + `grr.Interface`, + `grr.KnowledgeBase`, + `grr.NetworkAddress`, + `grr.NetworkAddress.Family`, + `grr.StartupInfo`, + `grr.StringMapEntry`, + `grr.UnixVolume`, + `grr.User`, + `grr.Volume`, + `grr.Volume.VolumeFileSystemFlagEnum`, + `grr.WindowsVolume`, + `grr.WindowsVolume.WindowsDriveTypeEnum`, + `grr.WindowsVolume.WindowsVolumeAttributeEnum`, + -- Notification reference types. + `grr.ApprovalRequestReference`, + `grr.ClientReference`, + `grr.CronJobReference`, + `grr.FlowLikeObjectReference`, + `grr.FlowLikeObjectReference.ObjectType`, + `grr.FlowReference`, + `grr.HuntReference`, + `grr.ObjectReference`, + `grr.ObjectReference.Type`, + `grr.VfsFileReference`, + -- Flow types. + `grr.CpuSeconds`, + `grr.Flow`, + `grr.Flow.FlowState`, + `grr.FlowIterator`, + `grr.FlowOutputPluginLogEntry`, + `grr.FlowOutputPluginLogEntry.LogEntryType`, + `grr.FlowProcessingRequest`, + `grr.FlowRequest`, + `grr.FlowResponse`, + `grr.FlowResultCount`, + `grr.FlowResultMetadata`, + `grr.FlowRunnerArgs`, + `grr.FlowStatus`, + `grr.FlowStatus.Status`, + `grr.GrrMessage`, + `grr.GrrMessage.AuthorizationState`, + `grr.GrrMessage.Type`, + `grr.GrrStatus`, + `grr.GrrStatus.ReturnedStatus`, + `grr.OutputPluginDescriptor`, + `grr.OutputPluginState`, + `grr.RequestState`, + -- Audit events types. + `grr.APIAuditEntry`, + `grr.APIAuditEntry.Code`, + -- File-related types. + `grr.AuthenticodeSignedData`, + `grr.BlobImageChunkDescriptor`, + `grr.BlobImageDescriptor`, + `grr.BufferReference`, + `grr.Hash`, + `grr.FileFinderResult`, + `grr.PathSpec`, + `grr.PathSpec.ImplementationType`, + `grr.PathSpec.Options`, + `grr.PathSpec.PathType`, + `grr.PathSpec.tsk_fs_attr_type`, + `grr.StatEntry`, + `grr.StatEntry.ExtAttr`, + `grr.StatEntry.RegistryType`, + -- Foreman rules types. + `grr.ForemanClientRule`, + `grr.ForemanClientRule.Type`, + `grr.ForemanClientRuleSet`, + `grr.ForemanClientRuleSet.MatchMode`, + `grr.ForemanCondition`, + `grr.ForemanIntegerClientRule`, + `grr.ForemanIntegerClientRule.ForemanIntegerField`, + `grr.ForemanIntegerClientRule.Operator`, + `grr.ForemanLabelClientRule`, + `grr.ForemanLabelClientRule.MatchMode`, + `grr.ForemanOsClientRule`, + `grr.ForemanRegexClientRule`, + `grr.ForemanRegexClientRule.ForemanStringField`, + `grr.ForemanRule`, + `grr.ForemanRuleAction`, + -- Artifact types. + `grr.Artifact`, + `grr.ArtifactSource`, + `grr.ArtifactSource.SourceType`, + `grr.ArtifactDescriptor`, + `grr.ClientActionResult`, + -- Hunt types. + `grr.Hunt`, + `grr.Hunt.HuntState`, + `grr.Hunt.HuntStateReason`, + `grr.HuntArguments`, + `grr.HuntArguments.HuntType`, + `grr.HuntArgumentsStandard`, + `grr.HuntArgumentsVariable`, + `grr.VariableHuntFlowGroup`, + -- SignedBinary types. + `grr.BlobReference`, + `grr.BlobReferences`, + `grr.SignedBinaryID`, + `grr.SignedBinaryID.BinaryType`, + -- CronJobs types. + `grr.CronJob`, + `grr.CronJobAction`, + `grr.CronJobAction.ActionType`, + `grr.CronJobRun`, + `grr.CronJobRun.CronJobRunStatus`, + `grr.SystemCronAction`, + `grr.HuntCronAction`, + `grr.HuntRunnerArgs`, + -- Message handlers. + `grr.MessageHandlerRequest`, + -- Signed Command types. + `grr.Command`, + `grr.Command.EnvVar`, + `grr.SignedCommand`, + `grr.SignedCommand.OS`, + -- RRG types. + `rrg.Log`, + `rrg.Log.Level`, + `rrg.fs.Path`, + `rrg.startup.Metadata`, + `rrg.startup.Startup`, + `rrg.startup.Version`, + `rrg.action.execute_signed_command.Command`, +); + +CREATE TABLE Labels ( + Label STRING(128) NOT NULL, +) PRIMARY KEY (Label); + +CREATE TABLE Clients ( + ClientId INT64 NOT NULL, + LastSnapshotTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastStartupTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastRRGStartupTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastCrashTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + FirstSeenTime TIMESTAMP, + LastPingTime TIMESTAMP, + LastForemanTime TIMESTAMP, + Certificate BYTES(MAX), + FleetspeakEnabled BOOL, + FleetspeakValidationInfo `grr.FleetspeakValidationInfo`, +) PRIMARY KEY (ClientId); + +CREATE TABLE ClientSnapshots( + ClientId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Snapshot `grr.ClientSnapshot` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientStartups( + ClientId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Startup `grr.StartupInfo` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientRRGStartups( + ClientId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Startup `rrg.startup.Startup` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientCrashes( + ClientId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Crash `grr.ClientCrash` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE Users ( + Username STRING(256) NOT NULL, + Email STRING(256), + Password `grr.Password`, + Type `grr.GRRUser.UserType`, + CanaryMode BOOL, + UiMode `grr.GUISettings.UIMode`, +) PRIMARY KEY (Username); + +CREATE TABLE UserNotifications( + Username STRING(256) NOT NULL, + NotificationId INT64 NOT NULL, + Type `grr.UserNotification.Type` NOT NULL, + State `grr.UserNotification.State` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Message STRING(MAX), + Reference `grr.ObjectReference` NOT NULL, +) PRIMARY KEY (Username, NotificationId), + INTERLEAVE IN PARENT Users ON DELETE CASCADE; + +CREATE TABLE ApprovalRequests( + Requestor STRING(256) NOT NULL, + ApprovalId INT64 NOT NULL, + SubjectClientId INT64, + SubjectHuntId INT64, + SubjectCronJobId STRING(100), + Reason STRING(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + ExpirationTime TIMESTAMP NOT NULL, + NotifiedUsers ARRAY NOT NULL, + CcEmails ARRAY NOT NULL, + CONSTRAINT fk_approval_request_requestor_username + FOREIGN KEY (Requestor) + REFERENCES Users(Username), + + -- TODO: Foreign keys on ARRAY columns are not supported, can we put an alternative in place? + -- CONSTRAINT fk_approval_request_notified_users_usernames + -- FOREIGN KEY UNNEST(NotifiedUsers) (NotifiedUsers) + -- REFERENCES Users(Username), + + CONSTRAINT ck_subject_id_valid + CHECK ((IF(SubjectClientId IS NOT NULL, 1, 0) + + IF(SubjectHuntId IS NOT NULL, 1, 0) + + IF(SubjectCronJobId IS NOT NULL, 1, 0)) = 1), +) PRIMARY KEY (Requestor, ApprovalId); + +CREATE INDEX ApprovalRequestsByRequestor + ON ApprovalRequests(Requestor); + +CREATE INDEX ApprovalRequestsByRequestorSubjectClientId + ON ApprovalRequests(Requestor, SubjectClientId); + +CREATE INDEX ApprovalRequestsByRequestorSubjectHuntId + ON ApprovalRequests(Requestor, SubjectHuntId); + +CREATE INDEX ApprovalRequestsByRequestorSubjectCronJobId + ON ApprovalRequests(Requestor, SubjectCronJobId); + +CREATE TABLE ApprovalGrants( + Requestor STRING(256) NOT NULL, + ApprovalId INT64 NOT NULL, + Grantor STRING(256) NOT NULL, + GrantId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT fk_approval_grant_grantor_username + FOREIGN KEY (Grantor) + REFERENCES Users(Username), +) PRIMARY KEY (Requestor, ApprovalId, Grantor, GrantId), + INTERLEAVE IN PARENT ApprovalRequests ON DELETE CASCADE; + +CREATE INDEX ApprovalGrantsByGrantor + ON ApprovalGrants(Grantor); + +CREATE TABLE ClientLabels( + ClientId INT64 NOT NULL, + Owner STRING(256) NOT NULL, + Label STRING(128) NOT NULL, + + CONSTRAINT fk_client_label_owner_username + FOREIGN KEY (Owner) + REFERENCES Users(Username), + + CONSTRAINT fk_client_label_label_label + FOREIGN KEY (Label) + REFERENCES Labels(Label), +) PRIMARY KEY (ClientId, Owner, Label), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientKeywords( + ClientId INT64 NOT NULL, + Keyword STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT ck_client_keyword_not_empty + CHECK (Keyword <> ''), +) PRIMARY KEY (ClientId, Keyword), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX ClientKeywordsByKeywordCreationTime + ON ClientKeywords(Keyword, CreationTime); + +CREATE TABLE Flows( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + ParentFlowId INT64, + ParentHuntId INT64, + LongFlowId STRING(256) NOT NULL, + Creator STRING(256) NOT NULL, + Name STRING(256) NOT NULL, + State `grr.Flow.FlowState` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + UpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Crash `grr.ClientCrash`, + NextRequestToProcess INT64, + ProcessingWorker STRING(MAX), + ProcessingStartTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + ProcessingEndTime TIMESTAMP, + Flow `grr.Flow` NOT NULL, + ReplyCount INT64 NOT NULL, + NetworkBytesSent INT64 NOT NULL, + UserCpuTimeUsed FLOAT64 NOT NULL, + SystemCpuTimeUsed FLOAT64 NOT NULL, + + CONSTRAINT fk_flow_client_id_parent_id_flow + FOREIGN KEY (ClientId, ParentFlowId) + REFERENCES Flows(ClientId, FlowId), +) PRIMARY KEY (ClientId, FlowId), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX FlowsByParentHuntIdFlowIdUpdateTime + ON Flows(ParentHuntId, FlowId, UpdateTime) + STORING (State); + +CREATE INDEX FlowsByParentHuntIdFlowIdState + ON Flows(ParentHuntId, FlowId, State) + STORING (ReplyCount, NetworkBytesSent, UserCpuTimeUsed, SystemCpuTimeUsed); + +CREATE TABLE FlowResults( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + HuntId INT64, + CreationTime TIMESTAMP NOT NULL, + Payload `google.protobuf.Any`, + RdfType STRING(MAX), + Tag STRING(MAX), +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowResultsByHuntIdCreationTime + ON FlowResults(HuntId, CreationTime); +CREATE INDEX FlowResultsByHuntIdFlowIdCreationTime + ON FlowResults(HuntId, FlowId, CreationTime); +CREATE INDEX FlowResultsByHuntIdFlowIdRdfTypeTagCreationTime + ON FlowResults(HuntId, FlowId, RdfType, Tag, CreationTime); + +CREATE TABLE FlowErrors( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + HuntId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL, + Payload `google.protobuf.Any`, + RdfType STRING(MAX), + Tag STRING(MAX), +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowErrorsByHuntIdFlowIdCreationTime + ON FlowErrors(HuntId, FlowId, CreationTime); +CREATE INDEX FlowErrorsByHuntIdFlowIdRdfTypeTagCreationTime + ON FlowErrors(HuntId, FlowId, RdfType, Tag, CreationTime); + +CREATE TABLE FlowRequests( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + RequestId INT64 NOT NULL, + NeedsProcessing BOOL, + ExpectedResponseCount INT64, + NextResponseId INT64, + CallbackState STRING(256), + Payload `grr.FlowRequest` NOT NULL, + StartTime TIMESTAMP, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), +) PRIMARY KEY (ClientId, FlowId, RequestId), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE, + ROW DELETION POLICY (OLDER_THAN(CreationTime, INTERVAL 84 DAY)); + +CREATE TABLE FlowResponses( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + RequestId INT64 NOT NULL, + ResponseId INT64 NOT NULL, + Response `grr.FlowResponse`, + Status `grr.FlowStatus`, + Iterator `grr.FlowIterator`, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT ck_flow_response_has_payload + CHECK ((IF(Response IS NOT NULL, 1, 0) + + IF(Status IS NOT NULL, 1, 0) + + IF(Iterator IS NOT NULL, 1, 0)) = 1), +) PRIMARY KEY (ClientId, FlowId, RequestId, ResponseId), + INTERLEAVE IN PARENT FlowRequests ON DELETE CASCADE; + +CREATE TABLE FlowLogEntries( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + HuntId INT64, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Message STRING(MAX) NOT NULL, +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowLogEntriesByHuntIdCreationTime + ON FlowLogEntries(HuntId, CreationTime); + +CREATE TABLE FlowRRGLogs( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + RequestId INT64 NOT NULL, + ResponseId INT64 NOT NULL, + LogLevel `rrg.Log.Level` NOT NULL, + LogTime TIMESTAMP NOT NULL, + LogMessage STRING(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), +) PRIMARY KEY (ClientId, FlowId, RequestId, ResponseId), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE TABLE FlowOutputPluginLogEntries( + ClientId INT64 NOT NULL, + FlowId INT64 NOT NULL, + OutputPluginId INT64 NOT NULL, + HuntId INT64, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, + Message STRING(MAX) NOT NULL, +) PRIMARY KEY (ClientId, FlowId, OutputPluginId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowOutputPluginLogEntriesByHuntIdCreationTime + ON FlowOutputPluginLogEntries(HuntId, CreationTime); + +CREATE TABLE ScheduledFlows( + ClientId INT64 NOT NULL, + Creator STRING(256) NOT NULL, + ScheduledFlowId INT64 NOT NULL, + FlowName STRING(256) NOT NULL, + FlowArgs `google.protobuf.Any` NOT NULL, + RunnerArgs `grr.FlowRunnerArgs` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Error STRING(MAX), + + CONSTRAINT fk_creator_users_username + FOREIGN KEY (Creator) + REFERENCES Users(Username), +) PRIMARY KEY (ClientId, Creator, ScheduledFlowId), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX ScheduledFlowsByCreator + ON ScheduledFlows(Creator); + +CREATE TABLE Paths( + ClientId INT64 NOT NULL, + Type `grr.PathInfo.PathType` NOT NULL, + Path BYTES(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + LastFileStatTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastFileHashTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + IsDir BOOL NOT NULL, + Depth INT64 NOT NULL, +) PRIMARY KEY (ClientId, Type, Path), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX PathsByClientIdTypePathDepth + ON Paths(ClientId, Type, Path, Depth); + +CREATE TABLE PathFileStats( + ClientId INT64 NOT NULL, + Type `grr.PathInfo.PathType` NOT NULL, + Path BYTES(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Stat `grr.StatEntry` NOT NULL, +) PRIMARY KEY (ClientId, Type, Path, CreationTime), + INTERLEAVE IN PARENT Paths ON DELETE CASCADE; + +CREATE TABLE PathFileHashes( + ClientId INT64 NOT NULL, + Type `grr.PathInfo.PathType` NOT NULL, + Path BYTES(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + FileHash `grr.Hash` NOT NULL, +) PRIMARY KEY (ClientId, Type, Path, CreationTime), + INTERLEAVE IN PARENT Paths ON DELETE CASCADE; + +CREATE TABLE HashBlobReferences( + -- 32 bytes is enough for SHA256 used for hash ids. + HashId BYTES(32) NOT NULL, + Offset INT64 NOT NULL, + -- 32 bytes is enough for SHA256 used for blob ids. + BlobId BYTES(32) NOT NULL, + Size INT64 NOT NULL, + + CONSTRAINT ck_hash_id_valid CHECK (BYTE_LENGTH(HashId) = 32), + CONSTRAINT ck_blob_id_valid CHECK (BYTE_LENGTH(BlobId) = 32), +) PRIMARY KEY (HashId, Offset); + +CREATE TABLE ApiAuditEntry ( + Username STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + HttpRequestPath String(MAX) NOT NULL, + RouterMethodName String(256) NOT NULL, + ResponseCode `grr.APIAuditEntry.Code` NOT NULL, +) PRIMARY KEY (Username, CreationTime); + +CREATE TABLE Artifacts ( + Name STRING(256) NOT NULL, + Platforms ARRAY NOT NULL, + Payload `grr.Artifact` NOT NULL, +) PRIMARY KEY (Name); + +CREATE TABLE YaraSignatureReferences( + BlobId BYTES(32) NOT NULL, + Creator STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT ck_yara_signature_reference_blob_id_valid + CHECK (BYTE_LENGTH(BlobId) = 32), + CONSTRAINT fk_yara_signature_reference_creator_username + FOREIGN KEY (Creator) + REFERENCES Users(Username), +) PRIMARY KEY (BlobId); + +CREATE TABLE Hunts( + HuntId INT64 NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + LastUpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Creator STRING(256) NOT NULL, + DurationMicros INT64 NOT NULL, + Description STRING(MAX) NOT NULL, + ClientRate FLOAT64 NOT NULL, + ClientLimit INT64 NOT NULL, + State `grr.Hunt.HuntState` NOT NULL, + StateReason `grr.Hunt.HuntStateReason` NOT NULL, + StateComment STRING(MAX) NOT NULL, + InitStartTime TIMESTAMP, + LastStartTime TIMESTAMP, + ClientCountAtStartTime INT64 NOT NULL, + Hunt `grr.Hunt` NOT NULL, + + CONSTRAINT fk_hunt_creator_username + FOREIGN KEY (Creator) + REFERENCES Users(Username), +) PRIMARY KEY (HuntId); + +ALTER TABLE FlowLogEntries ADD CONSTRAINT fk_flow_log_entry_hunt_id_hunt + FOREIGN KEY (HuntId) + REFERENCES Hunts(HuntId); +ALTER TABLE FlowOutputPluginLogEntries ADD CONSTRAINT fk_flow_output_plugin_log_entry_hunt_id_hunt + FOREIGN KEY (HuntId) + REFERENCES Hunts(HuntId); +ALTER TABLE Flows ADD CONSTRAINT fk_flow_parent_hunt_id_hunt + FOREIGN KEY (ParentHuntId) + REFERENCES Hunts(HuntId); + +CREATE INDEX HuntsByCreationTime + ON Hunts(CreationTime DESC); + +CREATE INDEX HuntsByCreator ON Hunts(Creator); + +CREATE TABLE HuntOutputPlugins( + HuntId INT64 NOT NULL, + OutputPluginId INT64 NOT NULL, + Name STRING(256) NOT NULL, + Args `google.protobuf.Any`, + State `google.protobuf.Any` NOT NULL, +) PRIMARY KEY (HuntId, OutputPluginId), + INTERLEAVE IN PARENT Hunts ON DELETE CASCADE; + +CREATE TABLE ForemanRules( + HuntId INT64 NOT NULL, + ExpirationTime TIMESTAMP NOT NULL, + Payload `grr.ForemanCondition`, + + CONSTRAINT fk_foreman_rule_hunt_id_hunt + FOREIGN KEY (HuntId) + REFERENCES Hunts(HuntId), +) PRIMARY KEY (HuntId); + +CREATE INDEX ForemanRulesByExpirationTime + ON ForemanRules(ExpirationTime); + +CREATE TABLE SignedBinaries ( + Type `grr.SignedBinaryID.BinaryType` NOT NULL, + Path STRING(MAX) NOT NULL, + BlobReferences `grr.BlobReferences` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), +) PRIMARY KEY (Type, Path); + +CREATE TABLE CronJobs ( + JobId STRING(256) NOT NULL, + Job `grr.CronJob` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Enabled BOOL NOT NULL, + CurrentRunId STRING(256), + ForcedRunRequested BOOL, + LastRunStatus `grr.CronJobRun.CronJobRunStatus`, + LastRunTime TIMESTAMP, + State `grr.AttributedDict`, + LeaseEndTime TIMESTAMP, + LeaseOwner STRING(256), +) PRIMARY KEY (JobId); + +CREATE TABLE CronJobRuns ( + JobId STRING(256) NOT NULL, + RunId STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + FinishTime TIMESTAMP, + Status `grr.CronJobRun.CronJobRunStatus` NOT NULL, + LogMessage STRING(MAX), + Backtrace STRING(MAX), + Payload `grr.CronJobRun` NOT NULL, +) PRIMARY KEY (JobId, RunId), + INTERLEAVE IN PARENT CronJobs ON DELETE CASCADE; + +ALTER TABLE CronJobs ADD CONSTRAINT fk_cron_job_run_job_id_current_run_id_cron_job_run + FOREIGN KEY (JobId, CurrentRunId) + REFERENCES CronJobRuns(JobId, RunId); + +ALTER TABLE ApprovalRequests + ADD CONSTRAINT fk_approval_request_subject_client_id_client + FOREIGN KEY (SubjectClientId) + REFERENCES Clients(ClientId); + +ALTER TABLE ApprovalRequests + ADD CONSTRAINT fk_approval_request_subject_hunt_id_hunt + FOREIGN KEY (SubjectHuntId) + REFERENCES Hunts(HuntId) + ON DELETE CASCADE; + +ALTER TABLE ApprovalRequests + ADD CONSTRAINT fk_approval_request_subject_cron_job_id_cron_job + FOREIGN KEY (SubjectCronJobId) + REFERENCES CronJobs(JobId) + ON DELETE CASCADE; + +CREATE TABLE BlobEncryptionKeys( + -- A unique identifier of the blob. + BlobId BYTES(32) NOT NULL, + -- A timestamp at which the association was created. + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + -- A name of the key (to retrieve the key from Keystore). + KeyName STRING(256) NOT NULL, +) PRIMARY KEY (BlobId, CreationTime); + +CREATE TABLE SignedCommands( + Id STRING(128) NOT NULL, + OperatingSystem `grr.SignedCommand.OS` NOT NULL, + Ed25519Signature BYTES(64) NOT NULL, + Command `rrg.action.execute_signed_command.Command` NOT NULL, +) PRIMARY KEY (Id, OperatingSystem); diff --git a/grr/server/grr_response_server/databases/spanner_setup.sh b/grr/server/grr_response_server/databases/spanner_setup.sh new file mode 100755 index 000000000..22a3a6406 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_setup.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Script to prepare the GRR protobufs and database on Spanner + +echo "1/3 : Bundling GRR protos in spanner_grr.pb..." +if [ ! -f ./spanner_grr.pb ]; then + protoc -I=../../../proto -I=/usr/include --include_imports --descriptor_set_out=./spanner_grr.pb \ + ../../../proto/grr_response_proto/analysis.proto \ + ../../../proto/grr_response_proto/artifact.proto \ + ../../../proto/grr_response_proto/flows.proto \ + ../../../proto/grr_response_proto/hunts.proto \ + ../../../proto/grr_response_proto/jobs.proto \ + ../../../proto/grr_response_proto/knowledge_base.proto \ + ../../../proto/grr_response_proto/objects.proto \ + ../../../proto/grr_response_proto/output_plugin.proto \ + ../../../proto/grr_response_proto/signed_commands.proto \ + ../../../proto/grr_response_proto/sysinfo.proto \ + ../../../proto/grr_response_proto/user.proto \ + ../../../proto/grr_response_proto/rrg.proto \ + ../../../proto/grr_response_proto/rrg/fs.proto \ + ../../../proto/grr_response_proto/rrg/startup.proto \ + ../../../proto/grr_response_proto/rrg/action/execute_signed_command.proto +fi + +echo "2/3 : Creating GRR database on Spanner..." +gcloud spanner databases create ${SPANNER_GRR_DATABASE} --instance ${SPANNER_GRR_INSTANCE} + +echo "3/3 : Creating tables ..." +gcloud spanner databases ddl update ${SPANNER_GRR_DATABASE} --instance=${SPANNER_GRR_INSTANCE} --ddl-file=spanner.sdl --proto-descriptors-file=spanner_grr.pb \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 4ee8fd174..590b9c92b 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -45,8 +45,8 @@ def Init(sdl_path: str) -> None: raise AssertionError("Spanner test library already initialized") project_id = _GetEnvironOrSkip("PROJECT_ID") - instance_id = _GetEnvironOrSkip("SPANNER_TEST_INSTANCE") - database_id = _GetEnvironOrSkip("SPANNER_TEST_DATABASE") + instance_id = _GetEnvironOrSkip("SPANNER_GRR_INSTANCE") + database_id = _GetEnvironOrSkip("SPANNER_GRR_DATABASE") spanner_client = Client(project_id) database_admin_api = spanner_client.database_admin_api diff --git a/grr/server/setup.py b/grr/server/setup.py index 16a674003..207f11722 100644 --- a/grr/server/setup.py +++ b/grr/server/setup.py @@ -179,6 +179,7 @@ def make_release_tree(self, base_dir, files): install_requires=[ "google-api-python-client==1.12.11", "google-auth==2.23.3", + "google-cloud-spanner==3.54.0", "google-cloud-storage==2.13.0", "google-cloud-pubsub==2.18.4", "grr-api-client==%s" % VERSION.get("Version", "packagedepends"), From f5dd3c590902d03ccbf05f855059f9ea98d92b17 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 14 May 2025 19:21:05 +0000 Subject: [PATCH 003/168] Drop in datastore stubs --- .../grr_response_server/databases/spanner.py | 53 ++ .../databases/spanner_artifacts.py | 67 +++ .../databases/spanner_blob_keys.py | 38 ++ .../databases/spanner_blob_references.py | 39 ++ .../databases/spanner_clients.py | 316 ++++++++++ .../databases/spanner_cron_jobs.py | 209 +++++++ .../databases/spanner_events.py | 57 ++ .../databases/spanner_flows.py | 552 ++++++++++++++++++ .../databases/spanner_foreman_rules.py | 41 ++ .../databases/spanner_hunts.py | 297 ++++++++++ .../databases/spanner_paths.py | 118 ++++ .../databases/spanner_signed_binaries.py | 73 +++ .../databases/spanner_signed_commands.py | 53 ++ .../databases/spanner_test_lib.py | 36 +- .../databases/spanner_users.py | 153 +++++ .../databases/spanner_utils_test.py | 2 +- .../databases/spanner_yara.py | 55 ++ .../databases/spanner_yara_test.py | 23 + 18 files changed, 2179 insertions(+), 3 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner.py create mode 100644 grr/server/grr_response_server/databases/spanner_artifacts.py create mode 100644 grr/server/grr_response_server/databases/spanner_blob_keys.py create mode 100644 grr/server/grr_response_server/databases/spanner_blob_references.py create mode 100644 grr/server/grr_response_server/databases/spanner_clients.py create mode 100644 grr/server/grr_response_server/databases/spanner_cron_jobs.py create mode 100644 grr/server/grr_response_server/databases/spanner_events.py create mode 100644 grr/server/grr_response_server/databases/spanner_flows.py create mode 100644 grr/server/grr_response_server/databases/spanner_foreman_rules.py create mode 100644 grr/server/grr_response_server/databases/spanner_hunts.py create mode 100644 grr/server/grr_response_server/databases/spanner_paths.py create mode 100644 grr/server/grr_response_server/databases/spanner_signed_binaries.py create mode 100644 grr/server/grr_response_server/databases/spanner_signed_commands.py create mode 100644 grr/server/grr_response_server/databases/spanner_users.py create mode 100644 grr/server/grr_response_server/databases/spanner_yara.py create mode 100644 grr/server/grr_response_server/databases/spanner_yara_test.py diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py new file mode 100644 index 000000000..3fc9d7cdf --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner.py @@ -0,0 +1,53 @@ +# Imports the Google Cloud Client Library. +from google.cloud import spanner + +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import spanner_artifacts +from grr_response_server.databases import spanner_blob_keys +from grr_response_server.databases import spanner_clients +from grr_response_server.databases import spanner_cronjobs +from grr_response_server.databases import spanner_events +from grr_response_server.databases import spanner_flows +from grr_response_server.databases import spanner_foreman_rules +from grr_response_server.databases import spanner_hunts +from grr_response_server.databases import spanner_paths +from grr_response_server.databases import spanner_signed_binaries +from grr_response_server.databases import spanner_signed_commands +from grr_response_server.databases import spanner_users +from grr_response_server.databases import spanner_utils +from grr_response_server.databases import spanner_yara +from grr_response_server.models import blobs as models_blobs +from grr_response_server.rdfvalues import objects as rdf_objects + +class SpannerDB( + spanner_artifacts.ArtifactsMixin, + spanner_blob_keys.BlobKeysMixin, + spanner_blob_references.BlobReferencesMixin, + spanner_clients.ClientsMixin, + spanner_signed_commands.SignedCommandsMixin, + spanner_cron_jobs.CronJobsMixin, + spanner_events.EventsMixin, + spanner_flows.FlowsMixin, + spanner_foreman_rules.ForemanRulesMixin, + spanner_hunts.HuntsMixin, + spanner_paths.PathsMixin, + spanner_signed_binaries.SignedBinariesMixin, + spanner_users.UsersMixin, + spanner_yara.YaraMixin, + abstract_db.Database, +): + """A Spanner implementation of the GRR database.""" + + def __init__(self, db: spanner_utils.Database) -> None: + """Initializes the database.""" + self.db = db + self._write_rows_batch_size = 10000 + + def Now(self) -> rdfvalue.RDFDatetime: + """Retrieves current time as reported by the database.""" + (timestamp,) = self.db.QuerySingle("SELECT CURRENT_TIMESTAMP()") + return rdfvalue.RDFDatetime.FromDatetime(timestamp) + + def MinTimestamp(self) -> rdfvalue.RDFDatetime: + """Returns minimal timestamp allowed by the DB.""" + return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py new file mode 100644 index 000000000..9692a05ef --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +"""A module with artifacts methods of the Spanner backend.""" + +from typing import Optional, Sequence + +from grr_response_proto import artifact_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class ArtifactsMixin: + """A Spanner database mixin with implementation of artifacts.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteArtifact(self, artifact: artifact_pb2.Artifact) -> None: + """Writes new artifact to the database. + + Args: + artifact: Artifact to be stored. + + Raises: + DuplicatedArtifactError: when the artifact already exists. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadArtifact(self, name: str) -> Optional[artifact_pb2.Artifact]: + """Looks up an artifact with given name from the database. + + Args: + name: Name of the artifact to be read. + + Returns: + The artifact object read from the database. + + Raises: + UnknownArtifactError: when the artifact does not exist. + """ + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllArtifacts(self) -> Sequence[artifact_pb2.Artifact]: + """Lists all artifacts that are stored in the database.""" + result = [] + + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteArtifact(self, name: str) -> None: + """Deletes an artifact with given name from the database. + + Args: + name: Name of the artifact to be deleted. + + Raises: + UnknownArtifactError when the artifact does not exist. + """ + diff --git a/grr/server/grr_response_server/databases/spanner_blob_keys.py b/grr/server/grr_response_server/databases/spanner_blob_keys.py new file mode 100644 index 000000000..3131237fe --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_keys.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +"""Blob encryption key methods of Spanner database implementation.""" + +from typing import Collection, Dict, Optional + +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import blobs as models_blobs + + +class BlobKeysMixin: + """A Spanner mixin with implementation of blob encryption keys methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteBlobEncryptionKeys( + self, + key_names: Dict[models_blobs.BlobID, str], + ) -> None: + """Associates the specified blobs with the given encryption keys.""" + # A special case for empty list of blob identifiers to avoid issues with an + # empty mutation. + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadBlobEncryptionKeys( + self, + blob_ids: Collection[models_blobs.BlobID], + ) -> Dict[models_blobs.BlobID, Optional[str]]: + """Retrieves encryption keys associated with blobs.""" + # A special case for empty list of blob identifiers to avoid syntax errors + # in the query below. + + + return {} diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py new file mode 100644 index 000000000..447a5bf5d --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +"""A library with blob references methods of Spanner database implementation.""" + +from typing import Collection, Mapping, Optional + +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.rdfvalues import objects as rdf_objects + + +class BlobReferencesMixin: + """A Spanner database mixin with implementation of blob references methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteHashBlobReferences( + self, + references_by_hash: Mapping[ + rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] + ], + ) -> None: + """Writes blob references for a given set of hashes.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHashBlobReferences( + self, hashes: Collection[rdf_objects.SHA256HashID] + ) -> Mapping[ + rdf_objects.SHA256HashID, Optional[Collection[objects_pb2.BlobReference]] + ]: + """Reads blob references of a given set of hashes.""" + + result = {} + + return result diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py new file mode 100644 index 000000000..6958a6a09 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python +"""A module with client methods of the Spanner database implementation.""" + +import datetime +import logging +import re +from typing import Collection, Iterator, Mapping, Optional, Sequence, Tuple + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib.util import iterator +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +# Aliasing the import since the name db clashes with the db annotation. +from grr_response_server.databases import db as db_lib +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import clients as models_clients +from rrg.proto.rrg import startup_pb2 as rrg_startup_pb2 + + +class ClientsMixin: + """A Spanner database mixin with implementation of client methods.""" + + db: spanner_utils.Database + + # TODO(b/196379916): Implement client methods. + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiWriteClientMetadata( + self, + client_ids: Collection[str], + first_seen: Optional[rdfvalue.RDFDatetime] = None, + last_ping: Optional[rdfvalue.RDFDatetime] = None, + last_foreman: Optional[rdfvalue.RDFDatetime] = None, + fleetspeak_validation_info: Optional[Mapping[str, str]] = None, + ) -> None: + """Writes metadata about the clients.""" + # Early return to avoid generating empty mutation. + + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientMetadata( + self, + client_ids: Collection[str], + ) -> Mapping[str, objects_pb2.ClientMetadata]: + """Reads ClientMetadata records for a list of clients.""" + result = {} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiAddClientLabels( + self, + client_ids: Collection[str], + owner: str, + labels: Collection[str], + ) -> None: + """Attaches user labels to the specified clients.""" + # Early return to avoid generating empty mutation. + + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientLabels( + self, + client_ids: Collection[str], + ) -> Mapping[str, Collection[objects_pb2.ClientLabel]]: + """Reads the user labels for a list of clients.""" + result = {client_id: [] for client_id in client_ids} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveClientLabels( + self, + client_id: str, + owner: str, + labels: Collection[str], + ) -> None: + """Removes a list of user labels from a given client.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllClientLabels(self) -> Collection[str]: + """Lists all client labels known to the system.""" + result = [] + + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientSnapshot(self, snapshot: objects_pb2.ClientSnapshot) -> None: + """Writes new client snapshot.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientSnapshot( + self, + client_ids: Collection[str], + ) -> Mapping[str, Optional[objects_pb2.ClientSnapshot]]: + """Reads the latest client snapshots for a list of clients.""" + # Unfortunately, Spanner has troubles with handling `UNNEST` expressions if + # the given array is empty, so we just handle such case separately. + + return {} + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientSnapshotHistory( + self, + client_id: str, + timerange: Optional[ + Tuple[Optional[rdfvalue.RDFDatetime], Optional[rdfvalue.RDFDatetime]] + ] = None, + ) -> Sequence[objects_pb2.ClientSnapshot]: + """Reads the full history for a particular client.""" + result = [] + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientStartupInfo( + self, + client_id: str, + startup: jobs_pb2.StartupInfo, + ) -> None: + """Writes a new client startup record.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientRRGStartup( + self, + client_id: str, + startup: rrg_startup_pb2.Startup, + ) -> None: + """Writes a new RRG startup entry to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientRRGStartup( + self, + client_id: str, + ) -> Optional[rrg_startup_pb2.Startup]: + """Reads the latest RRG startup entry for the given client.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientStartupInfo( + self, + client_id: str, + ) -> Optional[jobs_pb2.StartupInfo]: + """Reads the latest client startup record for a single client.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientCrashInfo( + self, + client_id: str, + crash: jobs_pb2.ClientCrash, + ) -> None: + """Writes a new client crash record.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientCrashInfo( + self, + client_id: str, + ) -> Optional[jobs_pb2.ClientCrash]: + """Reads the latest client crash record for a single client.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientCrashInfoHistory( + self, + client_id: str, + ) -> Sequence[jobs_pb2.ClientCrash]: + """Reads the full crash history for a particular client.""" + result = [] + + query = """ + SELECT cr.CreationTime, cr.Crash + FROM ClientCrashes AS cr + WHERE cr.ClientId = {client_id} + ORDER BY cr.CreationTime DESC + """ + return None + + # TODO(b/196379916): Investigate whether we need to batch this call or not. + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientFullInfo( + self, + client_ids: Collection[str], + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + ) -> Mapping[str, objects_pb2.ClientFullInfo]: + """Reads full client information for a list of clients.""" + # Spanner is having issues with `UNNEST` on empty arrays so we exit early in + # such cases. + + result = {} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientLastPings( + self, + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + max_last_ping: Optional[rdfvalue.RDFDatetime] = None, + batch_size: int = 0, + ) -> Iterator[Mapping[str, Optional[rdfvalue.RDFDatetime]]]: + """Yields dicts of last-ping timestamps for clients in the DB.""" + + + + def _ReadClientLastPingsBatch( + self, + count: int, + last_client_id: str, + min_last_ping_time: Optional[rdfvalue.RDFDatetime], + max_last_ping_time: Optional[rdfvalue.RDFDatetime], + ) -> Mapping[str, Optional[rdfvalue.RDFDatetime]]: + """Reads a single batch of last client last ping times. + + Args: + count: The number of entries to read in the batch. + last_client_id: The identifier of the last client of the previous batch. + min_last_ping_time: An (optional) lower bound on the last ping time value. + max_last_ping_time: An (optional) upper bound on the last ping time value. + + Returns: + A mapping from client identifiers to client last ping times. + """ + result = {} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteClient(self, client_id: str) -> None: + """Deletes a client with all associated metadata.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiAddClientKeywords( + self, + client_ids: Collection[str], + keywords: Collection[str], + ) -> None: + """Associates the provided keywords with the specified clients.""" + # Early return to avoid generating empty mutation. + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListClientsForKeywords( + self, + keywords: Collection[str], + start_time: Optional[rdfvalue.RDFDatetime] = None, + ) -> Mapping[str, Collection[str]]: + """Lists the clients associated with keywords.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveClientKeyword(self, client_id: str, keyword: str) -> None: + """Removes the association of a particular client to a keyword.""" + + + +def IntClientID(client_id: str) -> spanner_lib.UInt64: + """Converts a client identifier to its integer representation. + + This function wraps the value in PySpanner's `UInt64` wrapper. It is needed + because by default PySpanner assumes that integers are `Int64` and this can + cause conversion errors for large values. + + Args: + client_id: A client identifier to convert. + + Returns: + An integer representation of the given client identifier. + """ + return spanner_lib.UInt64(db_utils.ClientIDToInt(client_id)) + + +_EPOCH = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + +# TODO(b/196379916): The F1 implementation uses a single constant for all +# queries deleting a lot of data. We should probably follow this pattern here, +# this constant should be moved to a more appropriate module. Also, it might +# be worthwhile to use Spanner's partitioned DML feature [1]. +# +# pylint: disable=line-too-long +# [1]: https://g3doc.corp.google.com/spanner/g3doc/userguide/sqlv1/data-manipulation-language.md#a-note-about-locking +# pylint: enable=line-too-long +_DELETE_BATCH_SIZE = 5_000 diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py new file mode 100644 index 000000000..127b7acd0 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python +"""A module with cronjobs methods of the Spanner backend.""" + +import datetime +from typing import Any, Mapping, Optional, Sequence + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib import utils +from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + +_UNCHANGED = db.Database.UNCHANGED + + +class CronJobsMixin: + """A Spanner database mixin with implementation of cronjobs.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteCronJob(self, cronjob: flows_pb2.CronJob) -> None: + """Writes a cronjob to the database. + + Args: + cronjob: A flows_pb2.CronJob object. + """ + # We currently expect to reuse `created_at` if set. + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadCronJobs( + self, cronjob_ids: Optional[Sequence[str]] = None + ) -> Sequence[flows_pb2.CronJob]: + """Reads all cronjobs from the database. + + Args: + cronjob_ids: A list of cronjob ids to read. If not set, returns all cron + jobs in the database. + + Returns: + A list of flows_pb2.CronJob objects. + + Raises: + UnknownCronJobError: A cron job for at least one of the given ids + does not exist. + """ + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateCronJob( # pytype: disable=annotation-type-mismatch + self, + cronjob_id: str, + last_run_status: Optional[ + "flows_pb2.CronJobRun.CronJobRunStatus" + ] = _UNCHANGED, + last_run_time: Optional[rdfvalue.RDFDatetime] = _UNCHANGED, + current_run_id: Optional[str] = _UNCHANGED, + state: Optional[jobs_pb2.AttributedDict] = _UNCHANGED, + forced_run_requested: Optional[bool] = _UNCHANGED, + ): + """Updates run information for an existing cron job. + + Args: + cronjob_id: The id of the cron job to update. + last_run_status: A CronJobRunStatus object. + last_run_time: The last time a run was started for this cron job. + current_run_id: The id of the currently active run. + state: The state dict for stateful cron jobs. + forced_run_requested: A boolean indicating if a forced run is pending for + this job. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def EnableCronJob(self, cronjob_id: str) -> None: + """Enables a cronjob. + + Args: + cronjob_id: The id of the cron job to enable. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DisableCronJob(self, cronjob_id: str) -> None: + """Deletes a cronjob along with all its runs. + + Args: + cronjob_id: The id of the cron job to delete. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteCronJob(self, cronjob_id: str) -> None: + """Deletes a cronjob along with all its runs. + + Args: + cronjob_id: The id of the cron job to delete. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def LeaseCronJobs( + self, + cronjob_ids: Optional[Sequence[str]] = None, + lease_time: Optional[rdfvalue.Duration] = None, + ) -> Sequence[flows_pb2.CronJob]: + """Leases all available cron jobs. + + Args: + cronjob_ids: A list of cronjob ids that should be leased. If None, all + available cronjobs will be leased. + lease_time: rdfvalue.Duration indicating how long the lease should be + valid. + + Returns: + A list of cronjobs.CronJob objects that were leased. + """ + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReturnLeasedCronJobs(self, jobs: Sequence[flows_pb2.CronJob]) -> None: + """Makes leased cron jobs available for leasing again. + + Args: + jobs: A list of leased cronjobs. + + Raises: + ValueError: If not all of the cronjobs are leased. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteCronJobRun(self, run_object: flows_pb2.CronJobRun) -> None: + """Stores a cron job run object in the database. + + Args: + run_object: A flows_pb2.CronJobRun object to store. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadCronJobRuns(self, job_id: str) -> Sequence[flows_pb2.CronJobRun]: + """Reads all cron job runs for a given job id. + + Args: + job_id: Runs will be returned for the job with the given id. + + Returns: + A list of flows_pb2.CronJobRun objects. + """ + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadCronJobRun(self, job_id: str, run_id: str) -> flows_pb2.CronJobRun: + """Reads a single cron job run from the db. + + Args: + job_id: The job_id of the run to be read. + run_id: The run_id of the run to be read. + + Returns: + An flows_pb2.CronJobRun object. + """ + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteOldCronJobRuns(self, cutoff_timestamp: rdfvalue.RDFDatetime) -> int: + """Deletes cron job runs that are older than cutoff_timestamp. + + Args: + cutoff_timestamp: This method deletes all runs that were started before + cutoff_timestamp. + + Returns: + The number of deleted runs. + """ + + return 0 + diff --git a/grr/server/grr_response_server/databases/spanner_events.py b/grr/server/grr_response_server/databases/spanner_events.py new file mode 100644 index 000000000..843d04166 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_events.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python +"""A library with audit events methods of Spanner database implementation.""" + +from typing import Dict, List, Optional, Tuple + +from grr_response_core.lib import rdfvalue +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class EventsMixin: + """A Spanner database mixin with implementation of audit events methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAPIAuditEntries( + self, + username: Optional[str] = None, + router_method_names: Optional[List[str]] = None, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> List[objects_pb2.APIAuditEntry]: + """Returns audit entries stored in the database.""" + query = """ + SELECT + a.Username, + a.CreationTime, + a.HttpRequestPath, + a.RouterMethodName, + a.ResponseCode + FROM ApiAuditEntry AS a + """ + + result = [] + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountAPIAuditEntriesByUserAndDay( + self, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Dict[Tuple[str, rdfvalue.RDFDatetime], int]: + """Returns audit entry counts grouped by user and calendar day.""" + + result = {} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteAPIAuditEntry(self, entry: objects_pb2.APIAuditEntry): + """Writes an audit entry to the database.""" diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py new file mode 100644 index 000000000..657bf6e18 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -0,0 +1,552 @@ +#!/usr/bin/env python +"""A module with flow methods of the Spanner database implementation.""" + +import dataclasses +import datetime +import logging +from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib import utils +from grr_response_core.lib.util import collection +from grr_response_core.stats import metrics +from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_clients +from grr_response_server.databases import spanner_utils +from grr_response_server.models import hunts as models_hunts +from rrg.proto import rrg_pb2 + +class FlowsMixin: + """A Spanner database mixin with implementation of flow methods.""" + + db: spanner_utils.Database + _write_rows_batch_size: int + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowObject( + self, + flow_obj: flows_pb2.Flow, + allow_update: bool = True, + ) -> None: + """Writes a flow object to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowObject( + self, + client_id: str, + flow_id: str, + ) -> flows_pb2.Flow: + """Reads a flow object from the database.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllFlowObjects( + self, + client_id: Optional[str] = None, + parent_flow_id: Optional[str] = None, + min_create_time: Optional[rdfvalue.RDFDatetime] = None, + max_create_time: Optional[rdfvalue.RDFDatetime] = None, + include_child_flows: bool = True, + not_created_by: Optional[Iterable[str]] = None, + ) -> Sequence[flows_pb2.Flow]: + """Returns all flow objects that meet the specified conditions.""" + result = [] + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateFlow( + self, + client_id: str, + flow_id: str, + flow_obj: Union[flows_pb2.Flow, _UNCHANGED_TYPE] = _UNCHANGED, + flow_state: Union[ + flows_pb2.Flow.FlowState.ValueType, _UNCHANGED_TYPE + ] = _UNCHANGED, + client_crash_info: Union[ + jobs_pb2.ClientCrash, _UNCHANGED_TYPE + ] = _UNCHANGED, + processing_on: Optional[Union[str, _UNCHANGED_TYPE]] = _UNCHANGED, + processing_since: Optional[ + Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] + ] = _UNCHANGED, + processing_deadline: Optional[ + Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] + ] = _UNCHANGED, + ) -> None: + """Updates flow objects in the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: + """Writes flow results for a given flow.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: + """Writes flow errors for a given flow.""" + + + + def ReadFlowResults( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowResult]: + """Reads flow results of a given flow using given query options.""" + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowErrors( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowError]: + """Reads flow errors of a given flow using given query options.""" + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowResults( + self, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: + """Counts flow results of a given flow using given query options.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowErrors( + self, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: + """Counts flow errors of a given flow using given query options.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowResultsByType( + self, client_id: str, flow_id: str + ) -> Mapping[str, int]: + """Returns counts of flow results grouped by result type.""" + + query = """ + SELECT r.RdfType, COUNT(*) + FROM FlowResults AS r + WHERE r.ClientId = {client_id} AND r.FlowId = {flow_id} + GROUP BY RdfType + """ + + result = {} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowErrorsByType( + self, client_id: str, flow_id: str + ) -> Mapping[str, int]: + """Returns counts of flow errors grouped by error type.""" + result = {} + + return result + + def _BuildFlowProcessingRequestWrites( + self, + mut: spanner_utils.Mutation, + requests: Iterable[flows_pb2.FlowProcessingRequest], + ) -> None: + """Builds db writes for a list of FlowProcessingRequests.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + ) -> None: + """Writes a list of flow processing requests to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowProcessingRequests( + self, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: + """Reads all flow processing requests from the database.""" + query = """ + SELECT t.Payload, t.CreationTime FROM FlowProcessingRequestsQueue AS t + """ + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def AckFlowProcessingRequests( + self, requests: Iterable[flows_pb2.FlowProcessingRequest] + ) -> None: + """Acknowledges and deletes flow processing requests.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteAllFlowProcessingRequests(self) -> None: + """Deletes all flow processing requests from the database.""" + query = """ + DELETE FROM FlowProcessingRequestsQueue WHERE true + """ + self.db.ParamExecute(query, {}, txn_tag="DeleteAllFlowProcessingRequests") + + def RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ): + """Registers a handler to receive flow processing messages.""" + + + def UnregisterFlowProcessingHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: + """Unregisters any registered flow processing handler.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + ) -> None: + """Writes a list of flow requests to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowResponses( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> None: + """Writes Flow ressages and updates corresponding requests.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> None: + """Deletes all requests and responses for a given flow from the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> Iterable[ + Tuple[ + flows_pb2.FlowRequest, + Dict[ + int, + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ] + ]: + """Reads all requests and responses for a given flow from the database.""" + + ret = [] + + return ret + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteFlowRequests( + self, + requests: Sequence[flows_pb2.FlowRequest], + ) -> None: + """Deletes a list of flow requests from the database.""" + + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowRequests( + self, + client_id: str, + flow_id: str, + ) -> Dict[ + int, + Tuple[ + flows_pb2.FlowRequest, + List[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ]: + """Reads all requests for a flow that can be processed by the worker.""" + + return {} + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateIncrementalFlowRequests( + self, + client_id: str, + flow_id: str, + next_response_id_updates: Mapping[int, int], + ) -> None: + """Updates next response ids of given requests.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: + """Writes a single flow log entry to the database.""" + + def ReadFlowLogEntries( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_substring: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowLogEntry]: + """Reads flow log entries of a given flow using given query options.""" + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowLogEntries(self, client_id: str, flow_id: str) -> int: + """Returns number of flow log entries of a given flow.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowRRGLogs( + self, + client_id: str, + flow_id: str, + request_id: int, + logs: Mapping[int, rrg_pb2.Log], + ) -> None: + """Writes new log entries for a particular action request.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowRRGLogs( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + ) -> Sequence[rrg_pb2.Log]: + """Reads log entries logged by actions issued by a particular flow.""" + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowOutputPluginLogEntry( + self, + entry: flows_pb2.FlowOutputPluginLogEntry, + ) -> None: + """Writes a single output plugin log entry to the database. + + Args: + entry: An output plugin flow entry to write. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowOutputPluginLogEntries( + self, + client_id: str, + flow_id: str, + output_plugin_id: str, + offset: int, + count: int, + with_type: Optional[ + flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ValueType + ] = None, + ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: + """Reads flow output plugin log entries.""" + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowOutputPluginLogEntries( + self, + client_id: str, + flow_id: str, + output_plugin_id: str, + with_type: Optional[ + flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ValueType + ] = None, + ) -> int: + """Returns the number of flow output plugin log entries of a given flow.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteScheduledFlow( + self, + scheduled_flow: flows_pb2.ScheduledFlow, + ) -> None: + """Inserts or updates the ScheduledFlow in the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteScheduledFlow( + self, + client_id: str, + creator: str, + scheduled_flow_id: str, + ) -> None: + """Deletes the ScheduledFlow from the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListScheduledFlows( + self, + client_id: str, + creator: str, + ) -> Sequence[flows_pb2.ScheduledFlow]: + """Lists all ScheduledFlows for the client and creator.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: + """Writes a list of message handler requests to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadMessageHandlerRequests( + self, + ) -> Sequence[objects_pb2.MessageHandlerRequest]: + """Reads all message handler requests from the database.""" + + results = [] + return results + + def _BuildDeleteMessageHandlerRequestWrites( + self, + txn: spanner_utils.Transaction, + requests: Iterable[objects_pb2.MessageHandlerRequest], + ) -> None: + """Deletes given requests within a given transaction.""" + + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: + """Deletes a list of message handler requests from the database.""" + + + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Registers a message handler to receive batches of messages.""" + + + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: + """Unregisters any registered message handler.""" + + + def _ReadHuntState( + self, txn: spanner_utils.Transaction, hunt_id: str + ) -> Optional[int]: + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def LeaseFlowForProcessing( + self, + client_id: str, + flow_id: str, + processing_time: rdfvalue.Duration, + ) -> flows_pb2.Flow: + """Marks a flow as being processed on this worker and returns it.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: + """Releases a flow that the worker was processing to the database.""" + + return False diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules.py b/grr/server/grr_response_server/databases/spanner_foreman_rules.py new file mode 100644 index 000000000..a43e3825a --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +"""A module with foreman rules methods of the Spanner backend.""" + +from typing import Sequence + +from grr_response_core.lib import rdfvalue +from grr_response_proto import jobs_pb2 +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class ForemanRulesMixin: + """A Spanner database mixin with implementation of foreman rules.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteForemanRule(self, rule: jobs_pb2.ForemanCondition) -> None: + """Writes a foreman rule to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveForemanRule(self, hunt_id: str) -> None: + """Removes a foreman rule from the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllForemanRules(self) -> Sequence[jobs_pb2.ForemanCondition]: + """Reads all foreman rules from the database.""" + result = [] + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveExpiredForemanRules(self) -> None: + """Removes all expired foreman rules from the database.""" + diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py new file mode 100644 index 000000000..0b0df4c6d --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python +"""A module with hunt methods of the Spanner database implementation.""" + +from typing import AbstractSet, Callable, Collection, Iterable, List, Mapping, Optional, Sequence + +from google.protobuf import any_pb2 +from grr_response_core.lib import rdfvalue +from grr_response_core.lib.rdfvalues import client as rdf_client +from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_proto import output_plugin_pb2 +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_flows +from rr_response_server.databases import spanner_utils +from grr_response_server.models import hunts as models_hunts + + +class HuntsMixin: + """A Spanner database mixin with implementation of flow methods.""" + + db: spanner_utils.Database + _write_rows_batch_size: int + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt): + """Writes a hunt object to the database.""" + + + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateHuntObject( + self, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[int] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, + ): + """Updates the hunt object.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteHuntObject(self, hunt_id: str) -> None: + """Deletes a hunt object with a given id.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: + """Reads a hunt object from the database.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntObjects( + self, + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[AbstractSet[str]] = None, + not_created_by: Optional[AbstractSet[str]] = None, + with_states: Optional[Collection[hunts_pb2.Hunt.HuntState]] = None, + ) -> List[hunts_pb2.Hunt]: + """Reads hunt objects from the database.""" + + result = [] + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListHuntObjects( + self, + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[Iterable[str]] = None, + not_created_by: Optional[Iterable[str]] = None, + with_states: Optional[ + Iterable[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> Iterable[hunts_pb2.HuntMetadata]: + """Reads metadata for hunt objects from the database.""" + + result = [] + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntResults( + self, + hunt_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + with_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Iterable[flows_pb2.FlowResult]: + """Reads hunt results of a given hunt using given query options.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntResults( + self, + hunt_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: + """Counts hunt results of a given hunt using given query options.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntResultsByType(self, hunt_id: str) -> Mapping[str, int]: + """Returns counts of items in hunt results grouped by type.""" + + result = {} + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntLogEntries( + self, + hunt_id: str, + offset: int, + count: int, + with_substring: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowLogEntry]: + """Reads hunt log entries of a given hunt using given query options.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntLogEntries(self, hunt_id: str) -> int: + """Returns number of hunt log entries of a given hunt.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntOutputPluginLogEntries( + self, + hunt_id: str, + output_plugin_id: str, + offset: int, + count: int, + with_type: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: + """Reads hunt output plugin log entries.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntOutputPluginLogEntries( + self, + hunt_id: str, + output_plugin_id: str, + with_type: Optional[str] = None, + ) -> int: + """Returns number of hunt output plugin log entries of a given hunt.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntOutputPluginsStates( + self, hunt_id: str + ) -> List[output_plugin_pb2.OutputPluginState]: + """Reads all hunt output plugins states of a given hunt.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteHuntOutputPluginsStates( + self, + hunt_id: str, + states: Collection[output_plugin_pb2.OutputPluginState], + ) -> None: + """Writes hunt output plugin states for a given hunt.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateHuntOutputPluginState( + self, + hunt_id: str, + state_index: int, + update_fn: Callable[[jobs_pb2.AttributedDict], jobs_pb2.AttributedDict], + ) -> None: + """Updates hunt output plugin state for a given output plugin.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntFlows( # pytype: disable=annotation-type-mismatch + self, + hunt_id: str, + offset: int, + count: int, + filter_condition: abstract_db.HuntFlowsCondition = abstract_db.HuntFlowsCondition.UNSET, + ) -> Sequence[flows_pb2.Flow]: + """Reads hunt flows matching given conditions.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntFlowErrors( + self, + hunt_id: str, + offset: int, + count: int, + ) -> Mapping[str, abstract_db.FlowErrorInfo]: + """Returns errors for flows of the given hunt.""" + results = {} + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntFlows( # pytype: disable=annotation-type-mismatch + self, + hunt_id: str, + filter_condition: Optional[ + abstract_db.HuntFlowsCondition + ] = abstract_db.HuntFlowsCondition.UNSET, + ) -> int: + """Counts hunt flows matching given conditions.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntFlowsStatesAndTimestamps( + self, + hunt_id: str, + ) -> Sequence[abstract_db.FlowStateAndTimestamps]: + """Reads hunt flows states and timestamps.""" + + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntsCounters( + self, + hunt_ids: Collection[str], + ) -> Mapping[str, abstract_db.HuntCounters]: + """Reads hunt counters for several of hunt ids.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntClientResourcesStats( + self, + hunt_id: str, + ) -> jobs_pb2.ClientResourcesStats: + """Read hunt client resources stats.""" + + return None diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py new file mode 100644 index 000000000..eca20a458 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python +"""A module with path methods of the Spanner database implementation.""" + +from typing import Collection, Dict, Iterable, Optional, Sequence + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib.util import iterator +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_clients +from grr_response_server.databases import spanner_utils +from grr_response_server.models import paths as models_paths + + +class PathsMixin: + """A Spanner database mixin with implementation of path methods.""" + + db: spanner_utils.Database + + # TODO(b/196379916): Implement path methods. + + @db_utils.CallLogged + @db_utils.CallAccounted + def WritePathInfos( + self, + client_id: str, + path_infos: Iterable[objects_pb2.PathInfo], + ) -> None: + """Writes a collection of path records for a client.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadPathInfos( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components_list: Collection[Sequence[str]], + ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: + """Retrieves path info records for given paths.""" + + return {} + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadPathInfo( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components: Sequence[str], + timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> objects_pb2.PathInfo: + """Retrieves a path info record for a given path.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListDescendantPathInfos( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components: Sequence[str], + timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_depth: Optional[int] = None, + ) -> Sequence[objects_pb2.PathInfo]: + """Lists path info records that correspond to descendants of given path.""" + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadPathInfosHistories( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components_list: Collection[Sequence[str]], + cutoff: Optional[rdfvalue.RDFDatetime] = None, + ) -> dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: + """Reads a collection of hash and stat entries for given paths.""" + + + results = {tuple(components): [] for components in components_list} + + params = { + "client_id": spanner_clients.IntClientID(client_id), + "type": int(path_type), + "paths": list(map(EncodePathComponents, components_list)), + } + + if cutoff is not None: + stat_query += " AND s.CreationTime <= {cutoff}" + hash_query += " AND h.CreationTime <= {cutoff}" + params["cutoff"] = cutoff.AsDatetime() + + query = f""" + WITH s AS ({stat_query}), + h AS ({hash_query}) + SELECT s.Path, s.CreationTime, s.Stat, + h.Path, h.CreationTime, h.Hash + FROM s FULL JOIN h ON s.Path = h.Path + """ + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadLatestPathInfosWithHashBlobReferences( + self, + client_paths: Collection[abstract_db.ClientPath], + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Dict[abstract_db.ClientPath, Optional[objects_pb2.PathInfo]]: + """Returns path info with corresponding hash blob references.""" + # Early return in case of empty client paths to avoid issues with syntax er- + # rors due to empty clause list. + + return {} diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries.py b/grr/server/grr_response_server/databases/spanner_signed_binaries.py new file mode 100644 index 000000000..d697e4449 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +"""A module with signed binaries methods of the Spanner backend.""" + +from typing import Sequence, Tuple + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib.util import iterator +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class SignedBinariesMixin: + """A Spanner database mixin with implementation of signed binaries.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteSignedBinaryReferences( + self, + binary_id: objects_pb2.SignedBinaryID, + references: objects_pb2.BlobReferences, + ) -> None: + """Writes blob references for a signed binary to the DB. + + Args: + binary_id: Signed binary id for the binary. + references: Blob references for the given binary. + """ + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadSignedBinaryReferences( + self, binary_id: objects_pb2.SignedBinaryID + ) -> Tuple[objects_pb2.BlobReferences, rdfvalue.RDFDatetime]: + """Reads blob references for the signed binary with the given id. + + Args: + binary_id: Signed binary id for the binary. + + Returns: + A tuple of the signed binary's rdf_objects.BlobReferences and an + RDFDatetime representing the time when the references were written to the + DB. + """ + + return None, None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadIDsForAllSignedBinaries(self) -> Sequence[objects_pb2.SignedBinaryID]: + """Returns ids for all signed binaries in the DB.""" + results = [] + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteSignedBinaryReferences( + self, + binary_id: objects_pb2.SignedBinaryID, + ) -> None: + """Deletes blob references for the given signed binary from the DB. + + Does nothing if no entry with the given id exists in the DB. + + Args: + binary_id: An id of the signed binary to delete. + """ + diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands.py b/grr/server/grr_response_server/databases/spanner_signed_commands.py new file mode 100644 index 000000000..57804b21d --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_commands.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +"""A module with signed command methods of the Spanner database implementation.""" + +from typing import Sequence + +from grr_response_core.lib.util import iterator +from grr_response_proto import signed_commands_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class SignedCommandsMixin: + """A Spanner database mixin with implementation of signed command methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteSignedCommands( + self, + signed_commands: Sequence[signed_commands_pb2.SignedCommand], + ) -> None: + """Writes a signed command to the database.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadSignedCommand( + self, + id_: str, + operating_system: signed_commands_pb2.SignedCommand.OS, + ) -> signed_commands_pb2.SignedCommand: + """Reads signed command from the database.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadSignedCommands( + self, + ) -> Sequence[signed_commands_pb2.SignedCommand]: + """Reads signed command from the database.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteAllSignedCommands( + self, + ) -> None: + """Deletes all signed command from the database.""" + diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 590b9c92b..4ca66bdbc 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -12,12 +12,16 @@ from google.cloud.spanner_admin_database_v1.types import spanner_database_admin from grr_response_server.databases import spanner_utils +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import spanner as spanner_db OPERATION_TIMEOUT_SECONDS = 240 PROD_SCHEMA_SDL_PATH = "grr/server/grr_response_server/databases/spanner.sdl" TEST_SCHEMA_SDL_PATH = "grr/server/grr_response_server/databases/spanner_test.sdl" +PROTO_DESCRIPTOR_PATH = "grr/server/grr_response_server/databases/spanner_grr.pb" + def _GetEnvironOrSkip(key): value = os.environ.get(key) if value is None: @@ -32,7 +36,13 @@ def _readSchemaFromFile(file_path): ddl_statements = [stmt.strip() for stmt in f.read().split(';') if stmt.strip()] return ddl_statements -def Init(sdl_path: str) -> None: +def _readProtoDescriptorFromFile(): + """Reads DDL statements from a file.""" + with open(PROTO_DESCRIPTOR_PATH, 'rb') as f: + proto_descriptors = f.read() + return proto_descriptors + +def Init(sdl_path: str, proto_bundle: bool) -> None: """Initializes the Spanner testing environment. This must be called only once per test process. A `setUpModule` method is @@ -53,10 +63,16 @@ def Init(sdl_path: str) -> None: ddl_statements = _readSchemaFromFile(sdl_path) + proto_descriptors = bytes() + + if proto_bundle: + proto_descriptors = _readProtoDescriptorFromFile() + request = spanner_database_admin.CreateDatabaseRequest( parent=database_admin_api.instance_path(spanner_client.project, instance_id), create_statement=f"CREATE DATABASE `{database_id}`", - extra_statements=ddl_statements + extra_statements=ddl_statements, + proto_descriptors=proto_descriptors ) operation = database_admin_api.create_database(request=request) @@ -86,6 +102,22 @@ def TearDown() -> None: _TEST_DB.drop() +class TestCase(absltest.TestCase): + """A base test case class for Spanner tests. + + This class takes care of setting up a clean database for every test method. It + is intended to be used with database test suite mixins. + """ + + def setUp(self): + super().setUp() + + self.raw_db = spanner_utils.Database(CreateTestDatabase()) + + db = spanner_db.SpannerDB(self.raw_db) + self.db = abstract_db.DatabaseValidationWrapper(db) + + def CreateTestDatabase() -> spanner_lib.database: """Creates an empty test spanner database. diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py new file mode 100644 index 000000000..6d293b2c8 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +"""A library with user methods of Spanner database implementation.""" + +import datetime +import logging +from typing import Optional, Sequence, Tuple + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib.util import iterator +from grr_response_core.lib.util import random +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_proto import user_pb2 +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class UsersMixin: + """A Spanner database mixin with implementation of user methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteGRRUser( + self, + username: str, + email: Optional[str] = None, + password: Optional[jobs_pb2.Password] = None, + user_type: Optional["objects_pb2.GRRUser.UserType"] = None, + canary_mode: Optional[bool] = None, + ui_mode: Optional["user_pb2.GUISettings.UIMode"] = None, + ) -> None: + """Writes user object for a user with a given name.""" + + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteGRRUser(self, username: str) -> None: + """Deletes the user and all related metadata with the given username.""" + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadGRRUser(self, username: str) -> objects_pb2.GRRUser: + """Reads a user object corresponding to a given name.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadGRRUsers( + self, + offset: int = 0, + count: Optional[int] = None, + ) -> Sequence[objects_pb2.GRRUser]: + """Reads GRR users with optional pagination, sorted by username.""" + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountGRRUsers(self) -> int: + """Returns the total count of GRR users.""" + + return 0 + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteApprovalRequest(self, request: objects_pb2.ApprovalRequest) -> str: + """Writes an approval request object.""" + + return "" + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadApprovalRequest( + self, + username: str, + approval_id: str, + ) -> objects_pb2.ApprovalRequest: + """Reads an approval request object with a given id.""" + + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadApprovalRequests( + self, + username: str, + typ: "objects_pb2.ApprovalRequest.ApprovalType", + subject_id: Optional[str] = None, + include_expired: Optional[bool] = False, + ) -> Sequence[objects_pb2.ApprovalRequest]: + """Reads approval requests of a given type for a given user.""" + requests = [] + + return requests + + @db_utils.CallLogged + @db_utils.CallAccounted + def GrantApproval( + self, + requestor_username: str, + approval_id: str, + grantor_username: str, + ) -> None: + """Grants approval for a given request using given username.""" + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteUserNotification( + self, + notification: objects_pb2.UserNotification, + ) -> None: + """Writes a notification for a given user.""" + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadUserNotifications( + self, + username: str, + state: Optional["objects_pb2.UserNotification.State"] = None, + timerange: Optional[ + Tuple[Optional[rdfvalue.RDFDatetime], Optional[rdfvalue.RDFDatetime]] + ] = None, + ) -> Sequence[objects_pb2.UserNotification]: + """Reads notifications scheduled for a user within a given timerange.""" + + return [] + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateUserNotifications( + self, + username: str, + timestamps: Sequence[rdfvalue.RDFDatetime], + state: Optional["objects_pb2.UserNotification.State"] = None, + ): + """Updates existing user notification objects.""" + + + + +def RDFDatetime(time: datetime.datetime) -> rdfvalue.RDFDatetime: + return rdfvalue.RDFDatetime.FromDatetime(time) + + +_APPROVAL_TYPE_CLIENT = objects_pb2.ApprovalRequest.APPROVAL_TYPE_CLIENT +_APPROVAL_TYPE_HUNT = objects_pb2.ApprovalRequest.APPROVAL_TYPE_HUNT +_APPROVAL_TYPE_CRON_JOB = objects_pb2.ApprovalRequest.APPROVAL_TYPE_CRON_JOB diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index a7bbd41c5..a2bd3fafd 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -17,7 +17,7 @@ from grr_response_server.databases import spanner_utils def setUpModule() -> None: - spanner_test_lib.Init(spanner_test_lib.TEST_SCHEMA_SDL_PATH) + spanner_test_lib.Init(spanner_test_lib.TEST_SCHEMA_SDL_PATH, False) def tearDownModule() -> None: diff --git a/grr/server/grr_response_server/databases/spanner_yara.py b/grr/server/grr_response_server/databases/spanner_yara.py new file mode 100644 index 000000000..bfe8d7277 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_yara.py @@ -0,0 +1,55 @@ +"""A module with YARA methods of the Spanner database implementation.""" + +from google.cloud import spanner as spanner_lib +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import blobs as models_blobs + + +class YaraMixin: + """A Spanner database mixin with implementation of YARA methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteYaraSignatureReference( + self, + blob_id: models_blobs.BlobID, + username: str, + ) -> None: + """Marks the specified blob id as a YARA signature.""" + row = { + "BlobId": bytes(blob_id), + "Creator": username, + "CreationTime": spanner_lib.CommitTimestamp(), + } + + try: + self.db.InsertOrUpdate( + table="YaraSignatureReferences", + row=row, + txn_tag="WriteYaraSignatureReference", + ) + except NotFound as error: + if "fk_yara_signature_reference_creator_username" in str(error): + raise db.UnknownGRRUserError(username) from error + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def VerifyYaraSignatureReference( + self, + blob_id: models_blobs.BlobID, + ) -> bool: + """Verifies whether the specified blob is a YARA signature.""" + key = (bytes(blob_id),) + + try: + self.db.Read(table="YaraSignatureReferences", key=key, cols=()) + except NotFound: + return False + + return True \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_yara_test.py b/grr/server/grr_response_server/databases/spanner_yara_test.py new file mode 100644 index 000000000..0bf0ddb38 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_yara_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_yara_test_lib +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseYaraTest( + db_yara_test_lib.DatabaseTestYaraMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + pass + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 8286b9bff2baa49e7d6a35638b750447687b8da0 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Thu, 15 May 2025 13:36:11 +0000 Subject: [PATCH 004/168] Yara table tests passing --- .../grr_response_server/databases/spanner.py | 4 +- .../grr_response_server/databases/spanner.sdl | 3 +- .../databases/spanner_clients.py | 6 +-- .../databases/spanner_flows.py | 38 ++++++++----------- .../databases/spanner_hunts.py | 6 ++- .../databases/spanner_users.py | 17 +++++++++ .../databases/spanner_utils.py | 6 ++- .../databases/spanner_yara.py | 13 ++++--- 8 files changed, 57 insertions(+), 36 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index 3fc9d7cdf..b8b7fa97e 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -1,11 +1,13 @@ # Imports the Google Cloud Client Library. from google.cloud import spanner +from grr_response_core.lib import rdfvalue from grr_response_server.databases import db as abstract_db from grr_response_server.databases import spanner_artifacts from grr_response_server.databases import spanner_blob_keys +from grr_response_server.databases import spanner_blob_references from grr_response_server.databases import spanner_clients -from grr_response_server.databases import spanner_cronjobs +from grr_response_server.databases import spanner_cron_jobs from grr_response_server.databases import spanner_events from grr_response_server.databases import spanner_flows from grr_response_server.databases import spanner_foreman_rules diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 7ccbe9fc4..469b9c158 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -537,7 +537,8 @@ CREATE TABLE YaraSignatureReferences( CHECK (BYTE_LENGTH(BlobId) = 32), CONSTRAINT fk_yara_signature_reference_creator_username FOREIGN KEY (Creator) - REFERENCES Users(Username), + REFERENCES Users(Username) + ENFORCED ) PRIMARY KEY (BlobId); CREATE TABLE Hunts( diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index 6958a6a09..a67e9e8d0 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -15,7 +15,7 @@ from grr_response_server.databases import db_utils from grr_response_server.databases import spanner_utils from grr_response_server.models import clients as models_clients -from rrg.proto.rrg import startup_pb2 as rrg_startup_pb2 +from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 class ClientsMixin: @@ -287,7 +287,7 @@ def RemoveClientKeyword(self, client_id: str, keyword: str) -> None: -def IntClientID(client_id: str) -> spanner_lib.UInt64: +def IntClientID(client_id: str) -> int: """Converts a client identifier to its integer representation. This function wraps the value in PySpanner's `UInt64` wrapper. It is needed @@ -300,7 +300,7 @@ def IntClientID(client_id: str) -> spanner_lib.UInt64: Returns: An integer representation of the given client identifier. """ - return spanner_lib.UInt64(db_utils.ClientIDToInt(client_id)) + return db_utils.ClientIDToInt(client_id) _EPOCH = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 657bf6e18..4641827a9 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -18,7 +18,7 @@ from grr_response_server.databases import spanner_clients from grr_response_server.databases import spanner_utils from grr_response_server.models import hunts as models_hunts -from rrg.proto import rrg_pb2 +from grr_response_proto import rrg_pb2 class FlowsMixin: """A Spanner database mixin with implementation of flow methods.""" @@ -70,20 +70,24 @@ def UpdateFlow( self, client_id: str, flow_id: str, - flow_obj: Union[flows_pb2.Flow, _UNCHANGED_TYPE] = _UNCHANGED, + flow_obj: Union[ + flows_pb2.Flow, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, flow_state: Union[ - flows_pb2.Flow.FlowState.ValueType, _UNCHANGED_TYPE - ] = _UNCHANGED, + flows_pb2.Flow.FlowState.ValueType, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, client_crash_info: Union[ - jobs_pb2.ClientCrash, _UNCHANGED_TYPE - ] = _UNCHANGED, - processing_on: Optional[Union[str, _UNCHANGED_TYPE]] = _UNCHANGED, + jobs_pb2.ClientCrash, db.Database.UNCHANGED_TYPE + ] = db.Database.UNCHANGED, + processing_on: Optional[ + Union[str, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, processing_since: Optional[ - Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] - ] = _UNCHANGED, + Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, processing_deadline: Optional[ - Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] - ] = _UNCHANGED, + Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] + ] = db.Database.UNCHANGED, ) -> None: """Updates flow objects in the database.""" @@ -495,13 +499,6 @@ def ReadMessageHandlerRequests( results = [] return results - def _BuildDeleteMessageHandlerRequestWrites( - self, - txn: spanner_utils.Transaction, - requests: Iterable[objects_pb2.MessageHandlerRequest], - ) -> None: - """Deletes given requests within a given transaction.""" - @db_utils.CallLogged @@ -527,11 +524,6 @@ def UnregisterMessageHandler( """Unregisters any registered message handler.""" - def _ReadHuntState( - self, txn: spanner_utils.Transaction, hunt_id: str - ) -> Optional[int]: - return None - @db_utils.CallLogged @db_utils.CallAccounted def LeaseFlowForProcessing( diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index 0b0df4c6d..f4645c32c 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -13,7 +13,7 @@ from grr_response_server.databases import db as abstract_db from grr_response_server.databases import db_utils from grr_response_server.databases import spanner_flows -from rr_response_server.databases import spanner_utils +from grr_response_server.databases import spanner_utils from grr_response_server.models import hunts as models_hunts @@ -73,7 +73,9 @@ def ReadHuntObjects( with_description_match: Optional[str] = None, created_by: Optional[AbstractSet[str]] = None, not_created_by: Optional[AbstractSet[str]] = None, - with_states: Optional[Collection[hunts_pb2.Hunt.HuntState]] = None, + with_states: Optional[ + Collection[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, ) -> List[hunts_pb2.Hunt]: """Reads hunt objects from the database.""" diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 6d293b2c8..864e1a7a9 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -33,7 +33,24 @@ def WriteGRRUser( ui_mode: Optional["user_pb2.GUISettings.UIMode"] = None, ) -> None: """Writes user object for a user with a given name.""" + row = {"Username": username} + if email is not None: + row["Email"] = email + + if password is not None: + row["Password"] = password + + if user_type is not None: + row["Type"] = int(user_type) + + if ui_mode is not None: + row["UiMode"] = int(ui_mode) + + if canary_mode is not None: + row["CanaryMode"] = canary_mode + + self.db.InsertOrUpdate(table="Users", row=row, txn_tag="WriteGRRUser") @db_utils.CallLogged diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 53a33a08d..493309c1a 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -21,6 +21,8 @@ from google.cloud.spanner_admin_database_v1.types import spanner_database_admin from google.cloud.spanner_v1 import Mutation, param_types +from google.rpc.code_pb2 import OK + from grr_response_core.lib.util import collection from grr_response_core.lib.util import iterator @@ -373,7 +375,9 @@ def InsertOrUpdate( columns=columns, values=[values] ) - groups.batch_write() + for response in groups.batch_write(): + if response.status.code != OK: + raise Exception(response.status.message) def Delete( self, table: str, key: Sequence[Any], txn_tag: Optional[str] = None diff --git a/grr/server/grr_response_server/databases/spanner_yara.py b/grr/server/grr_response_server/databases/spanner_yara.py index bfe8d7277..167667595 100644 --- a/grr/server/grr_response_server/databases/spanner_yara.py +++ b/grr/server/grr_response_server/databases/spanner_yara.py @@ -1,4 +1,7 @@ """A module with YARA methods of the Spanner database implementation.""" +import base64 + +from google.api_core.exceptions import NotFound from google.cloud import spanner as spanner_lib from grr_response_server.databases import db @@ -21,9 +24,9 @@ def WriteYaraSignatureReference( ) -> None: """Marks the specified blob id as a YARA signature.""" row = { - "BlobId": bytes(blob_id), + "BlobId": base64.b64encode(bytes(blob_id)), "Creator": username, - "CreationTime": spanner_lib.CommitTimestamp(), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, } try: @@ -32,7 +35,7 @@ def WriteYaraSignatureReference( row=row, txn_tag="WriteYaraSignatureReference", ) - except NotFound as error: + except Exception as error: if "fk_yara_signature_reference_creator_username" in str(error): raise db.UnknownGRRUserError(username) from error else: @@ -45,10 +48,10 @@ def VerifyYaraSignatureReference( blob_id: models_blobs.BlobID, ) -> bool: """Verifies whether the specified blob is a YARA signature.""" - key = (bytes(blob_id),) + key = (base64.b64encode(bytes(blob_id)),) try: - self.db.Read(table="YaraSignatureReferences", key=key, cols=()) + self.db.Read(table="YaraSignatureReferences", key=key, cols=("BlobId",)) except NotFound: return False From b94a0261940f9abb2ce3ab8c6b3585df45fef41d Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Fri, 16 May 2025 14:54:05 +0000 Subject: [PATCH 005/168] Adds blob_keys table --- .../databases/spanner_blob_keys.py | 43 ++++++++++++++++++- .../databases/spanner_blob_keys_test.py | 22 ++++++++++ .../databases/spanner_utils.py | 2 + 3 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 grr/server/grr_response_server/databases/spanner_blob_keys_test.py diff --git a/grr/server/grr_response_server/databases/spanner_blob_keys.py b/grr/server/grr_response_server/databases/spanner_blob_keys.py index 3131237fe..9b13a971b 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_keys.py +++ b/grr/server/grr_response_server/databases/spanner_blob_keys.py @@ -1,8 +1,10 @@ #!/usr/bin/env python """Blob encryption key methods of Spanner database implementation.""" +import base64 from typing import Collection, Dict, Optional +from google.cloud import spanner as spanner_lib from grr_response_server.databases import db_utils from grr_response_server.databases import spanner_utils from grr_response_server.models import blobs as models_blobs @@ -22,6 +24,18 @@ def WriteBlobEncryptionKeys( """Associates the specified blobs with the given encryption keys.""" # A special case for empty list of blob identifiers to avoid issues with an # empty mutation. + if not key_names: + return + + def Mutation(mut) -> None: + for blob_id, key_name in key_names.items(): + mut.insert( + table="BlobEncryptionKeys", + columns=("BlobId", "CreationTime", "KeyName"), + values=[((base64.b64encode(bytes(blob_id))), spanner_lib.COMMIT_TIMESTAMP, key_name)] + ) + + self.db.Mutate(Mutation, txn_tag="WriteBlobEncryptionKeys") @db_utils.CallLogged @@ -33,6 +47,33 @@ def ReadBlobEncryptionKeys( """Retrieves encryption keys associated with blobs.""" # A special case for empty list of blob identifiers to avoid syntax errors # in the query below. + if not blob_ids: + return {} + + param_placeholders = ", ".join([f"{{blobId{i}}}" for i in range(len(blob_ids))]) + + params = {} + for i, blob_id_bytes in enumerate(blob_ids): + param_name = f"blobId{i}" + params[param_name] = base64.b64encode(bytes(blob_id_bytes)) + + query = f""" + SELECT k.BlobId, k.KeyName + FROM BlobEncryptionKeys AS k + INNER JOIN (SELECT k.BlobId, MAX(k.CreationTime) AS MaxCreationTime + FROM BlobEncryptionKeys AS k + WHERE k.BlobId IN ({param_placeholders}) + GROUP BY k.BlobId) AS last_k + ON k.BlobId = last_k.BlobId + AND k.CreationTime = last_k.MaxCreationTime + """ + + results = {blob_id: None for blob_id in blob_ids} + for blob_id_bytes, key_name in self.db.ParamQuery( + query, params, txn_tag="ReadBlobEncryptionKeys" + ): + blob_id = models_blobs.BlobID(base64.b64decode(blob_id_bytes)) + results[blob_id] = key_name - return {} + return results diff --git a/grr/server/grr_response_server/databases/spanner_blob_keys_test.py b/grr/server/grr_response_server/databases/spanner_blob_keys_test.py new file mode 100644 index 000000000..65935bf85 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_keys_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_blob_keys_test_lib +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseBlobKeysTest( + db_blob_keys_test_lib.DatabaseTestBlobKeysMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 493309c1a..1dfc197f9 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -2,7 +2,9 @@ import contextlib import datetime +import decimal import re + from typing import Any from typing import Callable from typing import Generic From 3ad2f76adef31e7cfec0e2c6b683effb6b6f6d2a Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 19 May 2025 11:31:55 +0000 Subject: [PATCH 006/168] Adds Artifacts table --- .../databases/spanner_artifacts.py | 44 ++++++++++++++++++- .../databases/spanner_artifacts_test.py | 22 ++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 grr/server/grr_response_server/databases/spanner_artifacts_test.py diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py index 9692a05ef..ff37eeb4f 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -3,6 +3,9 @@ from typing import Optional, Sequence +from google.api_core.exceptions import AlreadyExists, NotFound +from google.cloud import spanner as spanner_lib + from grr_response_proto import artifact_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils @@ -25,6 +28,16 @@ def WriteArtifact(self, artifact: artifact_pb2.Artifact) -> None: Raises: DuplicatedArtifactError: when the artifact already exists. """ + name = str(artifact.name) + row = { + "Name": name, + "Platforms": list(artifact.supported_os), + "Payload": artifact, + } + try: + self.db.Insert(table="Artifacts", row=row, txn_tag="WriteArtifact") + except AlreadyExists as error: + raise db.DuplicatedArtifactError(name) from error @db_utils.CallLogged @@ -41,8 +54,15 @@ def ReadArtifact(self, name: str) -> Optional[artifact_pb2.Artifact]: Raises: UnknownArtifactError: when the artifact does not exist. """ + try: + row = self.db.Read("Artifacts", key=[name], cols=("Platforms", "Payload")) + except NotFound as error: + raise db.UnknownArtifactError(name) from error - return None + artifact = artifact_pb2.Artifact.FromString(row[1]) + artifact.name = name + artifact.supported_os[:] = row[0] + return artifact @db_utils.CallLogged @db_utils.CallAccounted @@ -50,6 +70,15 @@ def ReadAllArtifacts(self) -> Sequence[artifact_pb2.Artifact]: """Lists all artifacts that are stored in the database.""" result = [] + query = """ + SELECT a.Name, a.Platforms, a.Payload + FROM Artifacts AS a + """ + for [name, supported_os, payload] in self.db.Query(query): + artifact = artifact_pb2.Artifact.FromString(payload) + artifact.name = name + artifact.supported_os[:] = supported_os + result.append(artifact) return result @@ -64,4 +93,17 @@ def DeleteArtifact(self, name: str) -> None: Raises: UnknownArtifactError when the artifact does not exist. """ + def Transaction(txn) -> None: + # Spanner does not raise if we attept to delete a non-existing row so + # we check it exists ourselves. + keyset = spanner_lib.KeySet(keys=[[name],]) + + try: + txn.read(table="Artifacts", columns=("Name",), keyset=keyset).one() + except NotFound as error: + raise db.UnknownArtifactError(name) from error + + txn.delete("Artifacts", keyset) + + self.db.Transact(Transaction, txn_tag="DeleteArtifact") diff --git a/grr/server/grr_response_server/databases/spanner_artifacts_test.py b/grr/server/grr_response_server/databases/spanner_artifacts_test.py new file mode 100644 index 000000000..12a6a9e4e --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_artifacts_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_artifacts_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseArtifactsTest( + db_artifacts_test.DatabaseTestArtifactsMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From b0573a980170632a305c15383c086dbaf68458d9 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 20 May 2025 08:37:32 +0000 Subject: [PATCH 007/168] Adds HashBlobReferences table --- .../databases/db_blob_references_test.py | 5 ++- .../databases/spanner_blob_references.py | 45 ++++++++++++++++++- .../databases/spanner_blob_references_test.py | 23 ++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_blob_references_test.py diff --git a/grr/server/grr_response_server/databases/db_blob_references_test.py b/grr/server/grr_response_server/databases/db_blob_references_test.py index babf33c0d..20f85fed6 100644 --- a/grr/server/grr_response_server/databases/db_blob_references_test.py +++ b/grr/server/grr_response_server/databases/db_blob_references_test.py @@ -82,7 +82,10 @@ def testMultipleHashBlobReferencesCanBeWrittenAndReadBack(self): def testWriteHashBlobHandlesLargeAmountsOfData(self): hash_id_blob_refs = {} - for _ in range(50000): + # Limit to 16k records to stay within Spanner 80k mutation/commit limit + # https://cloud.google.com/spanner/quotas#limits-for + # 16k records * 5 columns = 80k mutations/commit + for _ in range(16000): hash_id = rdf_objects.SHA256HashID(os.urandom(32)) blob_ref = objects_pb2.BlobReference() diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 447a5bf5d..6fb1aa422 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -1,8 +1,11 @@ #!/usr/bin/env python """A library with blob references methods of Spanner database implementation.""" +import base64 from typing import Collection, Mapping, Optional +from google.cloud import spanner as spanner_lib + from grr_response_proto import objects_pb2 from grr_response_server.databases import db_utils from grr_response_server.databases import spanner_utils @@ -23,6 +26,22 @@ def WriteHashBlobReferences( ], ) -> None: """Writes blob references for a given set of hashes.""" + def Mutation(mut) -> None: + for hash_id, refs in references_by_hash.items(): + hash_id_b64 = base64.b64encode(bytes(hash_id.AsBytes())) + key_range = spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,]) + keyset = spanner_lib.KeySet(ranges=[key_range]) + # Make sure we delete any of the previously existing blob references. + mut.delete("HashBlobReferences", keyset) + + for ref in refs: + mut.insert( + table="HashBlobReferences", + columns=("HashId", "BlobId", "Offset", "Size",), + values=[(hash_id_b64, base64.b64encode(bytes(ref.blob_id)), ref.offset, ref.size,)], + ) + + self.db.Mutate(Mutation, txn_tag="WriteHashBlobReferences") @db_utils.CallLogged @@ -33,7 +52,31 @@ def ReadHashBlobReferences( rdf_objects.SHA256HashID, Optional[Collection[objects_pb2.BlobReference]] ]: """Reads blob references of a given set of hashes.""" - result = {} + key_ranges = [] + + for h in hashes: + hash_id_b64 = base64.b64encode(bytes(h.AsBytes())) + key_ranges.append(spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,])) + result[h] = [] + + rows = spanner_lib.KeySet(ranges=key_ranges) + + hashes_left = set(hashes) + for row in self.db.ReadSet( + table="HashBlobReferences", rows=rows, cols=("HashId", "BlobId", "Offset", "Size") + ): + hash_id = rdf_objects.SHA256HashID(base64.b64decode(row[0])) + + blob_ref = objects_pb2.BlobReference() + blob_ref.blob_id = base64.b64decode(row[1]) + blob_ref.offset = row[2] + blob_ref.size = row[3] + + result[hash_id].append(blob_ref) + hashes_left.discard(hash_id) + + for h in hashes_left: + result[h] = None return result diff --git a/grr/server/grr_response_server/databases/spanner_blob_references_test.py b/grr/server/grr_response_server/databases/spanner_blob_references_test.py new file mode 100644 index 000000000..c4130afef --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_references_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_blob_references_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseBlobReferencesTest( + db_blob_references_test.DatabaseTestBlobReferencesMixin, + spanner_test_lib.TestCase, +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From b7e3f4f033f014d921efd541a57d81113a38751e Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 20 May 2025 11:53:00 +0000 Subject: [PATCH 008/168] Adds SignedBinaries table --- .../databases/spanner_artifacts.py | 3 +- .../databases/spanner_foreman_rules.py | 33 +++++++++++- .../databases/spanner_foreman_rules_test.py | 23 +++++++++ .../databases/spanner_signed_binaries.py | 51 ++++++++++++++++++- .../databases/spanner_signed_binaries_test.py | 23 +++++++++ 5 files changed, 128 insertions(+), 5 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_foreman_rules_test.py create mode 100644 grr/server/grr_response_server/databases/spanner_signed_binaries_test.py diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py index ff37eeb4f..fd3906319 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -105,5 +105,4 @@ def Transaction(txn) -> None: txn.delete("Artifacts", keyset) - self.db.Transact(Transaction, txn_tag="DeleteArtifact") - + self.db.Transact(Transaction, txn_tag="DeleteArtifact") \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules.py b/grr/server/grr_response_server/databases/spanner_foreman_rules.py index a43e3825a..f21f96eac 100644 --- a/grr/server/grr_response_server/databases/spanner_foreman_rules.py +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules.py @@ -18,12 +18,28 @@ class ForemanRulesMixin: @db_utils.CallAccounted def WriteForemanRule(self, rule: jobs_pb2.ForemanCondition) -> None: """Writes a foreman rule to the database.""" - + hunt_id_int = db_utils.HuntIDToInt(rule.hunt_id) + row = { + "HuntId": hunt_id_int, + "ExpirationTime": ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(rule.expiration_time) + .AsDatetime() + ), + "Payload": rule, + } + self.db.InsertOrUpdate( + table="ForemanRules", row=row, txn_tag="WriteForemanRule" + ) @db_utils.CallLogged @db_utils.CallAccounted def RemoveForemanRule(self, hunt_id: str) -> None: """Removes a foreman rule from the database.""" + hunt_id_int = db_utils.HuntIDToInt(hunt_id) + self.db.Delete( + table="ForemanRules", key=(hunt_id_int), txn_tag="RemoveForemanRule" + ) @db_utils.CallLogged @@ -32,10 +48,25 @@ def ReadAllForemanRules(self) -> Sequence[jobs_pb2.ForemanCondition]: """Reads all foreman rules from the database.""" result = [] + query = """ + SELECT fr.Payload + FROM ForemanRules AS fr + """ + for [payload] in self.db.Query(query, txn_tag="ReadAllForemanRules"): + rule = jobs_pb2.ForemanCondition() + rule.ParseFromString(payload) + result.append(rule) + return result @db_utils.CallLogged @db_utils.CallAccounted def RemoveExpiredForemanRules(self) -> None: """Removes all expired foreman rules from the database.""" + query = """ + DELETE + FROM ForemanRules@{{FORCE_INDEX=ForemanRulesByExpirationTime}} AS fr + WHERE fr.ExpirationTime < CURRENT_TIMESTAMP() + """ + self.db.ParamExecute(query, {}, txn_tag="RemoveExpiredForemanRules") diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules_test.py b/grr/server/grr_response_server/databases/spanner_foreman_rules_test.py new file mode 100644 index 000000000..8c6153fc0 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_foreman_rules_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseForemanRulesTest( + db_foreman_rules_test.DatabaseTestForemanRulesMixin, + spanner_test_lib.TestCase, +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries.py b/grr/server/grr_response_server/databases/spanner_signed_binaries.py index d697e4449..1deb02075 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_binaries.py +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries.py @@ -3,6 +3,9 @@ from typing import Sequence, Tuple +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import iterator from grr_response_proto import objects_pb2 @@ -29,6 +32,15 @@ def WriteSignedBinaryReferences( binary_id: Signed binary id for the binary. references: Blob references for the given binary. """ + row = { + "Type": int(binary_id.binary_type), + "Path": binary_id.path, + "BlobReferences": references, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + } + self.db.InsertOrUpdate( + table="SignedBinaries", row=row, txn_tag="WriteSignedBinaryReferences" + ) @db_utils.CallLogged @@ -46,8 +58,23 @@ def ReadSignedBinaryReferences( RDFDatetime representing the time when the references were written to the DB. """ + binary_type = int(binary_id.binary_type) + + try: + row = self.db.Read( + table="SignedBinaries", + key=(binary_type, binary_id.path), + cols=("BlobReferences", "CreationTime"), + ) + except NotFound as error: + raise db.UnknownSignedBinaryError(binary_id) from error - return None, None + raw_references = row[0] + creation_time = row[1] + + references = objects_pb2.BlobReferences() + references.ParseFromString(raw_references) + return references, rdfvalue.RDFDatetime.FromDatetime(creation_time) @db_utils.CallLogged @db_utils.CallAccounted @@ -55,6 +82,19 @@ def ReadIDsForAllSignedBinaries(self) -> Sequence[objects_pb2.SignedBinaryID]: """Returns ids for all signed binaries in the DB.""" results = [] + query = """ + SELECT sb.Type, sb.Path + FROM SignedBinaries as sb + """ + + for [binary_type, binary_path] in self.db.Query( + query, txn_tag="ReadIDsForAllSignedBinaries" + ): + binary_id = objects_pb2.SignedBinaryID( + binary_type=binary_type, path=binary_path + ) + results.append(binary_id) + return results @db_utils.CallLogged @@ -70,4 +110,11 @@ def DeleteSignedBinaryReferences( Args: binary_id: An id of the signed binary to delete. """ - + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.delete("SignedBinaries", spanner_lib.KeySet(keys=[[binary_id.binary_type, binary_id.path]]) + ) + + try: + self.db.Mutate(Mutation, txn_tag="DeleteSignedBinaryReferences") + except NotFound as error: + raise db.UnknownSignedBinaryError(binary_id) from error \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py b/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py new file mode 100644 index 000000000..871918ff2 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_signed_binaries_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseSignedBinariesTest( + db_signed_binaries_test.DatabaseTestSignedBinariesMixin, + spanner_test_lib.TestCase, +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From d21378d5cf4665952887f3af4db918a15a2dfa58 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 20 May 2025 15:53:39 +0000 Subject: [PATCH 009/168] Adds SignedCommands table --- .../grr_response_server/databases/spanner.sdl | 2 +- .../databases/spanner_signed_commands.py | 89 ++++++++++++++++++- .../databases/spanner_signed_commands_test.py | 23 +++++ 3 files changed, 111 insertions(+), 3 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_signed_commands_test.py diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 469b9c158..7d57e529b 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -667,5 +667,5 @@ CREATE TABLE SignedCommands( Id STRING(128) NOT NULL, OperatingSystem `grr.SignedCommand.OS` NOT NULL, Ed25519Signature BYTES(64) NOT NULL, - Command `rrg.action.execute_signed_command.Command` NOT NULL, + Command BYTES(MAX) NOT NULL, ) PRIMARY KEY (Id, OperatingSystem); diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands.py b/grr/server/grr_response_server/databases/spanner_signed_commands.py index 57804b21d..bd1c7a06a 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_commands.py +++ b/grr/server/grr_response_server/databases/spanner_signed_commands.py @@ -1,8 +1,12 @@ #!/usr/bin/env python """A module with signed command methods of the Spanner database implementation.""" +import base64 from typing import Sequence +from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound +from google.cloud import spanner as spanner_lib + from grr_response_core.lib.util import iterator from grr_response_proto import signed_commands_pb2 from grr_response_server.databases import db @@ -22,6 +26,25 @@ def WriteSignedCommands( signed_commands: Sequence[signed_commands_pb2.SignedCommand], ) -> None: """Writes a signed command to the database.""" + def Mutation(mut) -> None: + for signed_command in signed_commands: + mut.insert( + table="SignedCommands", + columns=("Id", "OperatingSystem", "Ed25519Signature", "Command"), + values=[( + signed_command.id, + signed_command.operating_system, + base64.b64encode(bytes(signed_command.ed25519_signature)), + base64.b64encode(bytes(signed_command.command)) + )] + ) + + try: + self.db.Mutate(Mutation, txn_tag="WriteSignedCommand") + except AlreadyExists as e: + raise db.AtLeastOneDuplicatedSignedCommandError(signed_commands) from e + except InvalidArgument as e: + raise db.AtLeastOneDuplicatedSignedCommandError(signed_commands) from e @db_utils.CallLogged @@ -32,8 +55,35 @@ def ReadSignedCommand( operating_system: signed_commands_pb2.SignedCommand.OS, ) -> signed_commands_pb2.SignedCommand: """Reads signed command from the database.""" + params = {} + query = """ + SELECT + c.Id, c.Ed25519Signature, c.Command + FROM SignedCommands AS c + WHERE + c.Id = {id} + AND + c.OperatingSystem = {operating_system} + """ + params["id"] = id_ + params["operating_system"] = operating_system + + try: + ( + id_, + ed25519_signature, + command_bytes, + ) = self.db.ParamQuerySingle(query, params, txn_tag="ReadSignedCommand") + except NotFound as ex: + raise db.NotFoundError() from ex + + signed_command = signed_commands_pb2.SignedCommand() + signed_command.id = id_ + signed_command.operating_system = operating_system + signed_command.ed25519_signature = base64.b64decode(ed25519_signature) + signed_command.command = base64.b64decode(command_bytes) - return None + return signed_command @db_utils.CallLogged @db_utils.CallAccounted @@ -41,8 +91,35 @@ def ReadSignedCommands( self, ) -> Sequence[signed_commands_pb2.SignedCommand]: """Reads signed command from the database.""" + query = """ + SELECT + c.Id, c.OperatingSystem, c.Ed25519Signature, c.Command + FROM + SignedCommands AS c + """ + query = """ + SELECT + c.Id, c.OperatingSystem, c.Ed25519Signature, c.Command + FROM + SignedCommands AS c + """ + signed_commands = [] + for ( + command_id, + operating_system, + signature, + command_bytes, + ) in self.db.Query(query, txn_tag="ReadSignedCommand"): - return None + signed_command = signed_commands_pb2.SignedCommand() + signed_command.id = command_id + signed_command.operating_system = operating_system + signed_command.ed25519_signature = base64.b64decode(signature) + signed_command.command = base64.b64decode(command_bytes) + + signed_commands.append(signed_command) + + return signed_commands @db_utils.CallLogged @db_utils.CallAccounted @@ -50,4 +127,12 @@ def DeleteAllSignedCommands( self, ) -> None: """Deletes all signed command from the database.""" + to_delete = self.ReadSignedCommands() + if not to_delete: + return + + def Mutation(mut) -> None: + for command in to_delete: + mut.delete("SignedCommands", spanner_lib.KeySet(keys=[[command.id, int(command.operating_system)]])) + self.db.Mutate(Mutation, txn_tag="DeleteAllSignedCommands") \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands_test.py b/grr/server/grr_response_server/databases/spanner_signed_commands_test.py new file mode 100644 index 000000000..15ff6a872 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_commands_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_signed_commands_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseSignedCommandsTest( + db_signed_commands_test.DatabaseTestSignedCommandsMixin, + spanner_test_lib.TestCase, +): + """Spanner signed commands tests.""" + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 5d76dca0b917763e1c8e4d90579f1f24cdb7e6e5 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 21 May 2025 14:10:55 +0000 Subject: [PATCH 010/168] Adds ApiAuditEntry table --- .../databases/spanner_events.py | 89 +++++++++++++++++++ .../databases/spanner_events_test.py | 22 +++++ 2 files changed, 111 insertions(+) create mode 100644 grr/server/grr_response_server/databases/spanner_events_test.py diff --git a/grr/server/grr_response_server/databases/spanner_events.py b/grr/server/grr_response_server/databases/spanner_events.py index 843d04166..6975ddb30 100644 --- a/grr/server/grr_response_server/databases/spanner_events.py +++ b/grr/server/grr_response_server/databases/spanner_events.py @@ -3,6 +3,8 @@ from typing import Dict, List, Optional, Tuple +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_proto import objects_pb2 from grr_response_server.databases import db_utils @@ -34,7 +36,50 @@ def ReadAPIAuditEntries( FROM ApiAuditEntry AS a """ + params = {} + conditions = [] + + if username is not None: + conditions.append("a.Username = {username}") + params["username"] = username + + if router_method_names: + param_placeholders = ", ".join([f"{{rmn{i}}}" for i in range(len(router_method_names))]) + for i, rmn in enumerate(router_method_names): + param_name = f"rmn{i}" + params[param_name] = rmn + conditions.append(f"""a.RouterMethodName IN ({param_placeholders})""") + + if min_timestamp is not None: + conditions.append("a.CreationTime >= {min_timestamp}") + params["min_timestamp"] = min_timestamp.AsDatetime() + + if max_timestamp is not None: + conditions.append("a.CreationTime <= {max_timestamp}") + params["max_timestamp"] = max_timestamp.AsDatetime() + + if conditions: + query += " WHERE " + " AND ".join(conditions) + result = [] + for ( + username, + ts, + http_request_path, + router_method_name, + response_code, + ) in self.db.ParamQuery(query, params, txn_tag="ReadAPIAuditEntries"): + result.append( + objects_pb2.APIAuditEntry( + username=username, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + ts + ).AsMicrosecondsSinceEpoch(), + http_request_path=http_request_path, + router_method_name=router_method_name, + response_code=response_code, + ) + ) return result @@ -46,8 +91,35 @@ def CountAPIAuditEntriesByUserAndDay( max_timestamp: Optional[rdfvalue.RDFDatetime] = None, ) -> Dict[Tuple[str, rdfvalue.RDFDatetime], int]: """Returns audit entry counts grouped by user and calendar day.""" + query = """ + SELECT + a.Username, + TIMESTAMP_TRUNC(a.CreationTime, DAY, "UTC") AS day, + COUNT(*) + FROM APIAuditEntry AS a + """ + + params = {} + conditions = [] + + if min_timestamp is not None: + conditions.append("a.CreationTime >= {min_timestamp}") + params["min_timestamp"] = min_timestamp.AsDatetime() + + if max_timestamp is not None: + conditions.append("a.CreationTime <= {max_timestamp}") + params["max_timestamp"] = max_timestamp.AsDatetime() + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += " GROUP BY a.Username, day" result = {} + for username, day, count in self.db.ParamQuery( + query, params, txn_tag="CountAPIAuditEntriesByUserAndDay" + ): + result[(username, rdfvalue.RDFDatetime.FromDatetime(day))] = count return result @@ -55,3 +127,20 @@ def CountAPIAuditEntriesByUserAndDay( @db_utils.CallAccounted def WriteAPIAuditEntry(self, entry: objects_pb2.APIAuditEntry): """Writes an audit entry to the database.""" + row = { + "HttpRequestPath": entry.http_request_path, + "RouterMethodName": entry.router_method_name, + "Username": entry.username, + "ResponseCode": entry.response_code, + } + + if not entry.HasField("timestamp"): + row["CreationTime"] = spanner_lib.COMMIT_TIMESTAMP + else: + row["CreationTime"] = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + entry.timestamp + ).AsDatetime() + + self.db.InsertOrUpdate( + table="ApiAuditEntry", row=row, txn_tag="WriteAPIAuditEntry" + ) diff --git a/grr/server/grr_response_server/databases/spanner_events_test.py b/grr/server/grr_response_server/databases/spanner_events_test.py new file mode 100644 index 000000000..44d5cd10c --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_events_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_events_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseEventsTest( + db_events_test.DatabaseTestEventsMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 67d2a518c0abb4d0ab8e2c78a7453da63dafa489 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 28 May 2025 17:00:49 +0000 Subject: [PATCH 011/168] EoD 20250528 --- .../grr_response_server/databases/spanner.py | 19 +- .../databases/spanner_flows.py | 2054 ++++++++++++++++- .../databases/spanner_flows_large_test.py | 53 + .../databases/spanner_flows_test.py | 29 + .../databases/spanner_users.py | 405 +++- .../databases/spanner_users_test.py | 22 + .../databases/spanner_utils.py | 52 +- 7 files changed, 2570 insertions(+), 64 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_flows_large_test.py create mode 100644 grr/server/grr_response_server/databases/spanner_flows_test.py create mode 100644 grr/server/grr_response_server/databases/spanner_users_test.py diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index b8b7fa97e..a2265304f 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -1,8 +1,8 @@ # Imports the Google Cloud Client Library. -from google.cloud import spanner +from google.cloud.spanner import Client from grr_response_core.lib import rdfvalue -from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db as db_module from grr_response_server.databases import spanner_artifacts from grr_response_server.databases import spanner_blob_keys from grr_response_server.databases import spanner_blob_references @@ -36,7 +36,7 @@ class SpannerDB( spanner_signed_binaries.SignedBinariesMixin, spanner_users.UsersMixin, spanner_yara.YaraMixin, - abstract_db.Database, + db_module.Database, ): """A Spanner implementation of the GRR database.""" @@ -45,6 +45,19 @@ def __init__(self, db: spanner_utils.Database) -> None: self.db = db self._write_rows_batch_size = 10000 + @classmethod + def FromConfig(cls) -> "Database": + """Creates a GRR database instance for Spanner path specified in the config. + + Returns: + A GRR database instance. + """ + spanner_client = Client(onfig.CONFIG["ProjectID"]) + spanner_instance = spanner_client.instance(config.CONFIG["Spanner.instance"]) + spanner_database = spanner_instance.database(config.CONFIG["Spanner.database"]) + + return cls(spanner_utils.Database(spanner_database)) + def Now(self) -> rdfvalue.RDFDatetime: """Retrieves current time as reported by the database.""" (timestamp,) = self.db.QuerySingle("SELECT CURRENT_TIMESTAMP()") diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 4641827a9..45019fe5c 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -6,6 +6,9 @@ import logging from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union +from google.api_core.exceptions import AlreadyExists, NotFound +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils from grr_response_core.lib.util import collection @@ -20,12 +23,230 @@ from grr_response_server.models import hunts as models_hunts from grr_response_proto import rrg_pb2 + +SPANNER_DELETE_FLOW_REQUESTS_FAILURES = metrics.Counter( + name="spanner_delete_flow_requests_failures" +) + + +_MESSAGE_HANDLER_MAX_KEEPALIVE_SECONDS = 300 +_MESSAGE_HANDLER_MAX_ACTIVE_CALLBACKS = 20 + + +@dataclasses.dataclass(frozen=True) +class _FlowKey: + """Unique key identifying a flow in helper methods.""" + + client_id: str + flow_id: str + + +@dataclasses.dataclass(frozen=True) +class _RequestKey: + """Unique key identifying a flow request in helper methods.""" + + client_id: str + flow_id: str + request_id: int + + +@dataclasses.dataclass(frozen=True) +class _ResponseKey: + """Unique key identifying a flow response in helper methods.""" + + client_id: str + flow_id: str + request_id: int + response_id: int + + +_UNCHANGED = db.Database.UNCHANGED +_UNCHANGED_TYPE = db.Database.UNCHANGED_TYPE + + +def _BuildReadFlowResultsErrorsConditions( + table_name: str, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, +) -> tuple[str, Mapping[str, Any]]: + """Builds query string and params for results/errors reading queries.""" + params = {} + + query = f""" + SELECT t.Payload, t.RdfType, t.CreationTime, t.Tag, t.HuntId + FROM {table_name} AS t + """ + + query += """ + WHERE t.ClientId = {client_id} AND t.FlowId = {flow_id} + """ + + params["client_id"] = spanner_clients.IntClientID(client_id) + params["flow_id"] = IntFlowID(flow_id) + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + + if with_substring is not None: + query += """ + AND STRPOS(SAFE_CONVERT_BYTES_TO_STRING(t.Payload.value), {substring}) != 0 + """ + params["substring"] = with_substring + + query += """ + ORDER BY t.CreationTime ASC LIMIT {count} OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + return query, params + + +def _BuildCountFlowResultsErrorsConditions( + table_name: str, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, +) -> tuple[str, Mapping[str, Any]]: + """Builds query string and params for count flow results/errors queries.""" + params = {} + + query = f""" + SELECT COUNT(*) + FROM {table_name} AS t + """ + + query += """ + WHERE t.ClientId = {client_id} AND t.FlowId = {flow_id} + """ + + params["client_id"] = spanner_clients.IntClientID(client_id) + params["flow_id"] = IntFlowID(flow_id) + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + + return query, params + + +_READ_FLOW_OBJECT_COLS = ( + "LongFlowId", + "ParentFlowId", + "ParentHuntId", + "Creator", + "Name", + "State", + "CreationTime", + "UpdateTime", + "Crash", + "ProcessingWorker", + "ProcessingStartTime", + "ProcessingEndTime", + "NextRequestToProcess", + "Flow", +) + + +def _ParseReadFlowObjectRow( + client_id: str, + flow_id: str, + row: Mapping[str, Any], +) -> flows_pb2.Flow: + """Parses a row fetched with _READ_FLOW_OBJECT_COLS.""" + result = flows_pb2.Flow() + result.ParseFromString(row["Flow"]) + + creation_time = rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + update_time = rdfvalue.RDFDatetime.FromDatetime(row["UpdateTime"]) + + # We treat column values as the source of truth for values, not the message + # in the database itself. At least this is what the F1 implementation does. + result.client_id = client_id + result.flow_id = flow_id + result.long_flow_id = row["LongFlowId"] + + if row["ParentFlowId"] is not None: + result.parent_flow_id = db_utils.IntToFlowID(row["ParentFlowId"]) + if row["ParentHuntId"] is not None: + result.parent_hunt_id = db_utils.IntToHuntID(row["ParentHuntId"]) + + if row["Name"] is not None: + result.flow_class_name = row["Name"] + if row["Creator"] is not None: + result.creator = row["Creator"] + if row["State"] not in [None, flows_pb2.Flow.FlowState.UNSET]: + result.flow_state = row["State"] + if row["NextRequestToProcess"]: + result.next_request_to_process = row["NextRequestToProcess"] + + result.create_time = int(creation_time) + result.last_update_time = int(update_time) + + if row["Crash"] is not None: + client_crash = jobs_pb2.ClientCrash() + client_crash.ParseFromString(row["Crash"]) + result.client_crash_info.CopyFrom(client_crash) + + result.ClearField("processing_on") + if row["ProcessingWorker"] is not None: + result.processing_on = row["ProcessingWorker"] + result.ClearField("processing_since") + if row["ProcessingStartTime"] is not None: + result.processing_since = int( + rdfvalue.RDFDatetime.FromDatetime(row["ProcessingStartTime"]) + ) + result.ClearField("processing_deadline") + if row["ProcessingEndTime"] is not None: + result.processing_deadline = int( + rdfvalue.RDFDatetime.FromDatetime(row["ProcessingEndTime"]) + ) + + return result + + class FlowsMixin: """A Spanner database mixin with implementation of flow methods.""" db: spanner_utils.Database _write_rows_batch_size: int + @property + def _flow_processing_request_receiver( + self, + ) -> Optional[spanner_lib.QueueReceiver]: + return getattr(self, "__flow_processing_request_receiver", None) + + @_flow_processing_request_receiver.setter + def _flow_processing_request_receiver( + self, value: Optional[spanner_lib.QueueReceiver] + ) -> None: + setattr(self, "__flow_processing_request_receiver", value) + + @property + def _message_handler_receiver(self) -> Optional[spanner_lib.QueueReceiver]: + return getattr(self, "__message_handler_receiver", None) + + @_message_handler_receiver.setter + def _message_handler_receiver( + self, value: Optional[spanner_lib.QueueReceiver] + ) -> None: + setattr(self, "__message_handler_receiver", value) @db_utils.CallLogged @db_utils.CallAccounted @@ -35,7 +256,67 @@ def WriteFlowObject( allow_update: bool = True, ) -> None: """Writes a flow object to the database.""" - + client_id = flow_obj.client_id + flow_id = flow_obj.flow_id + + row = { + "ClientId": spanner_clients.IntClientID(client_id), + "FlowId": IntFlowID(flow_id), + "LongFlowId": flow_obj.long_flow_id, + } + + if flow_obj.parent_flow_id: + row["ParentFlowId"] = IntFlowID(flow_obj.parent_flow_id) + if flow_obj.parent_hunt_id: + row["ParentHuntId"] = IntHuntID(flow_obj.parent_hunt_id) + + row["Creator"] = flow_obj.creator + row["Name"] = flow_obj.flow_class_name + row["State"] = int(flow_obj.flow_state) + row["NextRequestToProcess"] = flow_obj.next_request_to_process + + row["CreationTime"] = spanner_lib.CommitTimestamp() + row["UpdateTime"] = spanner_lib.CommitTimestamp() + + if flow_obj.HasField("client_crash_info"): + row["Crash"] = flow_obj.client_crash_info + + if flow_obj.HasField("processing_on"): + row["ProcessingWorker"] = flow_obj.processing_on + if flow_obj.HasField("processing_since"): + row["ProcessingStartTime"] = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(flow_obj.processing_since) + .AsDatetime() + ) + if flow_obj.HasField("processing_deadline"): + row["ProcessingEndTime"] = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(flow_obj.processing_deadline) + .AsDatetime() + ) + + row["Flow"] = flow_obj + + row["ReplyCount"] = spanner_lib.UInt64(flow_obj.num_replies_sent) + row["NetworkBytesSent"] = spanner_lib.UInt64(flow_obj.network_bytes_sent) + row["UserCpuTimeUsed"] = float(flow_obj.cpu_time_used.user_cpu_time) + row["SystemCpuTimeUsed"] = float(flow_obj.cpu_time_used.system_cpu_time) + + try: + if allow_update: + self.db.InsertOrUpdate( + table="Flows", row=row, txn_tag="WriteFlowObject_IOU" + ) + else: + self.db.Insert(table="Flows", row=row, txn_tag="WriteFlowObject_I") + except spanner_errors.AlreadyExistsError as error: + raise db.FlowExistsError(client_id, flow_id) from error + except spanner_errors.RowNotFoundError as error: + if "Parent row is missing: Clients" in str(error): + raise db.UnknownClientError(client_id) + else: + raise @db_utils.CallLogged @db_utils.CallAccounted @@ -45,8 +326,20 @@ def ReadFlowObject( flow_id: str, ) -> flows_pb2.Flow: """Reads a flow object from the database.""" + int_client_id = spanner_clients.IntClientID(client_id) + int_flow_id = IntFlowID(flow_id) - return None + try: + row = self.db.Read( + table="Flows", + key=(int_client_id, int_flow_id), + cols=_READ_FLOW_OBJECT_COLS, + ) + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + + flow = _ParseReadFlowObjectRow(client_id, flow_id, row) + return flow @db_utils.CallLogged @db_utils.CallAccounted @@ -62,6 +355,75 @@ def ReadAllFlowObjects( """Returns all flow objects that meet the specified conditions.""" result = [] + query = """ + SELECT f.ClientId, f.FlowId, f.LongFlowId, + f.ParentFlowId, f.ParentHuntId, + f.Creator, f.Name, f.State, + f.CreationTime, f.UpdateTime, + f.Crash, f.NextRequestToProcess, + f.Flow + FROM Flows AS f + """ + params = {} + + conds = [] + + if client_id is not None: + params["client_id"] = spanner_clients.IntClientID(client_id) + conds.append("f.ClientId = {client_id}") + if parent_flow_id is not None: + params["parent_flow_id"] = IntFlowID(parent_flow_id) + conds.append("f.ParentFlowId = {parent_flow_id}") + if min_create_time is not None: + params["min_creation_time"] = min_create_time.AsDatetime() + conds.append("f.CreationTime >= {min_creation_time}") + if max_create_time is not None: + params["max_creation_time"] = max_create_time.AsDatetime() + conds.append("f.CreationTime <= {max_creation_time}") + if not include_child_flows: + conds.append("f.ParentFlowId IS NULL") + if not_created_by is not None: + params["not_created_by"] = spanner_lib.Array(str, not_created_by) + conds.append("f.Creator NOT IN UNNEST({not_created_by})") + + if conds: + query += f" WHERE {' AND '.join(conds)}" + + for row in self.db.ParamQuery(query, params, txn_tag="ReadAllFlowObjects"): + int_client_id, int_flow_id, long_flow_id, *row = row + int_parent_flow_id, int_parent_hunt_id, *row = row + creator, name, state, *row = row + creation_time, update_time, *row = row + crash_bytes, next_request_to_process, flow_bytes = row + + flow = flows_pb2.Flow() + flow.ParseFromString(flow_bytes) + flow.client_id = db_utils.IntToClientID(int_client_id) + flow.flow_id = db_utils.IntToFlowID(int_flow_id) + flow.long_flow_id = long_flow_id + flow.next_request_to_process = next_request_to_process + + if int_parent_flow_id is not None: + flow.parent_flow_id = db_utils.IntToFlowID(int_parent_flow_id) + if int_parent_hunt_id is not None: + flow.parent_hunt_id = db_utils.IntToHuntID(int_parent_hunt_id) + + flow.creator = creator + flow.flow_state = state + flow.flow_class_name = name + + flow.create_time = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + flow.last_update_time = rdfvalue.RDFDatetime.FromDatetime( + update_time + ).AsMicrosecondsSinceEpoch() + + if crash_bytes is not None: + flow.client_crash_info.ParseFromString(crash_bytes) + + result.append(flow) + return result @db_utils.CallLogged @@ -70,40 +432,98 @@ def UpdateFlow( self, client_id: str, flow_id: str, - flow_obj: Union[ - flows_pb2.Flow, db.Database.UNCHANGED_TYPE - ] = db.Database.UNCHANGED, + flow_obj: Union[flows_pb2.Flow, _UNCHANGED_TYPE] = _UNCHANGED, flow_state: Union[ - flows_pb2.Flow.FlowState.ValueType, db.Database.UNCHANGED_TYPE - ] = db.Database.UNCHANGED, + flows_pb2.Flow.FlowState.ValueType, _UNCHANGED_TYPE + ] = _UNCHANGED, client_crash_info: Union[ - jobs_pb2.ClientCrash, db.Database.UNCHANGED_TYPE - ] = db.Database.UNCHANGED, - processing_on: Optional[ - Union[str, db.Database.UNCHANGED_TYPE] - ] = db.Database.UNCHANGED, + jobs_pb2.ClientCrash, _UNCHANGED_TYPE + ] = _UNCHANGED, + processing_on: Optional[Union[str, _UNCHANGED_TYPE]] = _UNCHANGED, processing_since: Optional[ - Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] - ] = db.Database.UNCHANGED, + Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] + ] = _UNCHANGED, processing_deadline: Optional[ - Union[rdfvalue.RDFDatetime, db.Database.UNCHANGED_TYPE] - ] = db.Database.UNCHANGED, + Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] + ] = _UNCHANGED, ) -> None: """Updates flow objects in the database.""" + row = { + "ClientId": spanner_clients.IntClientID(client_id), + "FlowId": IntFlowID(flow_id), + "UpdateTime": spanner_lib.CommitTimestamp(), + } + + if isinstance(flow_obj, flows_pb2.Flow): + row["Flow"] = flow_obj + row["State"] = int(flow_obj.flow_state) + row["ReplyCount"] = spanner_lib.UInt64(flow_obj.num_replies_sent) + row["NetworkBytesSent"] = spanner_lib.UInt64(flow_obj.network_bytes_sent) + row["UserCpuTimeUsed"] = float(flow_obj.cpu_time_used.user_cpu_time) + row["SystemCpuTimeUsed"] = float(flow_obj.cpu_time_used.system_cpu_time) + if isinstance(flow_state, flows_pb2.Flow.FlowState.ValueType): + row["State"] = int(flow_state) + if isinstance(client_crash_info, jobs_pb2.ClientCrash): + row["Crash"] = client_crash_info + if ( + isinstance(processing_on, str) and processing_on is not db.UNCHANGED + ) or processing_on is None: + row["ProcessingWorker"] = processing_on + if isinstance(processing_since, rdfvalue.RDFDatetime): + row["ProcessingStartTime"] = processing_since.AsDatetime() + if processing_since is None: + row["ProcessingStartTime"] = None + if isinstance(processing_deadline, rdfvalue.RDFDatetime): + row["ProcessingEndTime"] = processing_deadline.AsDatetime() + if processing_deadline is None: + row["ProcessingEndTime"] = None + + try: + self.db.Update(table="Flows", row=row, txn_tag="UpdateFlow") + except spanner_errors.RowNotFoundError as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) @db_utils.CallLogged @db_utils.CallAccounted def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: """Writes flow results for a given flow.""" + def Mutation(mut) -> None: + for r in results: + row = { + "ClientId": spanner_clients.IntClientID(r.client_id), + "FlowId": IntFlowID(r.flow_id), + "HuntId": IntHuntID(r.hunt_id) if r.hunt_id else 0, + "CreationTime": rdfvalue.RDFDatetime.Now().AsDatetime(), + "Tag": r.tag, + "RdfType": db_utils.TypeURLToRDFTypeName(r.payload.type_url), + "Payload": r.payload, + } + + mut.Insert("FlowResults", row) + + self.db.Mutate(Mutation, txn_tag="WriteFlowResults") @db_utils.CallLogged @db_utils.CallAccounted def WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: """Writes flow errors for a given flow.""" - + def Mutation(mut) -> None: + for r in errors: + row = { + "ClientId": spanner_clients.IntClientID(r.client_id), + "FlowId": IntFlowID(r.flow_id), + "HuntId": IntHuntID(r.hunt_id) if r.hunt_id else 0, + "CreationTime": rdfvalue.RDFDatetime.Now().AsDatetime(), + "Payload": r.payload, + "RdfType": db_utils.TypeURLToRDFTypeName(r.payload.type_url), + "Tag": r.tag, + } + mut.Insert("FlowErrors", row) + + self.db.Mutate(Mutation, txn_tag="WriteFlowErrors") def ReadFlowResults( self, @@ -116,7 +536,41 @@ def ReadFlowResults( with_substring: Optional[str] = None, ) -> Sequence[flows_pb2.FlowResult]: """Reads flow results of a given flow using given query options.""" + query, params = _BuildReadFlowResultsErrorsConditions( + "FlowResults", + client_id, + flow_id, + offset, + count, + with_tag, + with_type, + with_substring, + ) + results = [] + for ( + payload_bytes, + _, + creation_time, + tag, + hunt_id, + ) in self.db.ParamQuery(query, params, txn_tag="ReadFlowResults"): + result = flows_pb2.FlowResult( + client_id=client_id, + flow_id=flow_id, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + result.payload.ParseFromString(payload_bytes) + + if hunt_id is not None: + result.hunt_id = db_utils.IntToHuntID(hunt_id) + + if tag is not None: + result.tag = tag + + results.append(result) return results @@ -132,8 +586,55 @@ def ReadFlowErrors( with_type: Optional[str] = None, ) -> Sequence[flows_pb2.FlowError]: """Reads flow errors of a given flow using given query options.""" - - return [] + query, params = _BuildReadFlowResultsErrorsConditions( + "FlowErrors", + client_id, + flow_id, + offset, + count, + with_tag, + with_type, + None, + ) + + errors = [] + for ( + payload_bytes, + payload_type, + creation_time, + tag, + hunt_id, + ) in self.db.ParamQuery(query, params, txn_tag="ReadFlowErrors"): + error = flows_pb2.FlowError( + client_id=client_id, + flow_id=flow_id, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + # TODO(b/309429206): for separation of concerns reasons, + # ReadFlowResults/ReadFlowErrors shouldn't do the payload type validation, + # they should be completely agnostic to what payloads get written/read + # to/from the database. Keeping this logic here temporarily + # to narrow the scope of the RDFProtoStruct->protos migration. + if payload_type in rdfvalue.RDFValue.classes: + error.payload.ParseFromString(payload_bytes) + else: + unrecognized = objects_pb2.SerializedValueOfUnrecognizedType( + type_name=payload_type, value=payload_bytes + ) + error.payload.Pack(unrecognized) + + if hunt_id is not None: + error.hunt_id = db_utils.IntToHuntID(hunt_id) + + if tag is not None: + error.tag = tag + + errors.append(error) + + return errors @db_utils.CallLogged @db_utils.CallAccounted @@ -146,7 +647,13 @@ def CountFlowResults( ) -> int: """Counts flow results of a given flow using given query options.""" - return 0 + query, params = _BuildCountFlowResultsErrorsConditions( + "FlowResults", client_id, flow_id, with_tag, with_type + ) + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountFlowResults" + ) + return count @db_utils.CallLogged @db_utils.CallAccounted @@ -159,7 +666,13 @@ def CountFlowErrors( ) -> int: """Counts flow errors of a given flow using given query options.""" - return 0 + query, params = _BuildCountFlowResultsErrorsConditions( + "FlowErrors", client_id, flow_id, with_tag, with_type + ) + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountFlowErrors" + ) + return count @db_utils.CallLogged @db_utils.CallAccounted @@ -175,7 +688,16 @@ def CountFlowResultsByType( GROUP BY RdfType """ + params = { + "client_id": spanner_clients.IntClientID(client_id), + "flow_id": IntFlowID(flow_id), + } + result = {} + for type_name, count in self.db.ParamQuery( + query, params, txn_tag="CountFlowResultsByType" + ): + result[type_name] = count return result @@ -185,7 +707,24 @@ def CountFlowErrorsByType( self, client_id: str, flow_id: str ) -> Mapping[str, int]: """Returns counts of flow errors grouped by error type.""" + + query = """ + SELECT e.RdfType, COUNT(*) + FROM FlowErrors AS e + WHERE e.ClientId = {client_id} AND e.FlowId = {flow_id} + GROUP BY RdfType + """ + + params = { + "client_id": spanner_clients.IntClientID(client_id), + "flow_id": IntFlowID(flow_id), + } + result = {} + for type_name, count in self.db.ParamQuery( + query, params, txn_tag="CountFlowErrorsByType" + ): + result[type_name] = count return result @@ -196,6 +735,26 @@ def _BuildFlowProcessingRequestWrites( ) -> None: """Builds db writes for a list of FlowProcessingRequests.""" + for r in requests: + key = ( + spanner_clients.IntClientID(r.client_id), + IntFlowID(r.flow_id), + spanner_lib.CommitTimestamp(), + ) + + ts = None + if r.delivery_time: + ts = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + r.delivery_time + ).AsDatetimeUTC() + + mut.Send( + queue="FlowProcessingRequestsQueue", + key=key, + value=r, + column="Payload", + deliver_time=ts, + ) @db_utils.CallLogged @db_utils.CallAccounted @@ -205,6 +764,10 @@ def WriteFlowProcessingRequests( ) -> None: """Writes a list of flow processing requests to the database.""" + def Mutation(mut) -> None: + self._BuildFlowProcessingRequestWrites(mut, requests) + + self.db.BufferedMutate(Mutation, txn_tag="WriteFlowProcessingRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -217,6 +780,13 @@ def ReadFlowProcessingRequests( """ results = [] + for payload, creation_time in self.db.ParamQuery( + query, {}, txn_tag="ReadFlowProcessingRequests" + ): + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(payload) + req.creation_time = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + results.append(req) return results @@ -227,6 +797,18 @@ def AckFlowProcessingRequests( ) -> None: """Acknowledges and deletes flow processing requests.""" + def Mutation(mut) -> None: + for r in requests: + key = ( + spanner_clients.IntClientID(r.client_id), + IntFlowID(r.flow_id), + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + r.creation_time + ).AsDatetime(), + ) + mut.Ack("FlowProcessingRequestsQueue", key) + + self.db.BufferedMutate(Mutation, txn_tag="AckFlowProcessingRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -241,13 +823,39 @@ def RegisterFlowProcessingHandler( self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] ): """Registers a handler to receive flow processing messages.""" - + self.UnregisterFlowProcessingHandler() + + def Callback(expanded_key: Sequence[Any], payload: bytes): + try: + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(payload) + req.creation_time = int( + rdfvalue.RDFDatetime.FromDatetime(expanded_key[2]) + ) + handler(req) + except Exception as e: # pylint: disable=broad-except + logging.exception("Exception raised during Flow processing: %s", e) + + receiver = self.db.NewQueueReceiver( + "FlowProcessingRequestsQueue", + Callback, + receiver_max_keepalive_seconds=3000, + receiver_max_active_callbacks=50, + receiver_max_messages_per_callback=1, + ) + receiver.Receive() + self._flow_processing_request_receiver = receiver def UnregisterFlowProcessingHandler( self, timeout: Optional[rdfvalue.Duration] = None ) -> None: """Unregisters any registered flow processing handler.""" - + del timeout # Unused. + if self._flow_processing_request_receiver is not None: + # Pytype doesn't understand that the if-check above ensures that + # _flow_processing_request_receiver is not None. + self._flow_processing_request_receiver.Stop() # pytype: disable=attribute-error + self._flow_processing_request_receiver = None @db_utils.CallLogged @db_utils.CallAccounted @@ -257,6 +865,565 @@ def WriteFlowRequests( ) -> None: """Writes a list of flow requests to the database.""" + flow_keys = [(r.client_id, r.flow_id) for r in requests] + + def Txn(txn) -> None: + needs_processing = {} + with txn.Mutate() as mut: + for r in requests: + if r.needs_processing: + needs_processing.setdefault((r.client_id, r.flow_id), []).append(r) + + client_id_int = spanner_clients.IntClientID(r.client_id) + flow_id_int = IntFlowID(r.flow_id) + + update_dict = { + "ClientId": client_id_int, + "FlowId": flow_id_int, + "RequestId": r.request_id, + "NeedsProcessing": r.needs_processing, + "NextResponseId": r.next_response_id, + "CallbackState": r.callback_state, + "Payload": r, + "CreationTime": spanner_lib.CommitTimestamp(), + } + if r.start_time: + update_dict["StartTime"] = ( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + r.start_time + ).AsDatetime() + ) + + mut.InsertOrUpdate("FlowRequests", update_dict) + + if needs_processing: + flow_processing_requests = [] + + rows = spanner_lib.RowSet() + # Note on linting: adding .keys() triggers a warning that + # .keys() should be omitted. Omitting keys leads to a + # mistaken warning that .items() was not called. + for client_id, flow_id in needs_processing: # pylint: disable=dict-iter-missing-items + rows.Add( + spanner_lib.Key( + spanner_clients.IntClientID(client_id), IntFlowID(flow_id) + ) + ) + + cols = ( + "ClientId", + "FlowId", + "NextRequestToProcess", + ) + for row in txn.ReadSet(table="Flows", rows=rows, cols=cols): + client_id = db_utils.IntToClientID(row["ClientId"]) + flow_id = db_utils.IntToFlowID(row["FlowId"]) + + candidate_requests = needs_processing.get((client_id, flow_id), []) + for r in candidate_requests: + if row["NextRequestToProcess"] == r.request_id or r.start_time: + req = flows_pb2.FlowProcessingRequest( + client_id=client_id, flow_id=flow_id + ) + if r.start_time: + req.delivery_time = r.start_time + flow_processing_requests.append(req) + + if flow_processing_requests: + with txn.BufferedMutate() as mut: + self._BuildFlowProcessingRequestWrites( + mut, flow_processing_requests + ) + + try: + self.db.Transact(Txn, txn_tag="WriteFlowRequests") + except spanner_errors.RowNotFoundError as error: + if "Parent row is missing: Flows" in str(error): + raise db.AtLeastOneUnknownFlowError(flow_keys, cause=error) + else: + raise + + def _ReadRequestsInfo( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + txn: spanner_utils.Transaction, + ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str], set[_RequestKey]]: + """For given responses returns data about corresponding requests. + + Args: + responses: an iterable with responses. + txn: transaction to use. + + Returns: + A tuple of 3 dictionaries: ( + responses_expected_by_request, + callback_state_by_request, + currently_available_requests). + + responses_expected_by_request: for requests that already received + a Status response, maps each request id to the number of responses + expected for it. + + callback_state_by_request: for incremental requests, maps each request + id to the name of a flow callback state that has to be called on + every incoming response. + + currently_available_requests: a set with all the request ids corresponding + to given responses. + """ + + # Number of responses each affected request is waiting for (if available). + responses_expected_by_request = {} + + # We also store all requests we have in the db so we can discard responses + # for unknown requests right away. + currently_available_requests = set() + + # Callback states by request. + callback_state_by_request = {} + + req_rows = spanner_lib.RowSet() + for r in responses: + req_rows.Add( + spanner_lib.Key( + spanner_clients.IntClientID(r.client_id), + IntFlowID(r.flow_id), + r.request_id, + ) + ) + + for row in txn.ReadSet( + table="FlowRequests", + rows=req_rows, + cols=[ + "ClientID", + "FlowID", + "RequestID", + "CallbackState", + "ExpectedResponseCount", + ], + ): + + request_key = _RequestKey( + db_utils.IntToClientID(row["ClientID"]), + db_utils.IntToFlowID(row["FlowID"]), + row["RequestID"], + ) + currently_available_requests.add(request_key) + + callback_state: str = row["CallbackState"] + if callback_state: + callback_state_by_request[request_key] = callback_state + + responses_expected: int = row["ExpectedResponseCount"] + if responses_expected: + responses_expected_by_request[request_key] = responses_expected + + return ( + responses_expected_by_request, + callback_state_by_request, + currently_available_requests, + ) + + def _BuildResponseWrites( + self, + responses: Collection[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + txn: spanner_utils.Transaction, + ) -> None: + """Builds the writes to store given responses in the db. + + Args: + responses: iterable with flow responses to write. + txn: transaction to use for the writes. + + Raises: + TypeError: if responses have objects other than FlowResponse, FlowStatus + or FlowIterator. + """ + + with txn.Mutate() as mut: + for r in responses: + row = { + "ClientId": spanner_clients.IntClientID(r.client_id), + "FlowId": IntFlowID(r.flow_id), + "RequestId": r.request_id, + "ResponseId": r.response_id, + "Response": None, + "Status": None, + "Iterator": None, + "CreationTime": spanner_lib.CommitTimestamp(), + } + + if isinstance(r, flows_pb2.FlowResponse): + row["Response"] = r + elif isinstance(r, flows_pb2.FlowStatus): + row["Status"] = r + elif isinstance(r, flows_pb2.FlowIterator): + row["Iterator"] = r + else: + # This can't really happen due to DB validator type checking. + raise TypeError(f"Got unexpected response type: {type(r)} {r}") + + mut.InsertOrUpdate("FlowResponses", row) + + def _BuildExpectedUpdates( + self, updates: dict[_RequestKey, int], txn: spanner_utils.Transaction + ) -> None: + """Builds updates for requests with known number of expected responses. + + Args: + updates: dict mapping requests to the number of expected responses. + txn: transaction to use for the writes. + """ + + with txn.Mutate() as mut: + for r_key, num_responses_expected in updates.items(): + row = { + "ClientId": spanner_clients.IntClientID(r_key.client_id), + "FlowId": IntFlowID(r_key.flow_id), + "RequestId": r_key.request_id, + "ExpectedResponseCount": num_responses_expected, + } + mut.Update("FlowRequests", row) + + def _WriteFlowResponsesAndExpectedUpdates( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: + """Writes a flow responses and updates flow requests expected counts. + + Args: + responses: responses to write. + + Returns: + A tuple of (expected_responses_by_request, callback_state_by_request). + + expected_responses_by_request: number of expected responses by + request id. These numbers are collected from Status responses + discovered in the `responses` sequence. This data is later + passed to _BuildExpectedUpdates. + + callback_state_by_request: callback states by request. If incremental + requests are discovered during processing, their callback states end + up in this dictionary. This information is used later to make a + decision whether a flow should be notified about new responses: + incremental flows have to be notified even if Status responses were + not received. + """ + + if not responses: + return ({}, {}) + + def Txn( + txn: spanner_utils.Transaction, + ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: + ( + responses_expected_by_request, + callback_state_by_request, + currently_available_requests, + ) = self._ReadRequestsInfo(responses, txn) + + # For some requests we will need to update the number of expected + # responses. + needs_expected_update = {} + + for r in responses: + req_key = _RequestKey(r.client_id, r.flow_id, r.request_id) + + # If the response is not a FlowStatus, we have nothing to do: it will be + # simply written to the DB. If it's a FlowStatus, we have to update + # the FlowRequest with the number of expected messages. + if not isinstance(r, flows_pb2.FlowStatus): + continue + + if req_key not in currently_available_requests: + logging.info("Dropping status for unknown request %s", req_key) + continue + + current = responses_expected_by_request.get(req_key) + if current: + logging.warning( + "Got duplicate status message for request %s", req_key + ) + + # If there is already responses_expected information, we need to make + # sure the current status doesn't disagree. + if current != r.response_id: + logging.error( + "Got conflicting status information for request %s: %s", + req_key, + r, + ) + else: + needs_expected_update[req_key] = r.response_id + + responses_expected_by_request[req_key] = r.response_id + + responses_to_write = {} + for r in responses: + req_key = _RequestKey(r.client_id, r.flow_id, r.request_id) + full_key = _ResponseKey( + r.client_id, r.flow_id, r.request_id, r.response_id + ) + + if req_key not in currently_available_requests: + continue + + if full_key in responses_to_write: + # Don't write a response if it was already written as part of the + # same batch. + prev = responses_to_write[full_key] + if r != prev: + logging.warning( + "WriteFlowResponses attempted to write two different " + "entries with identical key %s. First is %s and " + "second is %s.", + full_key, + prev, + r, + ) + continue + + responses_to_write[full_key] = r + + if responses_to_write or needs_expected_update: + self._BuildResponseWrites(responses_to_write.values(), txn) + if needs_expected_update: + self._BuildExpectedUpdates(needs_expected_update, txn) + + return responses_expected_by_request, callback_state_by_request + + return self.db.Transact( + Txn, txn_tag="WriteFlowResponsesAndExpectedUpdates" + ).value + + def _GetFlowResponsesPerRequestCounts( + self, + request_keys: Iterable[_RequestKey], + txn: spanner_utils.SnapshotTransaction, + ) -> dict[_RequestKey, int]: + """Gets counts of already received responses for given requests. + + Args: + request_keys: iterable with request keys. + txn: transaction to use. + + Returns: + A dictionary mapping request keys to the number of existing flow + responses. + """ + + if not request_keys: + return {} + + conditions = [] + params = {} + for i, req_key in enumerate(request_keys): + if i > 0: + conditions.append("OR") + + conditions.append(f""" + (fr.ClientId = {{client_id_{i}}} AND + fr.FlowId = {{flow_id_{i}}} AND + fr.RequestId = {{request_id_{i}}}) + """) + + params[f"client_id_{i}"] = db_utils.ClientIDToInt(req_key.client_id) + params[f"flow_id_{i}"] = db_utils.FlowIDToInt(req_key.flow_id) + params[f"request_id_{i}"] = req_key.request_id + + query = f""" + SELECT fr.ClientId, fr.FlowId, fr.RequestId, COUNT(*) AS ResponseCount + FROM FlowResponses as fr + WHERE {" ".join(conditions)} + GROUP BY fr.ClientID, fr.FlowID, fr.RequestID + """ + + result = {} + for row in txn.ParamQuery(query, params): + client_id_int, flow_id_int, request_id, count = row + + req_key = _RequestKey( + db_utils.IntToClientID(client_id_int), + db_utils.IntToFlowID(flow_id_int), + request_id, + ) + result[req_key] = count + + return result + + def _ReadFlowRequestsNotYetMarkedForProcessing( + self, + requests: set[_RequestKey], + callback_states: dict[_RequestKey, str], + txn: spanner_utils.Transaction, + ) -> tuple[ + set[_RequestKey], set[tuple[_FlowKey, Optional[rdfvalue.RDFDatetime]]] + ]: + """Reads given requests and returns only ones not marked for processing. + + Args: + requests: request keys for requests to be read. + callback_states: dict containing incremental flow requests from the set. + For each such request the request key will be mapped to the callback + state of the flow. + txn: transaction to use. + + Returns: + A tuple of (requests_to_mark, flows_to_notify). + + requests_to_mark is a set of request keys for requests that have to be + marked as needing processing. + + flows_to_notify is a set of tuples (flow_key, start_time) for flows that + have to be notified of incoming responses. start_time in the tuple + corresponds to the intended notification delivery time. + """ + flow_rows = spanner_lib.RowSet() + req_rows = spanner_lib.RowSet() + + unique_flow_keys = set() + + for req_key in set(requests) | set(callback_states): + client_id_int = spanner_clients.IntClientID(req_key.client_id) + flow_id_int = IntFlowID(req_key.flow_id) + + req_rows.AddPrefixRange( + spanner_lib.Key(client_id_int, flow_id_int, req_key.request_id) + ) + unique_flow_keys.add((client_id_int, flow_id_int)) + + for client_id_int, flow_id_int in unique_flow_keys: + flow_rows.AddPrefixRange(spanner_lib.Key(client_id_int, flow_id_int)) + + next_request_to_process_by_flow = {} + flow_cols = ( + "ClientId", + "FlowId", + "NextRequestToProcess", + ) + for row in txn.ReadSet("Flows", flow_rows, flow_cols): + client_id_int: int = row["ClientId"] + flow_id_int: int = row["FlowId"] + next_request_id: int = row["NextRequestToProcess"] + next_request_to_process_by_flow[(client_id_int, flow_id_int)] = ( + next_request_id + ) + + requests_to_mark = set() + requests_to_notify = set() + req_cols = ( + "ClientId", + "FlowId", + "RequestId", + "NeedsProcessing", + "StartTime", + ) + for row in txn.ReadSet("FlowRequests", req_rows, req_cols): + client_id_int: int = row["ClientId"] + flow_id_int: int = row["FlowId"] + request_id: int = row["RequestId"] + np: bool = row["NeedsProcessing"] + start_time: Optional[rdfvalue.RDFDatetime] = None + if row["StartTime"] is not None: + start_time = rdfvalue.RDFDatetime.FromDatetime(row["StartTime"]) + + if not np: + client_id = db_utils.IntToClientID(client_id_int) + flow_id = db_utils.IntToFlowID(flow_id_int) + + req_key = _RequestKey(client_id, flow_id, request_id) + if req_key in requests: + requests_to_mark.add(req_key) + + if ( + next_request_to_process_by_flow[(client_id_int, flow_id_int)] + == request_id + ): + requests_to_notify.add((_FlowKey(client_id, flow_id), start_time)) + + return requests_to_mark, requests_to_notify + + def _BuildNeedsProcessingUpdates( + self, requests: set[_RequestKey], txn: spanner_utils.Transaction + ) -> None: + """Builds updates for requests that have their NeedsProcessing flag set. + + Args: + requests: keys of requests to be updated. + txn: transaction to use. + """ + + with txn.Mutate() as mut: + for req_key in requests: + row = { + "ClientId": spanner_clients.IntClientID(req_key.client_id), + "FlowId": IntFlowID(req_key.flow_id), + "RequestId": req_key.request_id, + "NeedsProcessing": True, + } + mut.Update("FlowRequests", row) + + def _UpdateNeedsProcessingAndWriteFlowProcessingRequests( + self, + requests_ready_for_processing: set[_RequestKey], + callback_state_by_request: dict[_RequestKey, str], + txn: spanner_utils.Transaction, + ) -> None: + """Updates requests needs-processing flags, writes processing requests. + + Args: + requests_ready_for_processing: request keys for requests that have to be + updated. + callback_state_by_request: for incremental requests from the set - mapping + from request ids to callback states that are incrementally processing + incoming responses. + txn: transaction to use. + """ + + if not requests_ready_for_processing and not callback_state_by_request: + return + + (requests_to_mark, flows_to_notify) = ( + self._ReadFlowRequestsNotYetMarkedForProcessing( + requests_ready_for_processing, callback_state_by_request, txn + ) + ) + + if requests_to_mark: + self._BuildNeedsProcessingUpdates(requests_to_mark, txn) + + if flows_to_notify: + flow_processing_requests = [] + for flow_key, start_time in flows_to_notify: + fpr = flows_pb2.FlowProcessingRequest( + client_id=flow_key.client_id, + flow_id=flow_key.flow_id, + ) + if start_time is not None: + fpr.delivery_time = int(start_time) + flow_processing_requests.append(fpr) + + with txn.BufferedMutate() as mut: + self._BuildFlowProcessingRequestWrites(mut, flow_processing_requests) @db_utils.CallLogged @db_utils.CallAccounted @@ -271,7 +1438,41 @@ def WriteFlowResponses( ], ) -> None: """Writes Flow ressages and updates corresponding requests.""" - + responses_expected_by_request = {} + callback_state_by_request = {} + for batch in collection.Batch(responses, self._write_rows_batch_size): + res_exp_by_req_iter, callback_state_by_req_iter = ( + self._WriteFlowResponsesAndExpectedUpdates(batch) + ) + + responses_expected_by_request.update(res_exp_by_req_iter) + callback_state_by_request.update(callback_state_by_req_iter) + + # If we didn't get any status messages, then there's nothing to process. + if not responses_expected_by_request and not callback_state_by_request: + return + + # Get actual per-request responses counts using a separate transaction. + read_txn = self.db.Snapshot() + counts = self._GetFlowResponsesPerRequestCounts( + responses_expected_by_request, read_txn + ) + + requests_ready_for_processing = set() + for req_key, responses_expected in responses_expected_by_request.items(): + if counts.get(req_key) == responses_expected: + requests_ready_for_processing.add(req_key) + + # requests_to_notify is a subset of requests_ready_for_processing, so no + # need to check if it's empty or not. + if requests_ready_for_processing or callback_state_by_request: + + def Txn(txn) -> None: + self._UpdateNeedsProcessingAndWriteFlowProcessingRequests( + requests_ready_for_processing, callback_state_by_request, txn + ) + + self.db.Transact(Txn, txn_tag="WriteFlowResponses") @db_utils.CallLogged @db_utils.CallAccounted @@ -281,7 +1482,13 @@ def DeleteAllFlowRequestsAndResponses( flow_id: str, ) -> None: """Deletes all requests and responses for a given flow from the database.""" - + int_client_id = spanner_clients.IntClientID(client_id) + int_flow_id = IntFlowID(flow_id) + self.db.DeleteWithPrefix( + "FlowRequests", + (int_client_id, int_flow_id), + txn_tag="DeleteAllFlowRequestsAndResponses", + ) @db_utils.CallLogged @db_utils.CallAccounted @@ -290,9 +1497,9 @@ def ReadAllFlowRequestsAndResponses( client_id: str, flow_id: str, ) -> Iterable[ - Tuple[ + tuple[ flows_pb2.FlowRequest, - Dict[ + dict[ int, Union[ flows_pb2.FlowResponse, @@ -304,8 +1511,71 @@ def ReadAllFlowRequestsAndResponses( ]: """Reads all requests and responses for a given flow from the database.""" - ret = [] + txn = self.db.Snapshot() + + req_rows = spanner_lib.RowSet() + req_rows.AddPrefixRange( + spanner_lib.Key( + spanner_clients.IntClientID(client_id), IntFlowID(flow_id) + ) + ) + req_cols = ( + "Payload", + "NeedsProcessing", + "ExpectedResponseCount", + "CallbackState", + "NextResponseId", + "CreationTime", + ) + requests = [] + for row in txn.ReadSet(table="FlowRequests", rows=req_rows, cols=req_cols): + request = flows_pb2.FlowRequest() + request.ParseFromString(row["Payload"]) + request.needs_processing = row["NeedsProcessing"] + if row["ExpectedResponseCount"] is not None: + request.nr_responses_expected = row["ExpectedResponseCount"] + request.callback_state = row["CallbackState"] + request.next_response_id = row["NextResponseId"] + request.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + ) + requests.append(request) + + resp_rows = spanner_lib.RowSet() + resp_rows.AddPrefixRange( + spanner_lib.Key( + spanner_clients.IntClientID(client_id), IntFlowID(flow_id) + ) + ) + resp_cols = ( + "Response", + "Status", + "Iterator", + "CreationTime", + ) + responses = {} + for row in txn.ReadSet( + table="FlowResponses", rows=resp_rows, cols=resp_cols + ): + if row["Status"] is not None: + response = flows_pb2.FlowStatus() + response.ParseFromString(row["Status"]) + elif row["Iterator"] is not None: + response = flows_pb2.FlowIterator() + response.ParseFromString(row["Iterator"]) + else: + response = flows_pb2.FlowResponse() + response.ParseFromString(row["Response"]) + response.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + ) + responses.setdefault(response.request_id, {})[ + response.response_id + ] = response + ret = [] + for req in sorted(requests, key=lambda r: r.request_id): + ret.append((req, responses.get(req.request_id, {}))) return ret @db_utils.CallLogged @@ -315,8 +1585,51 @@ def DeleteFlowRequests( requests: Sequence[flows_pb2.FlowRequest], ) -> None: """Deletes a list of flow requests from the database.""" - - + if not requests: + return + + def Mutation(mut) -> None: + for request in requests: + key = [ + spanner_clients.IntClientID(request.client_id), + IntFlowID(request.flow_id), + request.request_id, + ] + mut.Delete(table="FlowRequests", key=key) + + try: + self.db.Mutate(Mutation, txn_tag="DeleteFlowRequests") + # TODO(b/196379916): Narrow the exception types (cl/450440276). + except spanner_errors.BadUsageError: + if len(requests) == 1: + # If there is only one request and we still hit Spanner limits it means + # that the requests has a lot of responses. It should be extremely rare + # to end up in such situation, so we just leave the request in the + # database. Eventually, these rows will be deleted automatically due to + # our row retention policies [1] for this table. + # + # [1]: go/spanner-row-deletion-policies. + SPANNER_DELETE_FLOW_REQUESTS_FAILURES.Increment() + logging.error( + "Transaction too big to delete flow request '%s'", requests[0] + ) + else: + # If there is more than one request, we attempt to divide the data into + # smaller parts and delete these. + # + # Note that dividing in two does not mean that the number of deleted + # rows will spread evenly as it might be the case that one request in + # one part has significantly more responses than requests in the other + # part. However, as a cheap and reasonable approximation, this should do + # just fine. + # + # Notice that both this `DeleteFlowRequests` calls happen in separate + # transactions. Since we are just deleting rows "obsolete" rows we do + # not really care about atomicity. If one of them succeeds and the other + # one fails, rows are going to be deleted eventually anyway (see the + # comment for a single request case). + self.DeleteFlowRequests(requests[: len(requests) // 2]) + self.DeleteFlowRequests(requests[len(requests) // 2 :]) @db_utils.CallLogged @db_utils.CallAccounted @@ -324,11 +1637,11 @@ def ReadFlowRequests( self, client_id: str, flow_id: str, - ) -> Dict[ + ) -> dict[ int, - Tuple[ + tuple[ flows_pb2.FlowRequest, - List[ + list[ Union[ flows_pb2.FlowResponse, flows_pb2.FlowStatus, @@ -339,7 +1652,84 @@ def ReadFlowRequests( ]: """Reads all requests for a flow that can be processed by the worker.""" - return {} + txn = self.db.Snapshot() + rows = spanner_lib.RowSet() + rows.AddPrefixRange( + spanner_lib.Key( + spanner_clients.IntClientID(client_id), IntFlowID(flow_id) + ) + ) + + responses: dict[ + int, + list[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ] + ], + ] = {} + resp_cols = ( + "Response", + "Status", + "Iterator", + "CreationTime", + ) + for row in txn.ReadSet(table="FlowResponses", rows=rows, cols=resp_cols): + if row["Status"]: + response = flows_pb2.FlowStatus() + response.ParseFromString(row["Status"]) + elif row["Iterator"]: + response = flows_pb2.FlowIterator() + response.ParseFromString(row["Iterator"]) + else: + response = flows_pb2.FlowResponse() + response.ParseFromString(row["Response"]) + response.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + ) + responses.setdefault(response.request_id, []).append(response) + + requests: dict[ + int, + tuple[ + flows_pb2.FlowRequest, + list[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ] = {} + req_cols = ( + "Payload", + "NeedsProcessing", + "ExpectedResponseCount", + "NextResponseId", + "CallbackState", + "CreationTime", + ) + for row in txn.ReadSet(table="FlowRequests", rows=rows, cols=req_cols): + request = flows_pb2.FlowRequest() + request.ParseFromString(row["Payload"]) + request.needs_processing = row["NeedsProcessing"] + if row["ExpectedResponseCount"] is not None: + request.nr_responses_expected = row["ExpectedResponseCount"] + request.callback_state = row["CallbackState"] + request.next_response_id = row["NextResponseId"] + request.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + ) + requests[request.request_id] = ( + request, + sorted( + responses.get(request.request_id, []), key=lambda r: r.response_id + ), + ) + return requests @db_utils.CallLogged @db_utils.CallAccounted @@ -351,11 +1741,40 @@ def UpdateIncrementalFlowRequests( ) -> None: """Updates next response ids of given requests.""" + int_client_id = spanner_clients.IntClientID(client_id) + int_flow_id = IntFlowID(flow_id) + + with self.db.MutationPool() as mp: + for request_id, response_id in next_response_id_updates.items(): + with mp.Apply() as mut: + mut.Update( + table="FlowRequests", + row={ + "ClientId": int_client_id, + "FlowId": int_flow_id, + "RequestId": request_id, + "NextResponseId": response_id, + }, + ) @db_utils.CallLogged @db_utils.CallAccounted def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: """Writes a single flow log entry to the database.""" + row = { + "ClientId": spanner_clients.IntClientID(entry.client_id), + "FlowId": IntFlowID(entry.flow_id), + "CreationTime": spanner_lib.CommitTimestamp(), + "Message": entry.message, + } + + if entry.hunt_id: + row["HuntId"] = IntHuntID(entry.hunt_id) + + try: + self.db.Insert(table="FlowLogEntries", row=row) + except spanner_errors.RowNotFoundError as error: + raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error def ReadFlowLogEntries( self, @@ -368,14 +1787,66 @@ def ReadFlowLogEntries( """Reads flow log entries of a given flow using given query options.""" results = [] + query = """ + SELECT l.HuntId, + l.CreationTime, + l.Message + FROM FlowLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + """ + params = { + "client_id": spanner_clients.IntClientID(client_id), + "flow_id": IntFlowID(flow_id), + } + + if with_substring is not None: + query += " AND STRPOS(l.Message, {substring}) != 0" + params["substring"] = with_substring + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + for row in self.db.ParamQuery(query, params): + int_hunt_id, creation_time, message = row + + result = flows_pb2.FlowLogEntry() + result.client_id = client_id + result.flow_id = flow_id + + if int_hunt_id is not None: + result.hunt_id = db_utils.IntToHuntID(int_hunt_id) + + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.message = message + + results.append(result) + return results @db_utils.CallLogged @db_utils.CallAccounted def CountFlowLogEntries(self, client_id: str, flow_id: str) -> int: """Returns number of flow log entries of a given flow.""" + query = """ + SELECT COUNT(*) + FROM FlowLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + """ + params = { + "client_id": spanner_clients.IntClientID(client_id), + "flow_id": IntFlowID(flow_id), + } - return 0 + (count,) = self.db.ParamQuerySingle(query, params) + return count @db_utils.CallLogged @db_utils.CallAccounted @@ -387,7 +1858,28 @@ def WriteFlowRRGLogs( logs: Mapping[int, rrg_pb2.Log], ) -> None: """Writes new log entries for a particular action request.""" - + # Mutations cannot be empty, so we exit early to avoid that if needed. + if not logs: + return + + def Mutation(mut) -> None: + for response_id, log in logs.items(): + row = { + "ClientId": db_utils.ClientIDToInt(client_id), + "FlowId": db_utils.FlowIDToInt(flow_id), + "RequestId": request_id, + "ResponseId": response_id, + "LogLevel": log.level, + "LogTime": log.timestamp.ToDatetime(), + "LogMessage": log.message, + "CreationTime": spanner_lib.CommitTimestamp(), + } + mut.Insert(table="FlowRRGLogs", row=row) + + try: + self.db.Mutate(Mutation) + except spanner_errors.RowNotFoundError as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) from error @db_utils.CallLogged @db_utils.CallAccounted @@ -399,8 +1891,40 @@ def ReadFlowRRGLogs( count: int, ) -> Sequence[rrg_pb2.Log]: """Reads log entries logged by actions issued by a particular flow.""" + query = """ + SELECT + l.LogLevel, l.LogTime, l.LogMessage + FROM + FlowRRGLogs AS l + WHERE + l.ClientId = {client_id} AND l.FlowId = {flow_id} + ORDER BY + l.RequestId, l.ResponseId + LIMIT + {count} + OFFSET + {offset} + """ + params = { + "client_id": db_utils.ClientIDToInt(client_id), + "flow_id": db_utils.FlowIDToInt(flow_id), + "offset": offset, + "count": count, + } + + results: list[rrg_pb2.Log] = [] + + for row in self.db.ParamQuery(query, params, txn_tag="ReadFlowRRGLogs"): + log_level, log_time, log_message = row + + log = rrg_pb2.Log() + log.level = log_level + log.timestamp.FromDatetime(log_time) + log.message = log_message - return [] + results.append(log) + + return results @db_utils.CallLogged @db_utils.CallAccounted @@ -413,7 +1937,22 @@ def WriteFlowOutputPluginLogEntry( Args: entry: An output plugin flow entry to write. """ - + row = { + "ClientId": spanner_clients.IntClientID(entry.client_id), + "FlowId": IntFlowID(entry.flow_id), + "OutputPluginId": IntOutputPluginID(entry.output_plugin_id), + "CreationTime": spanner_lib.CommitTimestamp(), + "Type": int(entry.log_entry_type), + "Message": entry.message, + } + + if entry.hunt_id: + row["HuntId"] = IntHuntID(entry.hunt_id) + + try: + self.db.Insert(table="FlowOutputPluginLogEntries", row=row) + except spanner_errors.RowNotFoundError as error: + raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error @db_utils.CallLogged @db_utils.CallAccounted @@ -431,6 +1970,51 @@ def ReadFlowOutputPluginLogEntries( """Reads flow output plugin log entries.""" results = [] + query = """ + SELECT l.HuntId, + l.CreationTime, + l.Type, l.Message + FROM FlowOutputPluginLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "client_id": spanner_clients.IntClientID(client_id), + "flow_id": IntFlowID(flow_id), + "output_plugin_id": IntOutputPluginID(output_plugin_id), + } + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + for row in self.db.ParamQuery(query, params): + int_hunt_id, creation_time, int_type, message = row + + result = flows_pb2.FlowOutputPluginLogEntry() + result.client_id = client_id + result.flow_id = flow_id + result.output_plugin_id = output_plugin_id + + if int_hunt_id is not None: + result.hunt_id = db_utils.IntToHuntID(int_hunt_id) + + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.log_entry_type = int_type + result.message = message + + results.append(result) + return results @db_utils.CallLogged @@ -445,8 +2029,25 @@ def CountFlowOutputPluginLogEntries( ] = None, ) -> int: """Returns the number of flow output plugin log entries of a given flow.""" + query = """ + SELECT COUNT(*) + FROM FlowOutputPluginLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "client_id": spanner_clients.IntClientID(client_id), + "flow_id": IntFlowID(flow_id), + "output_plugin_id": IntOutputPluginID(output_plugin_id), + } + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) - return 0 + (count,) = self.db.ParamQuerySingle(query, params) + return count @db_utils.CallLogged @db_utils.CallAccounted @@ -455,7 +2056,28 @@ def WriteScheduledFlow( scheduled_flow: flows_pb2.ScheduledFlow, ) -> None: """Inserts or updates the ScheduledFlow in the database.""" - + row = { + "ClientID": spanner_clients.IntClientID(scheduled_flow.client_id), + "Creator": scheduled_flow.creator, + "ScheduledFlowId": IntFlowID(scheduled_flow.scheduled_flow_id), + "FlowName": scheduled_flow.flow_name, + "FlowArgs": scheduled_flow.flow_args, + "RunnerArgs": scheduled_flow.runner_args, + "CreationTime": rdfvalue.RDFDatetime( + scheduled_flow.create_time + ).AsDatetime(), + "Error": scheduled_flow.error, + } + + try: + self.db.InsertOrUpdate(table="ScheduledFlows", row=row) + except spanner_errors.RowNotFoundError as error: + if "Parent row is missing: Clients" in str(error): + raise db.UnknownClientError(scheduled_flow.client_id) from error + elif "fk_creator_users_username" in str(error): + raise db.UnknownGRRUserError(scheduled_flow.creator) from error + else: + raise @db_utils.CallLogged @db_utils.CallAccounted @@ -467,6 +2089,26 @@ def DeleteScheduledFlow( ) -> None: """Deletes the ScheduledFlow from the database.""" + key = ( + spanner_clients.IntClientID(client_id), + creator, + IntFlowID(scheduled_flow_id), + ) + + def Transaction(txn) -> None: + try: + txn.Read(table="ScheduledFlows", cols=["ScheduledFlowId"], key=key) + except spanner_errors.RowNotFoundError as e: + raise db.UnknownScheduledFlowError( + client_id=client_id, + creator=creator, + scheduled_flow_id=scheduled_flow_id, + ) from e + + with txn.Mutate() as mut: + mut.Delete(table="ScheduledFlows", key=key) + + self.db.Transact(Transaction) @db_utils.CallLogged @db_utils.CallAccounted @@ -476,9 +2118,38 @@ def ListScheduledFlows( creator: str, ) -> Sequence[flows_pb2.ScheduledFlow]: """Lists all ScheduledFlows for the client and creator.""" - + rows = spanner_lib.RowSet() + rows.AddPrefixRange( + spanner_lib.Key(spanner_clients.IntClientID(client_id), creator) + ) + + cols = ( + "ClientId", + "Creator", + "ScheduledFlowId", + "FlowName", + "FlowArgs", + "RunnerArgs", + "CreationTime", + "Error", + ) results = [] + for row in self.db.ReadSet("ScheduledFlows", rows, cols): + sf = flows_pb2.ScheduledFlow() + sf.client_id = db_utils.IntToClientID(row["ClientId"]) + sf.creator = row["Creator"] + sf.scheduled_flow_id = db_utils.IntToFlowID(row["ScheduledFlowId"]) + sf.flow_name = row["FlowName"] + sf.flow_args.ParseFromString(row["FlowArgs"]) + sf.runner_args.ParseFromString(row["RunnerArgs"]) + sf.create_time = int( + rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + ) + sf.error = row["Error"] + + results.append(sf) + return results @db_utils.CallLogged @@ -488,6 +2159,20 @@ def WriteMessageHandlerRequests( ) -> None: """Writes a list of message handler requests to the database.""" + def Mutation(mut: spanner_utils.Mutation) -> None: + creation_timestamp = spanner_lib.CommitTimestamp() + + for r in requests: + key = (r.handler_name, r.request_id, creation_timestamp) + + mut.Send( + queue="MessageHandlerRequestsQueue", + key=key, + value=r, + column="Payload", + ) + + self.db.BufferedMutate(Mutation) @db_utils.CallLogged @db_utils.CallAccounted @@ -496,10 +2181,42 @@ def ReadMessageHandlerRequests( ) -> Sequence[objects_pb2.MessageHandlerRequest]: """Reads all message handler requests from the database.""" + query = """ + SELECT t.Payload, t.CreationTime FROM MessageHandlerRequestsQueue AS t + """ + results = [] - return results + for payload, creation_time in self.db.ParamQuery(query, {}): + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(payload) + req.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + results.append(req) + return results + def _BuildDeleteMessageHandlerRequestWrites( + self, + txn: spanner_utils.Transaction, + requests: Iterable[objects_pb2.MessageHandlerRequest], + ) -> None: + """Deletes given requests within a given transaction.""" + req_rows = spanner_lib.RowSet() + for r in requests: + req_rows.AddPrefixRange(spanner_lib.Key(r.handler_name, r.request_id)) + + to_delete = [] + req_cols = ("HandlerName", "RequestId", "CreationTime") + for row in txn.ReadSet("MessageHandlerRequestsQueue", req_rows, req_cols): + handler_name: str = row["HandlerName"] + request_id: int = row["RequestId"] + creation_time: datetime.datetime = row["CreationTime"] + to_delete.append((handler_name, request_id, creation_time)) + + with txn.BufferedMutate() as mut: + for td_key in to_delete: + mut.Ack("MessageHandlerRequestsQueue", td_key) @db_utils.CallLogged @db_utils.CallAccounted @@ -508,6 +2225,62 @@ def DeleteMessageHandlerRequests( ) -> None: """Deletes a list of message handler requests from the database.""" + def Txn(txn: spanner_utils.Transaction) -> None: + self._BuildDeleteMessageHandlerRequestWrites(txn, requests) + + self.db.Transact(Txn) + + def _LeaseMessageHandlerRequest( + self, + req: objects_pb2.MessageHandlerRequest, + lease_time: rdfvalue.Duration, + ) -> objects_pb2.MessageHandlerRequest: + """Leases the given message handler request. + + Leasing of the message amounts to the following: + 1. The message gets deleted from the queue. + 2. It gets rescheduled in the future (at now + lease_time) with + "leased_until" and "leased_by" attributes set. + + Args: + req: MessageHandlerRequest to lease. + lease_time: Lease duration. + + Returns: + Copy of the original request object with "leased_until" and "leased_by" + attributes set. + """ + delivery_time = rdfvalue.RDFDatetime.Now() + lease_time + + clone = objects_pb2.MessageHandlerRequest() + clone.CopyFrom(req) + clone.leased_until = delivery_time.AsMicrosecondsSinceEpoch() + clone.leased_by = utils.ProcessIdString() + + def Txn(txn) -> None: + self._BuildDeleteMessageHandlerRequestWrites(txn, [clone]) + + with txn.BufferedMutate() as mut: + key = (req.handler_name, req.request_id, spanner_lib.CommitTimestamp()) + mut.Send( + queue="MessageHandlerRequestsQueue", + key=key, + value=clone, + column="Payload", + deliver_time=delivery_time.AsDatetimeUTC(), + ) + + result = self.db.Transact(Txn) + # Using the transaction's commit timestamp and modifying the request object + # with it allows us to avoid doing a separate database read to read + # the same object back with a Spanner-provided timestamp + # (please see ReadMessageHandlerRequests that uses the CreationTime + # column to set the request's 'timestamp' attribute). + clone.timestamp = rdfvalue.RDFDatetime.FromDatetime( + result.commit_time + ).AsMicrosecondsSinceEpoch() + + return clone def RegisterMessageHandler( self, @@ -516,13 +2289,48 @@ def RegisterMessageHandler( limit: int = 1000, ) -> None: """Registers a message handler to receive batches of messages.""" - + self.UnregisterMessageHandler() + + def Callback(expanded_key: Sequence[Any], payload: bytes): + del expanded_key + try: + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(payload) + leased = self._LeaseMessageHandlerRequest(req, lease_time) + logging.info("Leased message handler request: %s", req.request_id) + handler([leased]) + except Exception as e: # pylint: disable=broad-except + logging.exception( + "Exception raised during MessageHandlerRequest processing: %s", e + ) + + receiver = self.db.NewQueueReceiver( + "MessageHandlerRequestsQueue", + Callback, + receiver_max_keepalive_seconds=_MESSAGE_HANDLER_MAX_KEEPALIVE_SECONDS, + receiver_max_active_callbacks=_MESSAGE_HANDLER_MAX_ACTIVE_CALLBACKS, + receiver_max_messages_per_callback=limit, + ) + receiver.Receive() + self._message_handler_receiver = receiver def UnregisterMessageHandler( self, timeout: Optional[rdfvalue.Duration] = None ) -> None: """Unregisters any registered message handler.""" - + del timeout # Unused. + if self._message_handler_receiver: + self._message_handler_receiver.Stop() # pytype: disable=attribute-error # always-use-return-annotations + self._message_handler_receiver = None + + def _ReadHuntState( + self, txn: spanner_utils.Transaction, hunt_id: str + ) -> Optional[int]: + try: + row = txn.Read(table="Hunts", key=(IntHuntID(hunt_id),), cols=("State",)) + return row["State"] + except spanner_errors.RowNotFoundError: + return None @db_utils.CallLogged @db_utils.CallAccounted @@ -533,12 +2341,162 @@ def LeaseFlowForProcessing( processing_time: rdfvalue.Duration, ) -> flows_pb2.Flow: """Marks a flow as being processed on this worker and returns it.""" - - return None + int_client_id = spanner_clients.IntClientID(client_id) + int_flow_id = IntFlowID(flow_id) + + def Txn(txn) -> flows_pb2.Flow: + try: + row = txn.Read( + table="Flows", + key=(int_client_id, int_flow_id), + cols=_READ_FLOW_OBJECT_COLS, + ) + except spanner_errors.RowNotFoundError as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + + flow = _ParseReadFlowObjectRow(client_id, flow_id, row) + now = rdfvalue.RDFDatetime.Now() + if flow.processing_on and flow.processing_deadline > int(now): + raise ValueError( + "Flow {}/{} is already being processed on {} since {} " + "with deadline {} (now: {})).".format( + client_id, + flow_id, + flow.processing_on, + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + flow.processing_since + ), + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + flow.processing_deadline + ), + now, + ) + ) + + if flow.parent_hunt_id is not None: + hunt_state = self._ReadHuntState(txn, flow.parent_hunt_id) + if ( + hunt_state is not None + and not models_hunts.IsHuntSuitableForFlowProcessing(hunt_state) + ): + raise db.ParentHuntIsNotRunningError( + client_id, flow_id, flow.parent_hunt_id, hunt_state + ) + + flow.processing_on = utils.ProcessIdString() + flow.processing_deadline = int(now + processing_time) + + txn.Update( + table="Flows", + row={ + "ClientId": int_client_id, + "FlowId": int_flow_id, + "ProcessingWorker": flow.processing_on, + "ProcessingEndTime": ( + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + flow.processing_deadline + ).AsDatetime() + ), + "ProcessingStartTime": spanner_lib.CommitTimestamp(), + }, + ) + + return flow + + result = self.db.Transact(Txn) + + leased_flow = result.value + leased_flow.processing_since = int( + rdfvalue.RDFDatetime.FromDatetime(result.commit_time) + ) + return leased_flow @db_utils.CallLogged @db_utils.CallAccounted def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: """Releases a flow that the worker was processing to the database.""" - - return False + int_client_id = spanner_clients.IntClientID(flow_obj.client_id) + int_flow_id = IntFlowID(flow_obj.flow_id) + + def Txn(txn) -> bool: + try: + row = txn.Read( + table="FlowRequests", + key=(int_client_id, int_flow_id, flow_obj.next_request_to_process), + cols=("NeedsProcessing", "StartTime"), + ) + if row["NeedsProcessing"]: + start_time = row["StartTime"] + if start_time is None: + return False + elif ( + rdfvalue.RDFDatetime.FromDatetime(start_time) + < rdfvalue.RDFDatetime.Now() + ): + return False + except spanner_errors.RowNotFoundError: + pass + txn.Update( + table="Flows", + row={ + "ClientId": int_client_id, + "FlowId": int_flow_id, + "Flow": flow_obj, + "State": int(flow_obj.flow_state), + "UserCpuTimeUsed": float(flow_obj.cpu_time_used.user_cpu_time), + "SystemCpuTimeUsed": float( + flow_obj.cpu_time_used.system_cpu_time + ), + "NetworkBytesSent": spanner_lib.UInt64( + flow_obj.network_bytes_sent + ), + "ProcessingWorker": None, + "ProcessingStartTime": None, + "ProcessingEndTime": None, + "NextRequesttoProcess": spanner_lib.UInt64( + flow_obj.next_request_to_process + ), + "UpdateTime": spanner_lib.CommitTimestamp(), + "ReplyCount": flow_obj.num_replies_sent, + }, + ) + + return True + + return self.db.Transact(Txn).value + + +def IntFlowID(flow_id: str) -> int: + """Converts a flow identifier to its integer representation. + + Args: + flow_id: A flow identifier to convert. + + Returns: + An integer representation of the given flow identifier. + """ + return db_utils.FlowIDToInt(flow_id) + + +def IntHuntID(hunt_id: str) -> int: + """Converts a hunt identifier to its integer representation. + + Args: + hunt_id: A hunt identifier to convert. + + Returns: + An integer representation of the given hunt identifier. + """ + return db_utils.HuntIDToInt(hunt_id) + + +def IntOutputPluginID(output_plugin_id: str) -> int: + """Converts an output plugin identifier to its integer representation. + + Args: + output_plugin_id: An output plugin identifier to convert. + + Returns: + An integer representation of the given output plugin identifier. + """ + return db_utils.OutputPluginIDToInt(output_plugin_id) diff --git a/grr/server/grr_response_server/databases/spanner_flows_large_test.py b/grr/server/grr_response_server/databases/spanner_flows_large_test.py new file mode 100644 index 000000000..143ac3c96 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_flows_large_test.py @@ -0,0 +1,53 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_flows_test +from grr_response_server.databases import spanner_test_lib +from grr_response_server.databases import spanner_utils + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseFlowsTest( + db_flows_test.DatabaseLargeTestFlowMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + + # To cleanup the database we use `DeleteWithPrefix` (to do multiple deletions + # within a single mutation) but this method is super slow for cleaning up huge + # amounts of data. Thus, for certain methods that populate the database with + # a lot of rows we manually clean up using the `DELETE` DML statement which is + # faster in such cases. + + def test40001RequestsCanBeWrittenAndRead(self): + super().test40001RequestsCanBeWrittenAndRead() + + db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error + db.ExecutePartitioned("DELETE FROM FlowRequests WHERE TRUE") + + def test40001ResponsesCanBeWrittenAndRead(self): + super().test40001ResponsesCanBeWrittenAndRead() + + db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error + db.ExecutePartitioned("DELETE FROM FlowResponses WHERE TRUE") + + def testWritesAndCounts40001FlowResults(self): + super().testWritesAndCounts40001FlowResults() + + db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error + db.ExecutePartitioned("DELETE FROM FlowResults WHERE TRUE") + + def testWritesAndCounts40001FlowErrors(self): + super().testWritesAndCounts40001FlowErrors() + + db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error + db.ExecutePartitioned("DELETE FROM FlowErrors WHERE TRUE") + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_flows_test.py b/grr/server/grr_response_server/databases/spanner_flows_test.py new file mode 100644 index 000000000..1e981ef08 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_flows_test.py @@ -0,0 +1,29 @@ +import random +from unittest import mock + +from absl.testing import absltest + +from google.cloud import spanner as spanner_lib +from grr_response_proto import flows_pb2 +from grr_response_server.databases import db_flows_test +from grr_response_server.databases import db_test_utils +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseFlowsTest( + db_flows_test.DatabaseTestFlowMixin, spanner_test_lib.TestCase +): + """Spanner flow tests.""" + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 864e1a7a9..1f863eec7 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -3,11 +3,17 @@ import datetime import logging +import random +import sys + from typing import Optional, Sequence, Tuple +from google.api_core.exceptions import NotFound + +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import iterator -from grr_response_core.lib.util import random from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_proto import user_pb2 @@ -57,14 +63,59 @@ def WriteGRRUser( @db_utils.CallAccounted def DeleteGRRUser(self, username: str) -> None: """Deletes the user and all related metadata with the given username.""" + keyset = spanner_lib.KeySet(keys=[(username,)]) + + def Transaction(txn) -> None: + try: + txn.read(table="Users", columns=("Username",), keyset=keyset) + except NotFound: + raise abstract_db.UnknownGRRUserError(username) + + query = f""" + DELETE + FROM ApprovalGrants@{{FORCE_INDEX=ApprovalGrantsByGrantor}} AS g + WHERE g.Grantor = {username} + """ + txn.execute_sql(query) + + query = f""" + DELETE + FROM ScheduledFlows@{{FORCE_INDEX=ScheduledFlowsByCreator}} AS f + WHERE f.Creator = {username} + """ + txn.execute_sql(query) + + username_range = spanner_lib.KeyRange(start_closed=[username], end_closed=[username]) + txn.delete(table="ApprovalRequests", keyset=spanner_lib.KeySet(ranges=[username_range])) + txn.delete(table="Users", keyset=keyset) + + self.db.Transact(Transaction, txn_tag="DeleteGRRUser") @db_utils.CallLogged @db_utils.CallAccounted def ReadGRRUser(self, username: str) -> objects_pb2.GRRUser: """Reads a user object corresponding to a given name.""" - - return None + cols = ("Email", "Password", "Type", "CanaryMode", "UiMode") + try: + row = self.db.Read(table="Users", key=[username], cols=cols) + except NotFound as error: + raise abstract_db.UnknownGRRUserError(username) from error + + user = objects_pb2.GRRUser( + username=username, + email=row[0], + user_type=row[2], + canary_mode=row[3], + ui_mode=row[4], + ) + + if row[1]: + pw = jobs_pb2.Password() + pw.ParseFromString(row[1]) + user.password.CopyFrom(pw) + + return user @db_utils.CallLogged @db_utils.CallAccounted @@ -74,22 +125,92 @@ def ReadGRRUsers( count: Optional[int] = None, ) -> Sequence[objects_pb2.GRRUser]: """Reads GRR users with optional pagination, sorted by username.""" - - return [] + if count is None: + # TODO(b/196379916): We use the same value as F1 implementation does. But + # a better solution would be to dynamically not ignore the `LIMIT` clause + # in the query if the count parameter is not provided. This is not trivial + # as queries have to be provided as docstrings (an utility that does it + # on the fly has to be created to hack around this limitation). + count = 2147483647 + + users = [] + + query = """ + SELECT u.Username, u.Email, u.Password, u.Type, u.CanaryMode, u.UiMode + FROM Users AS u + ORDER BY u.Username + LIMIT {count} + OFFSET {offset} + """ + params = { + "offset": offset, + "count": count, + } + + for row in self.db.ParamQuery(query, params, txn_tag="ReadGRRUsers"): + username, email, password, typ, canary_mode, ui_mode = row + + user = objects_pb2.GRRUser( + username=username, + email=email, + user_type=typ, + canary_mode=canary_mode, + ui_mode=ui_mode, + ) + + if password: + user.password.ParseFromString(password) + + users.append(user) + + return users @db_utils.CallLogged @db_utils.CallAccounted def CountGRRUsers(self) -> int: """Returns the total count of GRR users.""" + query = """ + SELECT COUNT(*) + FROM Users + """ - return 0 + (count,) = self.db.QuerySingle(query, txn_tag="CountGRRUsers") + return count @db_utils.CallLogged @db_utils.CallAccounted def WriteApprovalRequest(self, request: objects_pb2.ApprovalRequest) -> str: """Writes an approval request object.""" - - return "" + approval_id = random.randint(0, sys.maxsize) + + row = { + "Requestor": request.requestor_username, + "ApprovalId": approval_id, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "ExpirationTime": ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(request.expiration_time) + .AsDatetime() + ), + "Reason": request.reason, + "NotifiedUsers": list(request.notified_users), + "CcEmails": list(request.email_cc_addresses), + } + + if request.approval_type == _APPROVAL_TYPE_CLIENT: + row["SubjectClientId"] = db_utils.ClientIDToInt(request.subject_id) + elif request.approval_type == _APPROVAL_TYPE_HUNT: + row["SubjectHuntId"] = db_utils.HuntIDToInt(request.subject_id) + elif request.approval_type == _APPROVAL_TYPE_CRON_JOB: + row["SubjectCronJobId"] = request.subject_id + else: + raise ValueError(f"Unsupported approval type: {request.approval_type}") + + self.db.Insert( + table="ApprovalRequests", row=row, txn_tag="WriteApprovalRequest" + ) + + return _HexApprovalID(approval_id) @db_utils.CallLogged @db_utils.CallAccounted @@ -99,8 +220,72 @@ def ReadApprovalRequest( approval_id: str, ) -> objects_pb2.ApprovalRequest: """Reads an approval request object with a given id.""" - - return None + approval_id = _UnhexApprovalID(approval_id) + + query = """ + SELECT r.SubjectClientId, r.SubjectHuntId, r.SubjectCronJobId, + r.Reason, + r.CreationTime, r.ExpirationTime, + r.NotifiedUsers, r.CcEmails, + ARRAY(SELECT AS STRUCT g.Grantor, + g.CreationTime + FROM ApprovalGrants AS g + WHERE g.Requestor = r.Requestor + AND g.ApprovalId = r.ApprovalId) AS Grants + FROM ApprovalRequests AS r + WHERE r.Requestor = {requestor} + AND r.ApprovalId = {approval_id} + """ + params = { + "requestor": username, + "approval_id": approval_id, + } + + try: + row = self.db.ParamQuerySingle( + query, params, txn_tag="ReadApprovalRequest" + ) + except NotFound: + # TODO: Improve error message of this error class. + raise abstract_db.UnknownApprovalRequestError(approval_id) + + subject_client_id, subject_hunt_id, subject_cron_job_id, *row = row + reason, *row = row + creation_time, expiration_time, *row = row + notified_users, cc_emails, grants = row + + request = objects_pb2.ApprovalRequest( + requestor_username=username, + approval_id=_HexApprovalID(approval_id), + reason=reason, + timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), + expiration_time=RDFDatetime(expiration_time).AsMicrosecondsSinceEpoch(), + notified_users=notified_users, + email_cc_addresses=cc_emails, + ) + + if subject_client_id is not None: + request.subject_id = db_utils.IntToClientID(subject_client_id) + request.approval_type = _APPROVAL_TYPE_CLIENT + elif subject_hunt_id is not None: + request.subject_id = db_utils.IntToHuntID(subject_hunt_id) + request.approval_type = _APPROVAL_TYPE_HUNT + elif subject_cron_job_id is not None: + request.subject_id = subject_cron_job_id + request.approval_type = _APPROVAL_TYPE_CRON_JOB + else: + # This should not happen as the condition to one of these being always + # set if enforced by the database schema. + message = "No subject set for approval '%s' of user '%s'" + logging.error(message, approval_id, username) + + for grantor, creation_time in grants: + grant = objects_pb2.ApprovalGrant() + grant.grantor_username = grantor + grant.timestamp = RDFDatetime(creation_time).AsMicrosecondsSinceEpoch() + request.grants.add().CopyFrom(grant) + + return request @db_utils.CallLogged @db_utils.CallAccounted @@ -114,6 +299,100 @@ def ReadApprovalRequests( """Reads approval requests of a given type for a given user.""" requests = [] + # We need to use double curly braces for parameters as we also parametrize + # over index that is substituted using standard Python templating. + query = """ + SELECT r.ApprovalId, + r.SubjectClientId, r.SubjectHuntId, r.SubjectCronJobId, + r.Reason, + r.CreationTime, r.ExpirationTime, + r.NotifiedUsers, r.CcEmails, + ARRAY(SELECT AS STRUCT g.Grantor, + g.CreationTime + FROM ApprovalGrants AS g + WHERE g.Requestor = r.Requestor + AND g.ApprovalId = r.ApprovalId) AS Grants + FROM ApprovalRequests@{{{{FORCE_INDEX={index}}}}} AS r + WHERE r.Requestor = {{requestor}} + """ + params = { + "requestor": username, + } + + # By default we use the "by requestor" index but in case a specific subject + # is given we can also use a more specific index (overridden below). + index = "ApprovalRequestsByRequestor" + + if typ == _APPROVAL_TYPE_CLIENT: + query += " AND r.SubjectClientId IS NOT NULL" + if subject_id is not None: + query += " AND r.SubjectClientId = {{subject_client_id}}" + params["subject_client_id"] = db_utils.ClientIDToInt(subject_id) + index = "ApprovalRequestsByRequestorSubjectClientId" + elif typ == _APPROVAL_TYPE_HUNT: + query += " AND r.SubjectHuntId IS NOT NULL" + if subject_id is not None: + query += " AND r.SubjectHuntId = {{subject_hunt_id}}" + params["subject_hunt_id"] = db_utils.HuntIDToInt(subject_id) + index = "ApprovalRequestsByRequestorSubjectHuntId" + elif typ == _APPROVAL_TYPE_CRON_JOB: + query += " AND r.SubjectCronJobId IS NOT NULL" + if subject_id is not None: + query += " AND r.SubjectCronJobId = {{subject_cron_job_id}}" + params["subject_cron_job_id"] = subject_id + index = "ApprovalRequestsByRequestorSubjectCronJobId" + else: + raise ValueError(f"Unsupported approval type: {typ}") + + if not include_expired: + query += " AND r.ExpirationTime > CURRENT_TIMESTAMP()" + + query = query.format(index=index) + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadApprovalRequests" + ): + approval_id, *row = row + subject_client_id, subject_hunt_id, subject_cron_job_id, *row = row + reason, *row = row + creation_time, expiration_time, *row = row + notified_users, cc_emails, grants = row + + request = objects_pb2.ApprovalRequest( + requestor_username=username, + approval_id=_HexApprovalID(approval_id), + reason=reason, + timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), + expiration_time=RDFDatetime( + expiration_time + ).AsMicrosecondsSinceEpoch(), + notified_users=notified_users, + email_cc_addresses=cc_emails, + ) + + if subject_client_id is not None: + request.subject_id = db_utils.IntToClientID(subject_client_id) + request.approval_type = _APPROVAL_TYPE_CLIENT + elif subject_hunt_id is not None: + request.subject_id = db_utils.IntToHuntID(subject_hunt_id) + request.approval_type = _APPROVAL_TYPE_HUNT + elif subject_cron_job_id is not None: + request.subject_id = subject_cron_job_id + request.approval_type = _APPROVAL_TYPE_CRON_JOB + else: + # This should not happen as the condition to one of these being always + # set if enforced by the database schema. + message = "No subject set for approval '%s' of user '%s'" + logging.error(message, approval_id, username) + + for grantor, creation_time in grants: + grant = objects_pb2.ApprovalGrant() + grant.grantor_username = grantor + grant.timestamp = RDFDatetime(creation_time).AsMicrosecondsSinceEpoch() + request.grants.add().CopyFrom(grant) + + requests.append(request) + return requests @db_utils.CallLogged @@ -125,6 +404,16 @@ def GrantApproval( grantor_username: str, ) -> None: """Grants approval for a given request using given username.""" + row = { + "Requestor": requestor_username, + "ApprovalId": _UnhexApprovalID(approval_id), + "Grantor": grantor_username, + # TODO: Look into Spanner sequences to generate unique IDs. + "GrantId": random.randint(0, sys.maxsize), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + } + + self.db.Insert(table="ApprovalGrants", row=row, txn_tag="GrantApproval") @db_utils.CallLogged @db_utils.CallAccounted @@ -133,6 +422,24 @@ def WriteUserNotification( notification: objects_pb2.UserNotification, ) -> None: """Writes a notification for a given user.""" + row = { + "Username": notification.username, + # TODO: Look into Spanner sequences to generate unique IDs. + "NotificationId": random.randint(0, sys.maxsize), + "Type": int(notification.notification_type), + "State": int(notification.state), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "Message": notification.message, + } + if notification.reference: + row["Reference"] = notification.reference.SerializeToString() + + try: + self.db.Insert( + table="UserNotifications", row=row, txn_tag="WriteUserNotification" + ) + except NotFound: + raise abstract_db.UnknownGRRUserError(notification.username) @db_utils.CallLogged @db_utils.CallAccounted @@ -145,8 +452,52 @@ def ReadUserNotifications( ] = None, ) -> Sequence[objects_pb2.UserNotification]: """Reads notifications scheduled for a user within a given timerange.""" - - return [] + notifications = [] + + params = { + "username": username, + } + query = """ + SELECT n.Type, n.State, n.CreationTime, + n.Message, n.Reference + FROM UserNotifications AS n + WHERE n.Username = {username} + """ + + if state is not None: + params["state"] = int(state) + query += " AND n.state = {state}" + + if timerange is not None: + begin_time, end_time = timerange + if begin_time is not None: + params["begin_time"] = begin_time.AsDatetime() + query += " AND n.CreationTime >= {begin_time}" + if end_time is not None: + params["end_time"] = end_time.AsDatetime() + query += " AND n.CreationTime <= {end_time}" + + query += " ORDER BY n.CreationTime DESC" + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadUserNotifications" + ): + typ, state, creation_time, message, reference = row + + notification = objects_pb2.UserNotification( + username=username, + notification_type=typ, + state=state, + timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), + message=message, + ) + + if reference: + notification.reference.ParseFromString(reference) + + notifications.append(notification) + + return notifications @db_utils.CallLogged @db_utils.CallAccounted @@ -157,7 +508,37 @@ def UpdateUserNotifications( state: Optional["objects_pb2.UserNotification.State"] = None, ): """Updates existing user notification objects.""" + # `UNNEST` used in the query does not like empty arrays, so we return early + # in such cases. + if not timestamps: + return + + params = { + "username": username, + "state": int(state), + } + + param_placeholders = ", ".join([f"{{ts{i}}}" for i in range(len(timestamps))]) + for i, timestamp in enumerate(timestamps): + param_name = f"ts{i}" + params[param_name] = timestamp.AsDatetime() + + query = f""" + UPDATE UserNotifications n + SET n.State = {state} + WHERE n.Username = {username} + AND n.CreationTime IN ({param_placeholders}) + """ + + self.db.ParamExecute(query, params, txn_tag="UpdateUserNotifications") + + +def _HexApprovalID(approval_id: int) -> str: + return f"{approval_id:016x}" + +def _UnhexApprovalID(approval_id: str) -> int: + return int(approval_id, base=16) diff --git a/grr/server/grr_response_server/databases/spanner_users_test.py b/grr/server/grr_response_server/databases/spanner_users_test.py new file mode 100644 index 000000000..2ca44bb93 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_users_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_users_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseUsersTest( + db_users_test.DatabaseTestUsersMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 1dfc197f9..88b206db3 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -18,6 +18,8 @@ from typing import Type from typing import TypeVar +from concurrent import futures +from google.cloud import pubsub_v1 from google.cloud import spanner_v1 as spanner_lib from google.cloud.spanner import KeyRange, KeySet from google.cloud.spanner_admin_database_v1.types import spanner_database_admin @@ -33,6 +35,54 @@ _T = TypeVar("_T") +class QueueReceiver: + """ + This stores the callback internally, and will continue to deliver messages to + the callback as long as it is referenced in python code and Stop is not + called. + """ + + def __init__( + self, + queue_type: str, + callback, # : Callable[[spanner.KeyBuilder, List[Any], Any, bytes, + receiver_max_keepalive_seconds: int, + receiver_max_active_callbacks: int, + receiver_max_messages_per_callback: int, + ): + # An optional executor to use. If not specified, a default one with maximum 10 + # threads will be created. + executor = futures.ThreadPoolExecutor(max_workers=receiver_max_messages_per_callback) + # A thread pool-based scheduler. It must not be shared across SubscriberClients. + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) + + subscriber = pubsub_v1.SubscriberClient() + subscription_path = subscriber.subscription_path(project_id, subscription_id) + + def subcallback(message: pubsub_v1.subscriber.message.Message) -> None: + callback() + message.ack() + + flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) + + streaming_pull_future = subscriber.subscribe( + subscription_path, callback=subcallback, scheduler=scheduler, flow_control=flow_control + ) + + # Wrap subscriber in a 'with' block to automatically call close() when done. + with subscriber: + try: + # When `timeout` is not set, result() will block indefinitely, + # unless an exception is encountered first. + streaming_pull_future.result(timeout=receiver_max_keepalive_seconds) + except TimeoutError: + streaming_pull_future.cancel() # Trigger the shutdown. + streaming_pull_future.result() # Block until the shutdown is complete. + + def Stop(self): + streaming_pull_future.cancel() + + class Database: """A wrapper around the PySpanner class. @@ -466,4 +516,4 @@ def ReadSet( keyset=rows ) - return results \ No newline at end of file + return results From 4c9113dccd5262ac00a8ff17334187c1e1c98495 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sun, 1 Jun 2025 16:36:05 +0000 Subject: [PATCH 012/168] Adds PubSub to Database wrapper --- .../grr_response_server/databases/spanner.py | 17 +- .../databases/spanner_flows.py | 26 +-- .../databases/spanner_test_lib.py | 67 +++----- .../databases/spanner_utils.py | 134 +++++++++++++--- .../databases/spanner_utils_test.py | 151 ++++++++++-------- 5 files changed, 249 insertions(+), 146 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index a2265304f..b25610dd8 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -1,4 +1,4 @@ -# Imports the Google Cloud Client Library. + from google.cloud.spanner import Client from grr_response_core.lib import rdfvalue @@ -26,7 +26,6 @@ class SpannerDB( spanner_blob_keys.BlobKeysMixin, spanner_blob_references.BlobReferencesMixin, spanner_clients.ClientsMixin, - spanner_signed_commands.SignedCommandsMixin, spanner_cron_jobs.CronJobsMixin, spanner_events.EventsMixin, spanner_flows.FlowsMixin, @@ -34,6 +33,7 @@ class SpannerDB( spanner_hunts.HuntsMixin, spanner_paths.PathsMixin, spanner_signed_binaries.SignedBinariesMixin, + spanner_signed_commands.SignedCommandsMixin, spanner_users.UsersMixin, spanner_yara.YaraMixin, db_module.Database, @@ -52,11 +52,18 @@ def FromConfig(cls) -> "Database": Returns: A GRR database instance. """ - spanner_client = Client(onfig.CONFIG["ProjectID"]) + project_id = config.CONFIG["ProjectID"] + spanner_client = Client(project_id) spanner_instance = spanner_client.instance(config.CONFIG["Spanner.instance"]) spanner_database = spanner_instance.database(config.CONFIG["Spanner.database"]) + msg_handler_top_id = config.CONFIG["MessageHandler.topic_id"] + msg_handler_sub_id = config.CONFIG["MessageHandler.subscription_id"] + flow_processing_top_id = config.CONFIG["FlowProcessing.topic_id"] + flow_processing_sub_id = config.CONFIG["FlowProcessing.subscription_id"] - return cls(spanner_utils.Database(spanner_database)) + return cls(spanner_utils.Database(spanner_database, project_id, + msg_handler_top_id, msg_handler_sub_id, + flow_processing_top_id, flow_processing_sub_id)) def Now(self) -> rdfvalue.RDFDatetime: """Retrieves current time as reported by the database.""" @@ -65,4 +72,4 @@ def Now(self) -> rdfvalue.RDFDatetime: def MinTimestamp(self) -> rdfvalue.RDFDatetime: """Returns minimal timestamp allowed by the DB.""" - return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) \ No newline at end of file + return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 45019fe5c..86cab5cdb 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -229,22 +229,22 @@ class FlowsMixin: @property def _flow_processing_request_receiver( self, - ) -> Optional[spanner_lib.QueueReceiver]: + ) -> Optional[spanner_utils.RequestQueue]: return getattr(self, "__flow_processing_request_receiver", None) @_flow_processing_request_receiver.setter def _flow_processing_request_receiver( - self, value: Optional[spanner_lib.QueueReceiver] + self, value: Optional[spanner_utils.RequestQueue] ) -> None: setattr(self, "__flow_processing_request_receiver", value) @property - def _message_handler_receiver(self) -> Optional[spanner_lib.QueueReceiver]: + def _message_handler_receiver(self) -> Optional[spanner_utils.RequestQueue]: return getattr(self, "__message_handler_receiver", None) @_message_handler_receiver.setter def _message_handler_receiver( - self, value: Optional[spanner_lib.QueueReceiver] + self, value: Optional[spanner_utils.RequestQueue] ) -> None: setattr(self, "__message_handler_receiver", value) @@ -952,7 +952,7 @@ def _ReadRequestsInfo( flows_pb2.FlowIterator, ], ], - txn: spanner_utils.Transaction, + txn, ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str], set[_RequestKey]]: """For given responses returns data about corresponding requests. @@ -1040,7 +1040,7 @@ def _BuildResponseWrites( flows_pb2.FlowIterator, ], ], - txn: spanner_utils.Transaction, + txn, ) -> None: """Builds the writes to store given responses in the db. @@ -1079,7 +1079,7 @@ def _BuildResponseWrites( mut.InsertOrUpdate("FlowResponses", row) def _BuildExpectedUpdates( - self, updates: dict[_RequestKey, int], txn: spanner_utils.Transaction + self, updates: dict[_RequestKey, int], txn ) -> None: """Builds updates for requests with known number of expected responses. @@ -1218,7 +1218,7 @@ def Txn( def _GetFlowResponsesPerRequestCounts( self, request_keys: Iterable[_RequestKey], - txn: spanner_utils.SnapshotTransaction, + txn, ) -> dict[_RequestKey, int]: """Gets counts of already received responses for given requests. @@ -1274,7 +1274,7 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( self, requests: set[_RequestKey], callback_states: dict[_RequestKey, str], - txn: spanner_utils.Transaction, + txn, ) -> tuple[ set[_RequestKey], set[tuple[_FlowKey, Optional[rdfvalue.RDFDatetime]]] ]: @@ -1363,7 +1363,7 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( return requests_to_mark, requests_to_notify def _BuildNeedsProcessingUpdates( - self, requests: set[_RequestKey], txn: spanner_utils.Transaction + self, requests: set[_RequestKey], txn ) -> None: """Builds updates for requests that have their NeedsProcessing flag set. @@ -1386,7 +1386,7 @@ def _UpdateNeedsProcessingAndWriteFlowProcessingRequests( self, requests_ready_for_processing: set[_RequestKey], callback_state_by_request: dict[_RequestKey, str], - txn: spanner_utils.Transaction, + txn, ) -> None: """Updates requests needs-processing flags, writes processing requests. @@ -2198,7 +2198,7 @@ def ReadMessageHandlerRequests( def _BuildDeleteMessageHandlerRequestWrites( self, - txn: spanner_utils.Transaction, + txn, requests: Iterable[objects_pb2.MessageHandlerRequest], ) -> None: """Deletes given requests within a given transaction.""" @@ -2324,7 +2324,7 @@ def UnregisterMessageHandler( self._message_handler_receiver = None def _ReadHuntState( - self, txn: spanner_utils.Transaction, hunt_id: str + self, txn, hunt_id: str ) -> Optional[int]: try: row = txn.Read(table="Hunts", key=(IntHuntID(hunt_id),), cols=("State",)) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 4ca66bdbc..c84db358b 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -94,8 +94,8 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: def TearDown() -> None: """Tears down the Spanner testing environment. - This must be called once per process after all the tests. A `tearDownModule` - is a perfect place for it. + This must be called once per process after all the tests. + A `tearDownModule` is a perfect place for it. """ if _TEST_DB is not None: # Create a client @@ -112,51 +112,38 @@ class TestCase(absltest.TestCase): def setUp(self): super().setUp() - self.raw_db = spanner_utils.Database(CreateTestDatabase()) + project_id = _GetEnvironOrSkip("PROJECT_ID") + msg_handler_top_id = _GetEnvironOrSkip("MESSAGE_HANDLER_TOPIC_ID") + msg_handler_sub_id = _GetEnvironOrSkip("MESSAGE_HANDLER_SUBSCRIPTION_ID") + flow_processing_top_id = _GetEnvironOrSkip("FLOW_PROCESSING_TOPIC_ID") + flow_processing_sub_id = _GetEnvironOrSkip("FLOW_PROCESSING_SUBSCRIPTION_ID") - db = spanner_db.SpannerDB(self.raw_db) - self.db = abstract_db.DatabaseValidationWrapper(db) + _clean_database() + self.raw_db = spanner_utils.Database(_TEST_DB, project_id, + msg_handler_top_id, msg_handler_sub_id, + flow_processing_top_id, flow_processing_sub_id) -def CreateTestDatabase() -> spanner_lib.database: - """Creates an empty test spanner database. + spannerDB = spanner_db.SpannerDB(self.raw_db) + self.db = abstract_db.DatabaseValidationWrapper(spannerDB) - Returns: - A PySpanner instance pointing to the created database. - """ - #if _TEST_DB is None: - # raise AssertionError("Spanner test database not initialized") - db = spanner_utils.Database(_TEST_DB) +def _get_table_names(db): + with db.snapshot() as snapshot: + query_result = snapshot.execute_sql( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';" + ) + table_names = set() + for row in query_result: + table_names.add(row[0]) - query = """ - SELECT t.table_name - FROM information_schema.tables AS t - WHERE t.table_catalog = "" - AND t.table_schema = "" - ORDER BY t.table_name ASC - """ + return table_names - table_names = set() - for (table_name,) in db.Query(query): - table_names.add(table_name) - query = """ - SELECT v.table_name - FROM information_schema.views AS v - WHERE v.table_catalog = "" - AND v.table_schema = "" - ORDER BY v.table_name ASC - """ - view_names = set() - for (view_name,) in db.Query(query): - view_names.add(view_name) - - # `table_names` is a superset of `view_names` (since the `VIEWS` table is, - # well, just a view to the `TABLES` table [1]). Since deleting from views - # makes no sense, we have to exclude them from the tables we want to clean. - table_names -= view_names - +def _clean_database() -> None: + """Creates an empty test spanner database.""" + + table_names = _get_table_names(_TEST_DB) keyset = KeySet(all_=True) with _TEST_DB.batch() as batch: @@ -164,6 +151,4 @@ def CreateTestDatabase() -> spanner_lib.database: for table_name in table_names: batch.delete(table_name, keyset) - return _TEST_DB - _TEST_DB: spanner_lib.database = None \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 88b206db3..8d9d249a9 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -4,6 +4,7 @@ import datetime import decimal import re +import time from typing import Any from typing import Callable @@ -19,8 +20,10 @@ from typing import TypeVar from concurrent import futures + from google.cloud import pubsub_v1 from google.cloud import spanner_v1 as spanner_lib + from google.cloud.spanner import KeyRange, KeySet from google.cloud.spanner_admin_database_v1.types import spanner_database_admin from google.cloud.spanner_v1 import Mutation, param_types @@ -30,12 +33,16 @@ from grr_response_core.lib.util import collection from grr_response_core.lib.util import iterator +from grr_response_proto import flows_pb2 +from grr_response_proto import objects_pb2 + Row = Tuple[Any, ...] Cursor = Iterator[Row] _T = TypeVar("_T") -class QueueReceiver: + +class RequestQueue: """ This stores the callback internally, and will continue to deliver messages to the callback as long as it is referenced in python code and Stop is not @@ -44,43 +51,64 @@ class QueueReceiver: def __init__( self, - queue_type: str, - callback, # : Callable[[spanner.KeyBuilder, List[Any], Any, bytes, + project_id: str, + topic_id: str, + subscription_id: str, + callback, # : Callable receiver_max_keepalive_seconds: int, receiver_max_active_callbacks: int, receiver_max_messages_per_callback: int, ): + self.project_id = project_id + self.topic_id = topic_id + self.subscriber = pubsub_v1.SubscriberClient() + subscription_path = self.subscriber.subscription_path(project_id, subscription_id) + + def queueCallback(message: pubsub_v1.subscriber.message.Message) -> None: + callback(message.data) + # An optional executor to use. If not specified, a default one with maximum 10 # threads will be created. executor = futures.ThreadPoolExecutor(max_workers=receiver_max_messages_per_callback) # A thread pool-based scheduler. It must not be shared across SubscriberClients. scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) - subscriber = pubsub_v1.SubscriberClient() - subscription_path = subscriber.subscription_path(project_id, subscription_id) - - def subcallback(message: pubsub_v1.subscriber.message.Message) -> None: - callback() - message.ack() - flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) - streaming_pull_future = subscriber.subscribe( - subscription_path, callback=subcallback, scheduler=scheduler, flow_control=flow_control + self.streaming_pull_future = self.subscriber.subscribe( + subscription_path, callback=queueCallback, scheduler=scheduler, flow_control=flow_control ) - # Wrap subscriber in a 'with' block to automatically call close() when done. - with subscriber: - try: - # When `timeout` is not set, result() will block indefinitely, - # unless an exception is encountered first. - streaming_pull_future.result(timeout=receiver_max_keepalive_seconds) - except TimeoutError: - streaming_pull_future.cancel() # Trigger the shutdown. - streaming_pull_future.result() # Block until the shutdown is complete. + def publish(self, data: str) -> None: + publisher = pubsub_v1.PublisherClient() + topic_path = publisher.topic_path(self.project_id, self.topic_id) + publisher.publish(topic_path, data.encode("utf-8")) + + def pull(self): + # Create a client + client = pubsub_v1.SubscriberClient() + # Initialize request argument(s) + request = pubsub_v1.PullRequest( + subscription=subscription_path, + return_immediately=True, + max_messages=10000, + ) + # Make the request + response = client.pull(request=request) def Stop(self): - streaming_pull_future.cancel() + if self.streaming_pull_future: + try: + self.streaming_pull_future.cancel() + except asyncio.CancelledError: + pass # Expected when cancelling + except Exception as e: + print(f"Warning: Exception while cancelling future: {e}") + + if self.subscriber: + self.subscriber.close() + # Add a small sleep to allow threads to fully terminate + time.sleep(0.1) # Give a short buffer for threads to clean up class Database: @@ -93,9 +121,16 @@ class Database: _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") - def __init__(self, pyspanner: spanner_lib.database) -> None: + def __init__(self, pyspanner: spanner_lib.database, project_id: str, + msg_handler_top_id: str, msg_handler_sub_id: str, + flow_processing_top_id: str, flow_processing_sub_id: str) -> None: super().__init__() self._pyspanner = pyspanner + self.project_id = project_id + self.msg_handler_top_id = msg_handler_top_id + self.msg_handler_sub_id = msg_handler_sub_id + self.flow_processing_top_id = flow_processing_top_id + self.flow_processing_sub_id = flow_processing_sub_id def _parametrize(self, query: str, names: Iterable[str]) -> str: match = self._PYSPANNER_PARAM_REGEX.search(query) @@ -517,3 +552,56 @@ def ReadSet( ) return results + + def NewRequestQueue( + self, + queue: str, + callback: Callable[[Sequence[Any], bytes], None], + receiver_max_keepalive_seconds: Optional[int] = None, + receiver_max_active_callbacks: Optional[int] = None, + receiver_max_messages_per_callback: Optional[int] = None, + ) -> RequestQueue: + """Registers a queue callback in a given queue. + + Args: + queue: Name of the queue. + callback: Callback with 2 args (expanded_key, payload). expanded_key is a + sequence where each item corresponds to an item of the message's key. + Payload is the message itself, serialized as bytes. + receiver_max_keepalive_seconds: Num seconds before the lease on the + message expires (if the message is not acked before the lease expires, + it will be delivered again). + receiver_max_active_callbacks: Max number of callback to be called in + parallel. + receiver_max_messages_per_callback: Max messages to receive per callback. + + Returns: + New queue receiver objects. + """ + + def _Callback(data: str): + data = data.decode("utf-8") + if queue == "": + payload = data + elif queue == "MessageHandler": + payload = objects_pb2.MessageHandlerRequest.ParseFromString(data) + elif queue == "FlowProcessing": + payload = flows_pb2.FlowProcessingRequest.ParseFromString(data) + callback(payload) + + if queue == "MessageHandler" or queue == "": + topic_id = self.msg_handler_top_id + subscription_id = self.msg_handler_sub_id + elif queue == "FlowProcessing": + topic_id = self.flow_processing_top_id + subscription_id = self.flow_processing_sub_id + + return RequestQueue( + self.project_id, + topic_id, + subscription_id, + _Callback, + receiver_max_keepalive_seconds=receiver_max_keepalive_seconds, + receiver_max_active_callbacks=receiver_max_active_callbacks, + receiver_max_messages_per_callback=receiver_max_messages_per_callback, + ) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index a2bd3fafd..3c7df8bdd 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -23,14 +23,12 @@ def setUpModule() -> None: def tearDownModule() -> None: spanner_test_lib.TearDown() -class DatabaseTest(absltest.TestCase): + +class DatabaseTest(spanner_test_lib.TestCase): def setUp(self): super().setUp() - pyspanner = spanner_test_lib.CreateTestDatabase() - self.db = spanner_utils.Database(pyspanner) - ####################################### # Transact Tests ####################################### @@ -47,35 +45,35 @@ def TransactionRead(txn) -> List[Any]: result = list(txn.execute_sql("SELECT t.Key FROM Table AS t")) return result - self.db.Transact(TransactionWrite) - results = self.db.Transact(TransactionRead) + self.raw_db.Transact(TransactionWrite) + results = self.raw_db.Transact(TransactionRead) self.assertCountEqual(results, [["foo"], ["bar"]]) ####################################### # Query Tests ####################################### def testQuerySimple(self): - results = list(self.db.Query("SELECT 'foo', 42")) + results = list(self.raw_db.Query("SELECT 'foo', 42")) self.assertEqual(results, [["foo", 42]]) def testQueryWithPlaceholders(self): - results = list(self.db.Query("SELECT '{}', '@p0'")) + results = list(self.raw_db.Query("SELECT '{}', '@p0'")) self.assertEqual(results, [["{}", "@p0"]]) ####################################### # QuerySingle Tests ####################################### def testQuerySingle(self): - result = self.db.QuerySingle("SELECT 'foo', 42") + result = self.raw_db.QuerySingle("SELECT 'foo', 42") self.assertEqual(result, ["foo", 42]) def testQuerySingleEmpty(self): with self.assertRaises(NotFound): - self.db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([])") + self.raw_db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([])") def testQuerySingleMultiple(self): with self.assertRaises(ValueError): - self.db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([1, 2])") + self.raw_db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([1, 2])") ####################################### # ParamQuery Tests @@ -84,7 +82,7 @@ def testParamQuerySingleParam(self): query = "SELECT {abc}" params = {"abc": 1337} - results = list(self.db.ParamQuery(query, params)) + results = list(self.raw_db.ParamQuery(query, params)) self.assertEqual(results, [[1337,]]) def testParamQueryMultipleParams(self): @@ -93,30 +91,30 @@ def testParamQueryMultipleParams(self): query = "SELECT {int}, {str}, {timestamp}" params = {"int": 1337, "str": "quux", "timestamp": timestamp} - results = list(self.db.ParamQuery(query, params)) + results = list(self.raw_db.ParamQuery(query, params)) self.assertEqual(results, [[1337, "quux", timestamp]]) def testParamQueryMissingParams(self): with self.assertRaisesRegex(KeyError, "bar"): - self.db.ParamQuery("SELECT {foo}, {bar}", {"foo": 42}) + self.raw_db.ParamQuery("SELECT {foo}, {bar}", {"foo": 42}) def testParamQueryExtraParams(self): query = "SELECT 42, {foo}" params = {"foo": "foo", "bar": "bar"} - results = list(self.db.ParamQuery(query, params)) + results = list(self.raw_db.ParamQuery(query, params)) self.assertEqual(results, [[42, "foo"]]) def testParamQueryIllegalSequence(self): with self.assertRaisesRegex(ValueError, "@p1337"): - self.db.ParamQuery("SELECT @p1337", {}) + self.raw_db.ParamQuery("SELECT @p1337", {}) def testParamQueryLegalSequence(self): - results = list(self.db.ParamQuery("SELECT '@p', '@q'", {})) + results = list(self.raw_db.ParamQuery("SELECT '@p', '@q'", {})) self.assertEqual(results, [["@p", "@q"]]) def testParamQueryBraceEscape(self): - results = list(self.db.ParamQuery("SELECT '{{foo}}'", {})) + results = list(self.raw_db.ParamQuery("SELECT '{{foo}}'", {})) self.assertEqual(results, [["{foo}",]]) ####################################### @@ -129,7 +127,7 @@ def testParamExecuteSingleParam(self): """ params = {"key": "foo"} - self.db.ParamExecute(query, params) + self.raw_db.ParamExecute(query, params) ####################################### # ParamQuerySingle Tests @@ -138,7 +136,7 @@ def testParamQuerySingle(self): query = "SELECT {str}, {int}" params = {"str": "foo", "int": 42} - result = self.db.ParamQuerySingle(query, params) + result = self.raw_db.ParamQuerySingle(query, params) self.assertEqual(result, ["foo", 42]) def testParamQuerySingleEmpty(self): @@ -146,26 +144,26 @@ def testParamQuerySingleEmpty(self): params = {"str": "foo", "int": 42} with self.assertRaises(NotFound): - self.db.ParamQuerySingle(query, params) + self.raw_db.ParamQuerySingle(query, params) def testParamQuerySingleMultiple(self): query = "SELECT {str}, {int} FROM UNNEST([1, 2])" params = {"str": "foo", "int": 42} with self.assertRaises(ValueError): - self.db.ParamQuerySingle(query, params) + self.raw_db.ParamQuerySingle(query, params) ####################################### # ExecutePartitioned Tests ####################################### def testExecutePartitioned(self): - self.db.Insert(table="Table", row={"Key": "foo"}) - self.db.Insert(table="Table", row={"Key": "bar"}) - self.db.Insert(table="Table", row={"Key": "baz"}) + self.raw_db.Insert(table="Table", row={"Key": "foo"}) + self.raw_db.Insert(table="Table", row={"Key": "bar"}) + self.raw_db.Insert(table="Table", row={"Key": "baz"}) - self.db.ExecutePartitioned("DELETE FROM Table AS t WHERE t.Key LIKE 'ba%'") + self.raw_db.ExecutePartitioned("DELETE FROM Table AS t WHERE t.Key LIKE 'ba%'") - results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) self.assertLen(results, 1) self.assertEqual(results[0], ["foo",]) @@ -173,25 +171,25 @@ def testExecutePartitioned(self): # Insert Tests ####################################### def testInsert(self): - self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) - self.db.Insert(table="Table", row={"Key": "bar", "Column": "bar@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "bar", "Column": "bar@x.com"}) - results = list(self.db.Query("SELECT t.Column FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Column FROM Table AS t")) self.assertCountEqual(results, [["foo@x.com",], ["bar@x.com",]]) ####################################### # Update Tests ####################################### def testUpdate(self): - self.db.Insert(table="Table", row={"Key": "foo", "Column": "bar@y.com"}) - self.db.Update(table="Table", row={"Key": "foo", "Column": "qux@y.com"}) + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "bar@y.com"}) + self.raw_db.Update(table="Table", row={"Key": "foo", "Column": "qux@y.com"}) - results = list(self.db.Query("SELECT t.Column FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Column FROM Table AS t")) self.assertEqual(results, [["qux@y.com",]]) def testUpdateNotExisting(self): with self.assertRaises(NotFound): - self.db.Update(table="Table", row={"Key": "foo", "Column": "x@y.com"}) + self.raw_db.Update(table="Table", row={"Key": "foo", "Column": "x@y.com"}) ####################################### # InsertOrUpdate Tests @@ -200,50 +198,50 @@ def testInsertOrUpdate(self): row = {"Key": "foo"} row["Column"] = "bar@example.com" - self.db.InsertOrUpdate(table="Table", row=row) + self.raw_db.InsertOrUpdate(table="Table", row=row) row["Column"] = "baz@example.com" - self.db.InsertOrUpdate(table="Table", row=row) + self.raw_db.InsertOrUpdate(table="Table", row=row) - results = list(self.db.Query("SELECT t.Column FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Column FROM Table AS t")) self.assertEqual(results, [["baz@example.com",]]) ####################################### # Delete Tests ####################################### def testDelete(self): - self.db.InsertOrUpdate(table="Table", row={"Key": "foo"}) - self.db.Delete(table="Table", key=("foo",)) + self.raw_db.InsertOrUpdate(table="Table", row={"Key": "foo"}) + self.raw_db.Delete(table="Table", key=("foo",)) - results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) self.assertEmpty(results) def testDeleteSingle(self): - self.db.Insert(table="Table", row={"Key": "foo"}) - self.db.InsertOrUpdate(table="Table", row={"Key": "bar"}) - self.db.Delete(table="Table", key=("foo",)) + self.raw_db.Insert(table="Table", row={"Key": "foo"}) + self.raw_db.InsertOrUpdate(table="Table", row={"Key": "bar"}) + self.raw_db.Delete(table="Table", key=("foo",)) - results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) self.assertEqual(results, [["bar",]]) def testDeleteNotExisting(self): # Should not raise. - self.db.Delete(table="Table", key=("foo",)) + self.raw_db.Delete(table="Table", key=("foo",)) ####################################### # DeleteWithPrefix Tests ####################################### def testDeleteWithPrefix(self): - self.db.Insert(table="Table", row={"Key": "foo"}) - self.db.Insert(table="Table", row={"Key": "quux"}) + self.raw_db.Insert(table="Table", row={"Key": "foo"}) + self.raw_db.Insert(table="Table", row={"Key": "quux"}) - self.db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "bar"}) - self.db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "baz"}) - self.db.Insert(table="Subtable", row={"Key": "quux", "Subkey": "norf"}) + self.raw_db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "bar"}) + self.raw_db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "baz"}) + self.raw_db.Insert(table="Subtable", row={"Key": "quux", "Subkey": "norf"}) - self.db.DeleteWithPrefix(table="Subtable", key_prefix=["foo"]) + self.raw_db.DeleteWithPrefix(table="Subtable", key_prefix=["foo"]) - results = list(self.db.Query("SELECT t.Key, t.Subkey FROM Subtable AS t")) + results = list(self.raw_db.Query("SELECT t.Key, t.Subkey FROM Subtable AS t")) self.assertLen(results, 1) self.assertEqual(results[0], ["quux", "norf"]) @@ -251,33 +249,33 @@ def testDeleteWithPrefix(self): # Read Tests ####################################### def testReadSimple(self): - self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) - result = self.db.Read(table="Table", key=("foo",), cols=("Column",)) + result = self.raw_db.Read(table="Table", key=("foo",), cols=("Column",)) self.assertEqual(result, ['foo@x.com']) def testReadNotExisting(self): with self.assertRaises(NotFound): - self.db.Read(table="Table", key=("foo",), cols=("Column",)) + self.raw_db.Read(table="Table", key=("foo",), cols=("Column",)) ####################################### # ReadSet Tests ####################################### def testReadSetEmpty(self): - self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) rows = spanner_lib.KeySet() - results = list(self.db.ReadSet(table="Table", rows=rows, cols=("Column",))) + results = list(self.raw_db.ReadSet(table="Table", rows=rows, cols=("Column",))) self.assertEmpty(results) def testReadSetSimple(self): - self.db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) - self.db.Insert(table="Table", row={"Key": "bar", "Column": "bar@y.com"}) - self.db.Insert(table="Table", row={"Key": "baz", "Column": "baz@z.com"}) + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "bar", "Column": "bar@y.com"}) + self.raw_db.Insert(table="Table", row={"Key": "baz", "Column": "baz@z.com"}) keyset = spanner_lib.KeySet(keys=[["foo"], ["bar"]]) - results = list(self.db.ReadSet(table="Table", rows=keyset, cols=("Column",))) + results = list(self.raw_db.ReadSet(table="Table", rows=keyset, cols=("Column",))) self.assertIn(["foo@x.com"], results) self.assertIn(["bar@y.com"], results) @@ -300,9 +298,9 @@ def Mutation(txn) -> None: values=[("bar",)] ) - self.db.Mutate(Mutation) + self.raw_db.Mutate(Mutation) - results = list(self.db.Query("SELECT t.Key FROM Table AS t")) + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) self.assertCountEqual(results, [["foo",], ["bar",]]) def testMutateException(self): @@ -316,7 +314,32 @@ def Mutation(txn) -> None: raise RuntimeError() with self.assertRaises(RuntimeError): - self.db.Mutate(Mutation) + self.raw_db.Mutate(Mutation) + + ####################################### + # Queue Tests + ####################################### + def testNewRequestQueueCallbackGetsCalled(self): + callback_func = mock.Mock() + + requestQueue = self.raw_db.NewRequestQueue( + "", + callback_func, + receiver_max_keepalive_seconds=10, + receiver_max_active_callbacks=1, + receiver_max_messages_per_callback=1, + ) + + start_time = time.time() + requestQueue.publish("foo") + + while callback_func.call_count == 0: + time.sleep(0.1) + if time.time() - start_time > 10: + self.fail("Request was not processed in time.") + + callback_func.assert_called_once_with("foo") + requestQueue.Stop() if __name__ == "__main__": absltest.main() \ No newline at end of file From fc32e13dd176ca71e35f87fb02aa5b963cda84a1 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 2 Jun 2025 17:34:06 +0000 Subject: [PATCH 013/168] Adds PubSub pull --- .../databases/spanner_utils.py | 71 ++++++++++--------- .../databases/spanner_utils_test.py | 41 ++++++++++- 2 files changed, 74 insertions(+), 38 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 8d9d249a9..806d88d6b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -41,7 +41,6 @@ _T = TypeVar("_T") - class RequestQueue: """ This stores the callback internally, and will continue to deliver messages to @@ -54,6 +53,7 @@ def __init__( project_id: str, topic_id: str, subscription_id: str, + do_subscribe: bool, callback, # : Callable receiver_max_keepalive_seconds: int, receiver_max_active_callbacks: int, @@ -62,42 +62,44 @@ def __init__( self.project_id = project_id self.topic_id = topic_id self.subscriber = pubsub_v1.SubscriberClient() - subscription_path = self.subscriber.subscription_path(project_id, subscription_id) - - def queueCallback(message: pubsub_v1.subscriber.message.Message) -> None: - callback(message.data) + self.subscription_path = self.subscriber.subscription_path(project_id, subscription_id) + self.do_subscribe = do_subscribe - # An optional executor to use. If not specified, a default one with maximum 10 - # threads will be created. - executor = futures.ThreadPoolExecutor(max_workers=receiver_max_messages_per_callback) - # A thread pool-based scheduler. It must not be shared across SubscriberClients. - scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) + if do_subscribe: + # An optional executor to use. If not specified, a default one with maximum 10 + # threads will be created. + executor = futures.ThreadPoolExecutor(max_workers=receiver_max_active_callbacks) + # A thread pool-based scheduler. It must not be shared across SubscriberClients. + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) - flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) + flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) - self.streaming_pull_future = self.subscriber.subscribe( - subscription_path, callback=queueCallback, scheduler=scheduler, flow_control=flow_control - ) + self.streaming_pull_future = self.subscriber.subscribe( + self.subscription_path, callback=callback, scheduler=scheduler, flow_control=flow_control + ) def publish(self, data: str) -> None: publisher = pubsub_v1.PublisherClient() topic_path = publisher.topic_path(self.project_id, self.topic_id) publisher.publish(topic_path, data.encode("utf-8")) - def pull(self): - # Create a client - client = pubsub_v1.SubscriberClient() - # Initialize request argument(s) - request = pubsub_v1.PullRequest( - subscription=subscription_path, - return_immediately=True, - max_messages=10000, + def ack(self, ack_ids: [str]) -> None: + self.subscriber.acknowledge( + request={"subscription": self.subscription_path, "ack_ids": ack_ids} ) + + def pull(self): # Make the request - response = client.pull(request=request) + response = self.subscriber.pull( + request={ + "subscription": self.subscription_path, + "max_messages": 10000, + }, + ) + return response.received_messages - def Stop(self): - if self.streaming_pull_future: + def stop(self): + if self.do_subscribe and self.streaming_pull_future: try: self.streaming_pull_future.cancel() except asyncio.CancelledError: @@ -110,7 +112,6 @@ def Stop(self): # Add a small sleep to allow threads to fully terminate time.sleep(0.1) # Give a short buffer for threads to clean up - class Database: """A wrapper around the PySpanner class. @@ -556,6 +557,7 @@ def ReadSet( def NewRequestQueue( self, queue: str, + do_subscribe: bool, callback: Callable[[Sequence[Any], bytes], None], receiver_max_keepalive_seconds: Optional[int] = None, receiver_max_active_callbacks: Optional[int] = None, @@ -579,15 +581,13 @@ def NewRequestQueue( New queue receiver objects. """ - def _Callback(data: str): - data = data.decode("utf-8") - if queue == "": - payload = data - elif queue == "MessageHandler": - payload = objects_pb2.MessageHandlerRequest.ParseFromString(data) - elif queue == "FlowProcessing": - payload = flows_pb2.FlowProcessingRequest.ParseFromString(data) - callback(payload) + def _Callback(message: pubsub_v1.subscriber.message.Message): + payload = message.data.decode("utf-8") + #if queue == "MessageHandler": + # payload = objects_pb2.MessageHandlerRequest.ParseFromString(data) + #elif queue == "FlowProcessing": + # payload = flows_pb2.FlowProcessingRequest.ParseFromString(data) + callback(payload=payload, ack_id=message.ack_id, publish_time=message.publish_time) if queue == "MessageHandler" or queue == "": topic_id = self.msg_handler_top_id @@ -600,6 +600,7 @@ def _Callback(data: str): self.project_id, topic_id, subscription_id, + do_subscribe, _Callback, receiver_max_keepalive_seconds=receiver_max_keepalive_seconds, receiver_max_active_callbacks=receiver_max_active_callbacks, diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index 3c7df8bdd..ef47f6169 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -324,6 +324,7 @@ def testNewRequestQueueCallbackGetsCalled(self): requestQueue = self.raw_db.NewRequestQueue( "", + True, callback_func, receiver_max_keepalive_seconds=10, receiver_max_active_callbacks=1, @@ -338,8 +339,42 @@ def testNewRequestQueueCallbackGetsCalled(self): if time.time() - start_time > 10: self.fail("Request was not processed in time.") - callback_func.assert_called_once_with("foo") - requestQueue.Stop() + callback_func.assert_called_once() + result = callback_func.call_args.kwargs + ack_ids = [] + ack_ids.append(result["ack_id"]) + requestQueue.ack(ack_ids) + requestQueue.stop() + + def testNewRequestQueueCount(self): + callback_func = mock.Mock() + + requestQueue = self.raw_db.NewRequestQueue( + "", + False, + callback_func, + receiver_max_keepalive_seconds=1, + receiver_max_active_callbacks=1, + receiver_max_messages_per_callback=1, + ) + + start_time = time.time() + requestQueue.publish("foo") + requestQueue.publish("bar") + + results = {} + + while len(results) < 2 and time.time() - start_time < 10: + time.sleep(0.1) + messages = requestQueue.pull() + ack_ids = [] + for msg in messages: + results.update({msg.message.message_id: msg}) + ack_ids.append(msg.ack_id) + requestQueue.ack(ack_ids) + + self.assertLen(results, 2) + requestQueue.stop() if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 6aff24baafbb5b5a8f70c8fa87f4ae24207cd31e Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 3 Jun 2025 11:27:05 +0000 Subject: [PATCH 014/168] Refactor RequestQueue class --- .../databases/spanner_utils.py | 115 +++++++++--------- .../databases/spanner_utils_test.py | 39 +++--- 2 files changed, 74 insertions(+), 80 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 806d88d6b..06bb7dad5 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -50,56 +50,28 @@ class RequestQueue: def __init__( self, - project_id: str, - topic_id: str, - subscription_id: str, - do_subscribe: bool, + subscriber, + subscription_path: str, callback, # : Callable receiver_max_keepalive_seconds: int, receiver_max_active_callbacks: int, receiver_max_messages_per_callback: int, ): - self.project_id = project_id - self.topic_id = topic_id - self.subscriber = pubsub_v1.SubscriberClient() - self.subscription_path = self.subscriber.subscription_path(project_id, subscription_id) - self.do_subscribe = do_subscribe - - if do_subscribe: - # An optional executor to use. If not specified, a default one with maximum 10 - # threads will be created. - executor = futures.ThreadPoolExecutor(max_workers=receiver_max_active_callbacks) - # A thread pool-based scheduler. It must not be shared across SubscriberClients. - scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) - flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) - - self.streaming_pull_future = self.subscriber.subscribe( - self.subscription_path, callback=callback, scheduler=scheduler, flow_control=flow_control - ) + # An optional executor to use. If not specified, a default one with maximum 10 + # threads will be created. + executor = futures.ThreadPoolExecutor(max_workers=receiver_max_active_callbacks) + # A thread pool-based scheduler. It must not be shared across SubscriberClients. + scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) - def publish(self, data: str) -> None: - publisher = pubsub_v1.PublisherClient() - topic_path = publisher.topic_path(self.project_id, self.topic_id) - publisher.publish(topic_path, data.encode("utf-8")) + flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) - def ack(self, ack_ids: [str]) -> None: - self.subscriber.acknowledge( - request={"subscription": self.subscription_path, "ack_ids": ack_ids} + self.streaming_pull_future = subscriber.subscribe( + subscription_path, callback=callback, scheduler=scheduler, flow_control=flow_control ) - def pull(self): - # Make the request - response = self.subscriber.pull( - request={ - "subscription": self.subscription_path, - "max_messages": 10000, - }, - ) - return response.received_messages - - def stop(self): - if self.do_subscribe and self.streaming_pull_future: + def Stop(self): + if self.streaming_pull_future: try: self.streaming_pull_future.cancel() except asyncio.CancelledError: @@ -107,9 +79,6 @@ def stop(self): except Exception as e: print(f"Warning: Exception while cancelling future: {e}") - if self.subscriber: - self.subscriber.close() - # Add a small sleep to allow threads to fully terminate time.sleep(0.1) # Give a short buffer for threads to clean up class Database: @@ -128,10 +97,12 @@ def __init__(self, pyspanner: spanner_lib.database, project_id: str, super().__init__() self._pyspanner = pyspanner self.project_id = project_id - self.msg_handler_top_id = msg_handler_top_id - self.msg_handler_sub_id = msg_handler_sub_id - self.flow_processing_top_id = flow_processing_top_id - self.flow_processing_sub_id = flow_processing_sub_id + self.publisher = pubsub_v1.PublisherClient() + self.subscriber = pubsub_v1.SubscriberClient() + self.flow_proccessing_sub_path = self.subscriber.subscription_path(project_id, flow_processing_sub_id) + self.flow_processing_top_path = self.publisher.topic_path(project_id, flow_processing_top_id) + self.message_handler_sub_path = self.subscriber.subscription_path(project_id, msg_handler_sub_id) + self.message_handler_top_path = self.publisher.topic_path(project_id, msg_handler_top_id) def _parametrize(self, query: str, names: Iterable[str]) -> str: match = self._PYSPANNER_PARAM_REGEX.search(query) @@ -554,10 +525,46 @@ def ReadSet( return results + def PublishMessageHandlerRequests(self, requests: [str]) -> None: + self.PublishRequests(requests, self.message_handler_top_path) + + def PublishFlowProcessingRequests(self, requests: [str]) -> None: + self.PublishRequests(requests, self.flow_processing_top_path) + + def ReadMessageHandlerRequests(self): + return self.ReadRequests(self.message_handler_sub_path) + + def ReadFlowProcessingRequests(self): + return self.ReadRequests(self.flow_processing_sub_path) + + def AckMessageHandlerRequests(self, ack_ids: [str]) -> None: + self.AckRequests(ack_ids, self.message_handler_sub_path) + + def AckFlowProcessingRequests(self, ack_ids: [str]) -> None: + self.AckRequests(ack_ids, self.flow_proccessing_sub_path) + + def PublishRequests(self, requests: [str], top_path: str) -> None: + for req in requests: + self.publisher.publish(top_path, req.encode("utf-8")) + + def AckRequests(self, ack_ids: [str], sub_path: str) -> None: + self.subscriber.acknowledge( + request={"subscription": sub_path, "ack_ids": ack_ids} + ) + + def ReadRequests(self, sub_path: str): + # Make the request + response = self.subscriber.pull( + request={ + "subscription": sub_path, + "max_messages": 10000, + }, + ) + return response.received_messages + def NewRequestQueue( self, queue: str, - do_subscribe: bool, callback: Callable[[Sequence[Any], bytes], None], receiver_max_keepalive_seconds: Optional[int] = None, receiver_max_active_callbacks: Optional[int] = None, @@ -590,17 +597,13 @@ def _Callback(message: pubsub_v1.subscriber.message.Message): callback(payload=payload, ack_id=message.ack_id, publish_time=message.publish_time) if queue == "MessageHandler" or queue == "": - topic_id = self.msg_handler_top_id - subscription_id = self.msg_handler_sub_id + subscription_path = self.message_handler_sub_path elif queue == "FlowProcessing": - topic_id = self.flow_processing_top_id - subscription_id = self.flow_processing_sub_id + subscription_path = self.flow_processing_sub_path return RequestQueue( - self.project_id, - topic_id, - subscription_id, - do_subscribe, + self.subscriber, + subscription_path, _Callback, receiver_max_keepalive_seconds=receiver_max_keepalive_seconds, receiver_max_active_callbacks=receiver_max_active_callbacks, diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index ef47f6169..8dd393a26 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -323,8 +323,7 @@ def testNewRequestQueueCallbackGetsCalled(self): callback_func = mock.Mock() requestQueue = self.raw_db.NewRequestQueue( - "", - True, + "MessageHandler", callback_func, receiver_max_keepalive_seconds=10, receiver_max_active_callbacks=1, @@ -332,7 +331,9 @@ def testNewRequestQueueCallbackGetsCalled(self): ) start_time = time.time() - requestQueue.publish("foo") + requests = [] + requests.append("foo") + self.raw_db.PublishMessageHandlerRequests(requests) while callback_func.call_count == 0: time.sleep(0.1) @@ -343,38 +344,28 @@ def testNewRequestQueueCallbackGetsCalled(self): result = callback_func.call_args.kwargs ack_ids = [] ack_ids.append(result["ack_id"]) - requestQueue.ack(ack_ids) - requestQueue.stop() + self.raw_db.AckMessageHandlerRequests(ack_ids) + requestQueue.Stop() - def testNewRequestQueueCount(self): - callback_func = mock.Mock() - - requestQueue = self.raw_db.NewRequestQueue( - "", - False, - callback_func, - receiver_max_keepalive_seconds=1, - receiver_max_active_callbacks=1, - receiver_max_messages_per_callback=1, - ) + def testNewRequestCount(self): start_time = time.time() - requestQueue.publish("foo") - requestQueue.publish("bar") + requests = [] + requests.append("foo") + requests.append("bar") + self.raw_db.PublishMessageHandlerRequests(requests) results = {} - - while len(results) < 2 and time.time() - start_time < 10: + ack_ids = [] + while len(results) < 2 or time.time() - start_time < 10: time.sleep(0.1) - messages = requestQueue.pull() - ack_ids = [] + messages = self.raw_db.ReadMessageHandlerRequests() for msg in messages: results.update({msg.message.message_id: msg}) ack_ids.append(msg.ack_id) - requestQueue.ack(ack_ids) + self.raw_db.AckMessageHandlerRequests(ack_ids) self.assertLen(results, 2) - requestQueue.stop() if __name__ == "__main__": absltest.main() From f874977fba54377cf96bf4ab91a858de88404a6e Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Thu, 5 Jun 2025 09:35:19 +0000 Subject: [PATCH 015/168] Adds MessageHandler read/write --- grr/proto/grr_response_proto/objects.proto | 1 + .../databases/spanner_flows.py | 94 +++++++------------ .../databases/spanner_message_handler_test.py | 74 +++++++++++++++ .../databases/spanner_utils.py | 40 ++++++-- .../databases/spanner_utils_test.py | 17 ++-- 5 files changed, 145 insertions(+), 81 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_message_handler_test.py diff --git a/grr/proto/grr_response_proto/objects.proto b/grr/proto/grr_response_proto/objects.proto index 58440aafd..318a486c7 100644 --- a/grr/proto/grr_response_proto/objects.proto +++ b/grr/proto/grr_response_proto/objects.proto @@ -331,6 +331,7 @@ message MessageHandlerRequest { optional uint64 leased_until = 5 [(sem_type) = { type: "RDFDatetime" }]; optional string leased_by = 6; optional EmbeddedRDFValue request = 7; + optional string ack_id = 8; } message SerializedValueOfUnrecognizedType { diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 86cab5cdb..0ff817dc8 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -28,10 +28,11 @@ name="spanner_delete_flow_requests_failures" ) - _MESSAGE_HANDLER_MAX_KEEPALIVE_SECONDS = 300 _MESSAGE_HANDLER_MAX_ACTIVE_CALLBACKS = 20 +_MILLISECONDS = 1000 +_SECONDS = 1000 * _MILLISECONDS @dataclasses.dataclass(frozen=True) class _FlowKey: @@ -2157,41 +2158,29 @@ def ListScheduledFlows( def WriteMessageHandlerRequests( self, requests: Iterable[objects_pb2.MessageHandlerRequest] ) -> None: - """Writes a list of message handler requests to the database.""" - - def Mutation(mut: spanner_utils.Mutation) -> None: - creation_timestamp = spanner_lib.CommitTimestamp() + """Writes a list of message handler requests to the queue.""" - for r in requests: - key = (r.handler_name, r.request_id, creation_timestamp) - - mut.Send( - queue="MessageHandlerRequestsQueue", - key=key, - value=r, - column="Payload", - ) + msgRequests = [] + for request in requests: + msgRequests.append(request.SerializeToString()) - self.db.BufferedMutate(Mutation) + self.db.PublishMessageHandlerRequests(msgRequests) @db_utils.CallLogged @db_utils.CallAccounted def ReadMessageHandlerRequests( self, ) -> Sequence[objects_pb2.MessageHandlerRequest]: - """Reads all message handler requests from the database.""" - - query = """ - SELECT t.Payload, t.CreationTime FROM MessageHandlerRequestsQueue AS t - """ + """Reads all message handler requests from the queue.""" results = [] - for payload, creation_time in self.db.ParamQuery(query, {}): + for result in self.db.ReadMessageHandlerRequests(): req = objects_pb2.MessageHandlerRequest() - req.ParseFromString(payload) + req.ParseFromString(result["payload"]) req.timestamp = rdfvalue.RDFDatetime.FromDatetime( - creation_time + result["publish_time"] ).AsMicrosecondsSinceEpoch() + req.ack_id = result["ack_id"] results.append(req) return results @@ -2225,16 +2214,17 @@ def DeleteMessageHandlerRequests( ) -> None: """Deletes a list of message handler requests from the database.""" - def Txn(txn: spanner_utils.Transaction) -> None: - self._BuildDeleteMessageHandlerRequestWrites(txn, requests) + ack_ids = [] + for request in requests: + ack_ids.append(request.ack_id) - self.db.Transact(Txn) + self.db.AckMessageHandlerRequests(ack_ids) def _LeaseMessageHandlerRequest( self, req: objects_pb2.MessageHandlerRequest, lease_time: rdfvalue.Duration, - ) -> objects_pb2.MessageHandlerRequest: + ) -> None: """Leases the given message handler request. Leasing of the message amounts to the following: @@ -2252,35 +2242,13 @@ def _LeaseMessageHandlerRequest( """ delivery_time = rdfvalue.RDFDatetime.Now() + lease_time - clone = objects_pb2.MessageHandlerRequest() - clone.CopyFrom(req) - clone.leased_until = delivery_time.AsMicrosecondsSinceEpoch() - clone.leased_by = utils.ProcessIdString() - - def Txn(txn) -> None: - self._BuildDeleteMessageHandlerRequestWrites(txn, [clone]) - - with txn.BufferedMutate() as mut: - key = (req.handler_name, req.request_id, spanner_lib.CommitTimestamp()) - mut.Send( - queue="MessageHandlerRequestsQueue", - key=key, - value=clone, - column="Payload", - deliver_time=delivery_time.AsDatetimeUTC(), - ) + leased_until = str(delivery_time.AsMicrosecondsSinceEpoch()) + leased_by = utils.ProcessIdString() - result = self.db.Transact(Txn) - # Using the transaction's commit timestamp and modifying the request object - # with it allows us to avoid doing a separate database read to read - # the same object back with a Spanner-provided timestamp - # (please see ReadMessageHandlerRequests that uses the CreationTime - # column to set the request's 'timestamp' attribute). - clone.timestamp = rdfvalue.RDFDatetime.FromDatetime( - result.commit_time - ).AsMicrosecondsSinceEpoch() + ack_ids = [] + ack_ids.append(req.ack_id) - return clone + self.db.LeaseMessageHandlerRequest(ack_ids, lease_time.ToInt(_SECONDS)) def RegisterMessageHandler( self, @@ -2291,27 +2259,31 @@ def RegisterMessageHandler( """Registers a message handler to receive batches of messages.""" self.UnregisterMessageHandler() - def Callback(expanded_key: Sequence[Any], payload: bytes): - del expanded_key + def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time, attributes): try: req = objects_pb2.MessageHandlerRequest() req.ParseFromString(payload) - leased = self._LeaseMessageHandlerRequest(req, lease_time) + req.ack_id = ack_id + for attr in attributes: + if attr[0] == "leased_until": + req.leased_until = int(attr[1]) + elif attr[0] == "leased_by": + req.leased_by = attr[1] + self._LeaseMessageHandlerRequest(req, lease_time) logging.info("Leased message handler request: %s", req.request_id) - handler([leased]) + handler([req]) except Exception as e: # pylint: disable=broad-except logging.exception( "Exception raised during MessageHandlerRequest processing: %s", e ) - receiver = self.db.NewQueueReceiver( - "MessageHandlerRequestsQueue", + receiver = self.db.NewRequestQueue( + "MessageHandler", Callback, receiver_max_keepalive_seconds=_MESSAGE_HANDLER_MAX_KEEPALIVE_SECONDS, receiver_max_active_callbacks=_MESSAGE_HANDLER_MAX_ACTIVE_CALLBACKS, receiver_max_messages_per_callback=limit, ) - receiver.Receive() self._message_handler_receiver = receiver def UnregisterMessageHandler( diff --git a/grr/server/grr_response_server/databases/spanner_message_handler_test.py b/grr/server/grr_response_server/databases/spanner_message_handler_test.py new file mode 100644 index 000000000..7c840d9bc --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_message_handler_test.py @@ -0,0 +1,74 @@ +import queue + +from absl.testing import absltest + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib import utils +from grr_response_core.lib.rdfvalues import mig_protodict +from grr_response_core.lib.rdfvalues import protodict as rdf_protodict +from grr_response_proto import objects_pb2 + +from grr_response_server.databases import db_message_handler_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseHandlerTest(spanner_test_lib.TestCase +): + def setUp(self): + super().setUp() + + def testMessageHandlerRequests(self): + + requests = [] + for i in range(5): + emb = mig_protodict.ToProtoEmbeddedRDFValue( + rdf_protodict.EmbeddedRDFValue(rdfvalue.RDFInteger(i)) + ) + requests.append( + objects_pb2.MessageHandlerRequest( + client_id="C.1000000000000000", + handler_name="Testhandler", + request_id=i * 100, + request=emb, + ) + ) + + self.db.WriteMessageHandlerRequests(requests) + + read = self.db.ReadMessageHandlerRequests() + self.assertLen(read, 5) + + self.db.DeleteMessageHandlerRequests(read[:2]) + self.db.DeleteMessageHandlerRequests(read[4:5]) + + for r in read: + self.assertTrue(r.timestamp) + r.ClearField("timestamp") + self.assertTrue(r.ack_id) + r.ClearField("ack_id") + + self.assertCountEqual(read, requests) + + read = self.db.ReadMessageHandlerRequests() + self.assertLen(read, 2) + self.db.DeleteMessageHandlerRequests(read) + + for r in read: + self.assertTrue(r.timestamp) + r.ClearField("timestamp") + self.assertTrue(r.ack_id) + r.ClearField("ack_id") + + self.assertCountEqual(requests[2:4], read) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 06bb7dad5..52dd249f7 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -543,9 +543,18 @@ def AckMessageHandlerRequests(self, ack_ids: [str]) -> None: def AckFlowProcessingRequests(self, ack_ids: [str]) -> None: self.AckRequests(ack_ids, self.flow_proccessing_sub_path) + def LeaseMessageHandlerRequest(self, ack_ids: [str], ack_deadline: int) -> None: + self.subscriber.modify_ack_deadline( + request={ + "subscription": self.message_handler_sub_path, + "ack_ids": ack_ids, + "ack_deadline_seconds": ack_deadline, + } + ) + def PublishRequests(self, requests: [str], top_path: str) -> None: for req in requests: - self.publisher.publish(top_path, req.encode("utf-8")) + self.publisher.publish(top_path, req) def AckRequests(self, ack_ids: [str], sub_path: str) -> None: self.subscriber.acknowledge( @@ -554,13 +563,27 @@ def AckRequests(self, ack_ids: [str], sub_path: str) -> None: def ReadRequests(self, sub_path: str): # Make the request - response = self.subscriber.pull( + + start_time = time.time() + results = {} + while time.time() - start_time < 10: + time.sleep(0.1) + + response = self.subscriber.pull( request={ "subscription": sub_path, "max_messages": 10000, }, - ) - return response.received_messages + ) + for resp in response.received_messages: + results.update({resp.message.message_id: { + "payload": resp.message.data, + "msg_id": resp.message.message_id, + "ack_id": resp.ack_id, + "publish_time": resp.message.publish_time} + }) + + return results.values() def NewRequestQueue( self, @@ -589,12 +612,9 @@ def NewRequestQueue( """ def _Callback(message: pubsub_v1.subscriber.message.Message): - payload = message.data.decode("utf-8") - #if queue == "MessageHandler": - # payload = objects_pb2.MessageHandlerRequest.ParseFromString(data) - #elif queue == "FlowProcessing": - # payload = flows_pb2.FlowProcessingRequest.ParseFromString(data) - callback(payload=payload, ack_id=message.ack_id, publish_time=message.publish_time) + payload = message.data + callback(payload=payload, msg_id=message.message_id, ack_id=message.ack_id, + publish_time=message.publish_time) if queue == "MessageHandler" or queue == "": subscription_path = self.message_handler_sub_path diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index 8dd393a26..a1166d874 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -332,7 +332,7 @@ def testNewRequestQueueCallbackGetsCalled(self): start_time = time.time() requests = [] - requests.append("foo") + requests.append("foo".encode("utf-8")) self.raw_db.PublishMessageHandlerRequests(requests) while callback_func.call_count == 0: @@ -351,18 +351,15 @@ def testNewRequestCount(self): start_time = time.time() requests = [] - requests.append("foo") - requests.append("bar") + requests.append("foo".encode("utf-8")) + requests.append("bar".encode("utf-8")) self.raw_db.PublishMessageHandlerRequests(requests) - results = {} + results = self.raw_db.ReadMessageHandlerRequests() ack_ids = [] - while len(results) < 2 or time.time() - start_time < 10: - time.sleep(0.1) - messages = self.raw_db.ReadMessageHandlerRequests() - for msg in messages: - results.update({msg.message.message_id: msg}) - ack_ids.append(msg.ack_id) + + for result in results: + ack_ids.append(result["ack_id"]) self.raw_db.AckMessageHandlerRequests(ack_ids) self.assertLen(results, 2) From 9a164962c4c034f0881fe77f52ff0d644b580371 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Thu, 5 Jun 2025 17:59:23 +0000 Subject: [PATCH 016/168] Adds MessageHandler lease mgmt --- .../databases/spanner_flows.py | 50 ++++++++++------ .../databases/spanner_message_handler_test.py | 58 ++++++++++++++++++- .../databases/spanner_utils.py | 7 ++- 3 files changed, 90 insertions(+), 25 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 0ff817dc8..957f27f09 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2224,7 +2224,7 @@ def _LeaseMessageHandlerRequest( self, req: objects_pb2.MessageHandlerRequest, lease_time: rdfvalue.Duration, - ) -> None: + ) -> bool: """Leases the given message handler request. Leasing of the message amounts to the following: @@ -2240,15 +2240,31 @@ def _LeaseMessageHandlerRequest( Copy of the original request object with "leased_until" and "leased_by" attributes set. """ - delivery_time = rdfvalue.RDFDatetime.Now() + lease_time - - leased_until = str(delivery_time.AsMicrosecondsSinceEpoch()) - leased_by = utils.ProcessIdString() - - ack_ids = [] - ack_ids.append(req.ack_id) - - self.db.LeaseMessageHandlerRequest(ack_ids, lease_time.ToInt(_SECONDS)) + date_time_now = rdfvalue.RDFDatetime.Now() + epoch_now = date_time_now.AsMicrosecondsSinceEpoch() + delivery_time = date_time_now + lease_time + + leased = False + if not req.leased_by or req.leased_until <= epoch_now: + # If the message has not been leased yet or the lease has expired + # then take and write back the clone back to the queue + # and delete the original message + clone = objects_pb2.MessageHandlerRequest() + clone.CopyFrom(req) + clone.leased_until = delivery_time.AsMicrosecondsSinceEpoch() + clone.leased_by = utils.ProcessIdString() + clone.ack_id = "" + self.WriteMessageHandlerRequests([clone]) + self.DeleteMessageHandlerRequests([req]) + elif req.leased_until > epoch_now: + # if we have leased the message (leased_until set and in the future) + # then we modify ack deadline to match the leased_until time + leased = True + ack_ids = [] + ack_ids.append(req.ack_id) + self.db.LeaseMessageHandlerRequests(ack_ids, int((req.leased_until - epoch_now)/1000000)) + + return leased def RegisterMessageHandler( self, @@ -2259,19 +2275,15 @@ def RegisterMessageHandler( """Registers a message handler to receive batches of messages.""" self.UnregisterMessageHandler() - def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time, attributes): + def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time): try: req = objects_pb2.MessageHandlerRequest() req.ParseFromString(payload) req.ack_id = ack_id - for attr in attributes: - if attr[0] == "leased_until": - req.leased_until = int(attr[1]) - elif attr[0] == "leased_by": - req.leased_by = attr[1] - self._LeaseMessageHandlerRequest(req, lease_time) - logging.info("Leased message handler request: %s", req.request_id) - handler([req]) + leased = self._LeaseMessageHandlerRequest(req, lease_time) + if leased: + logging.info("Leased message handler request: %s", req.request_id) + handler([req]) except Exception as e: # pylint: disable=broad-except logging.exception( "Exception raised during MessageHandlerRequest processing: %s", e diff --git a/grr/server/grr_response_server/databases/spanner_message_handler_test.py b/grr/server/grr_response_server/databases/spanner_message_handler_test.py index 7c840d9bc..088dd6b0c 100644 --- a/grr/server/grr_response_server/databases/spanner_message_handler_test.py +++ b/grr/server/grr_response_server/databases/spanner_message_handler_test.py @@ -20,13 +20,15 @@ def tearDownModule() -> None: spanner_test_lib.TearDown() -class SpannerDatabaseHandlerTest(spanner_test_lib.TestCase -): +class SpannerDatabaseHandlerTest(spanner_test_lib.TestCase): def setUp(self): super().setUp() def testMessageHandlerRequests(self): + ######################## + # Read / Write tests + ######################## requests = [] for i in range(5): emb = mig_protodict.ToProtoEmbeddedRDFValue( @@ -35,7 +37,7 @@ def testMessageHandlerRequests(self): requests.append( objects_pb2.MessageHandlerRequest( client_id="C.1000000000000000", - handler_name="Testhandler", + handler_name="Testhandler 0", request_id=i * 100, request=emb, ) @@ -69,6 +71,56 @@ def testMessageHandlerRequests(self): self.assertCountEqual(requests[2:4], read) + ######################## + # Lease Management tests + ######################## + + requests = [] + for i in range(10): + emb = mig_protodict.ToProtoEmbeddedRDFValue( + rdf_protodict.EmbeddedRDFValue(rdfvalue.RDFInteger(i)) + ) + requests.append( + objects_pb2.MessageHandlerRequest( + client_id="C.1000000000000001", + handler_name="Testhandler 1", + request_id=i * 100, + request=emb, + ) + ) + + lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) + + leased = queue.Queue() + self.db.RegisterMessageHandler(leased.put, lease_time, limit=10) + + self.db.WriteMessageHandlerRequests(requests) + + got = [] + while len(got) < 10: + try: + l = leased.get(True, timeout=6) + except queue.Empty: + self.fail( + "Timed out waiting for messages, expected 10, got %d" % len(got) + ) + self.assertLessEqual(len(l), 10) + for m in l: + self.assertEqual(m.leased_by, utils.ProcessIdString()) + self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now()) + self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now()) + got += l + self.db.DeleteMessageHandlerRequests(got) + + got.sort(key=lambda req: req.request_id) + for m in got: + m.ClearField("leased_by") + m.ClearField("leased_until") + m.ClearField("timestamp") + m.ClearField("ack_id") + self.assertEqual(requests, got) + + self.db.UnregisterMessageHandler() if __name__ == "__main__": absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 52dd249f7..c9b116ab6 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -64,10 +64,11 @@ def __init__( # A thread pool-based scheduler. It must not be shared across SubscriberClients. scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) - flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) + #flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) self.streaming_pull_future = subscriber.subscribe( - subscription_path, callback=callback, scheduler=scheduler, flow_control=flow_control + subscription_path, callback=callback, scheduler=scheduler + #subscription_path, callback=callback, scheduler=scheduler, flow_control=flow_control ) def Stop(self): @@ -543,7 +544,7 @@ def AckMessageHandlerRequests(self, ack_ids: [str]) -> None: def AckFlowProcessingRequests(self, ack_ids: [str]) -> None: self.AckRequests(ack_ids, self.flow_proccessing_sub_path) - def LeaseMessageHandlerRequest(self, ack_ids: [str], ack_deadline: int) -> None: + def LeaseMessageHandlerRequests(self, ack_ids: [str], ack_deadline: int) -> None: self.subscriber.modify_ack_deadline( request={ "subscription": self.message_handler_sub_path, From 20fad7d723c3c83491a1e2a468f767f3fb6bc6c5 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sat, 7 Jun 2025 16:48:59 +0000 Subject: [PATCH 017/168] Adds Users table --- .../grr_response_server/databases/spanner.sdl | 104 +-- .../databases/spanner_clients.py | 614 ++++++++++++++++-- .../databases/spanner_clients_test.py | 35 + .../databases/spanner_cron_jobs.py | 16 +- .../databases/spanner_flows.py | 60 +- .../databases/spanner_hunts.py | 29 +- .../databases/spanner_users.py | 43 +- 7 files changed, 754 insertions(+), 147 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_clients_test.py diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 7d57e529b..7aaff306e 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -167,7 +167,7 @@ CREATE TABLE Labels ( ) PRIMARY KEY (Label); CREATE TABLE Clients ( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, LastSnapshotTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), LastStartupTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), LastRRGStartupTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), @@ -181,28 +181,28 @@ CREATE TABLE Clients ( ) PRIMARY KEY (ClientId); CREATE TABLE ClientSnapshots( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Snapshot `grr.ClientSnapshot` NOT NULL, ) PRIMARY KEY (ClientId, CreationTime), INTERLEAVE IN PARENT Clients ON DELETE CASCADE; CREATE TABLE ClientStartups( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Startup `grr.StartupInfo` NOT NULL, ) PRIMARY KEY (ClientId, CreationTime), INTERLEAVE IN PARENT Clients ON DELETE CASCADE; CREATE TABLE ClientRRGStartups( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Startup `rrg.startup.Startup` NOT NULL, ) PRIMARY KEY (ClientId, CreationTime), INTERLEAVE IN PARENT Clients ON DELETE CASCADE; CREATE TABLE ClientCrashes( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Crash `grr.ClientCrash` NOT NULL, ) PRIMARY KEY (ClientId, CreationTime), @@ -219,7 +219,7 @@ CREATE TABLE Users ( CREATE TABLE UserNotifications( Username STRING(256) NOT NULL, - NotificationId INT64 NOT NULL, + NotificationId STRING(36) NOT NULL, Type `grr.UserNotification.Type` NOT NULL, State `grr.UserNotification.State` NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), @@ -230,9 +230,9 @@ CREATE TABLE UserNotifications( CREATE TABLE ApprovalRequests( Requestor STRING(256) NOT NULL, - ApprovalId INT64 NOT NULL, - SubjectClientId INT64, - SubjectHuntId INT64, + ApprovalId STRING(36) NOT NULL, + SubjectClientId STRING(18), + SubjectHuntId STRING(8), SubjectCronJobId STRING(100), Reason STRING(MAX) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), @@ -268,9 +268,9 @@ CREATE INDEX ApprovalRequestsByRequestorSubjectCronJobId CREATE TABLE ApprovalGrants( Requestor STRING(256) NOT NULL, - ApprovalId INT64 NOT NULL, + ApprovalId STRING(36) NOT NULL, Grantor STRING(256) NOT NULL, - GrantId INT64 NOT NULL, + GrantId STRING(36) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), CONSTRAINT fk_approval_grant_grantor_username @@ -283,7 +283,7 @@ CREATE INDEX ApprovalGrantsByGrantor ON ApprovalGrants(Grantor); CREATE TABLE ClientLabels( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, Owner STRING(256) NOT NULL, Label STRING(128) NOT NULL, @@ -298,7 +298,7 @@ CREATE TABLE ClientLabels( INTERLEAVE IN PARENT Clients ON DELETE CASCADE; CREATE TABLE ClientKeywords( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, Keyword STRING(256) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), @@ -311,10 +311,10 @@ CREATE INDEX ClientKeywordsByKeywordCreationTime ON ClientKeywords(Keyword, CreationTime); CREATE TABLE Flows( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - ParentFlowId INT64, - ParentHuntId INT64, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + ParentFlowId STRING(8), + ParentHuntId STRING(8), LongFlowId STRING(256) NOT NULL, Creator STRING(256) NOT NULL, Name STRING(256) NOT NULL, @@ -322,7 +322,7 @@ CREATE TABLE Flows( CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), UpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Crash `grr.ClientCrash`, - NextRequestToProcess INT64, + NextRequestToProcess STRING(8), ProcessingWorker STRING(MAX), ProcessingStartTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), ProcessingEndTime TIMESTAMP, @@ -347,9 +347,9 @@ CREATE INDEX FlowsByParentHuntIdFlowIdState STORING (ReplyCount, NetworkBytesSent, UserCpuTimeUsed, SystemCpuTimeUsed); CREATE TABLE FlowResults( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - HuntId INT64, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + HuntId STRING(8), CreationTime TIMESTAMP NOT NULL, Payload `google.protobuf.Any`, RdfType STRING(MAX), @@ -365,9 +365,9 @@ CREATE INDEX FlowResultsByHuntIdFlowIdRdfTypeTagCreationTime ON FlowResults(HuntId, FlowId, RdfType, Tag, CreationTime); CREATE TABLE FlowErrors( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - HuntId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + HuntId STRING(8) NOT NULL, CreationTime TIMESTAMP NOT NULL, Payload `google.protobuf.Any`, RdfType STRING(MAX), @@ -381,12 +381,12 @@ CREATE INDEX FlowErrorsByHuntIdFlowIdRdfTypeTagCreationTime ON FlowErrors(HuntId, FlowId, RdfType, Tag, CreationTime); CREATE TABLE FlowRequests( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - RequestId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + RequestId STRING(8) NOT NULL, NeedsProcessing BOOL, ExpectedResponseCount INT64, - NextResponseId INT64, + NextResponseId STRING(8), CallbackState STRING(256), Payload `grr.FlowRequest` NOT NULL, StartTime TIMESTAMP, @@ -396,10 +396,10 @@ CREATE TABLE FlowRequests( ROW DELETION POLICY (OLDER_THAN(CreationTime, INTERVAL 84 DAY)); CREATE TABLE FlowResponses( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - RequestId INT64 NOT NULL, - ResponseId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + RequestId STRING(8) NOT NULL, + ResponseId STRING(8) NOT NULL, Response `grr.FlowResponse`, Status `grr.FlowStatus`, Iterator `grr.FlowIterator`, @@ -413,9 +413,9 @@ CREATE TABLE FlowResponses( INTERLEAVE IN PARENT FlowRequests ON DELETE CASCADE; CREATE TABLE FlowLogEntries( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - HuntId INT64, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + HuntId STRING(8), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Message STRING(MAX) NOT NULL, ) PRIMARY KEY (ClientId, FlowId, CreationTime), @@ -425,10 +425,10 @@ CREATE INDEX FlowLogEntriesByHuntIdCreationTime ON FlowLogEntries(HuntId, CreationTime); CREATE TABLE FlowRRGLogs( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - RequestId INT64 NOT NULL, - ResponseId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + RequestId STRING(8) NOT NULL, + ResponseId STRING(8) NOT NULL, LogLevel `rrg.Log.Level` NOT NULL, LogTime TIMESTAMP NOT NULL, LogMessage STRING(MAX) NOT NULL, @@ -437,10 +437,10 @@ CREATE TABLE FlowRRGLogs( INTERLEAVE IN PARENT Flows ON DELETE CASCADE; CREATE TABLE FlowOutputPluginLogEntries( - ClientId INT64 NOT NULL, - FlowId INT64 NOT NULL, - OutputPluginId INT64 NOT NULL, - HuntId INT64, + ClientId STRING(18) NOT NULL, + FlowId STRING(8) NOT NULL, + OutputPluginId STRING(8) NOT NULL, + HuntId STRING(8), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, Message STRING(MAX) NOT NULL, @@ -451,9 +451,9 @@ CREATE INDEX FlowOutputPluginLogEntriesByHuntIdCreationTime ON FlowOutputPluginLogEntries(HuntId, CreationTime); CREATE TABLE ScheduledFlows( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, Creator STRING(256) NOT NULL, - ScheduledFlowId INT64 NOT NULL, + ScheduledFlowId STRING(8) NOT NULL, FlowName STRING(256) NOT NULL, FlowArgs `google.protobuf.Any` NOT NULL, RunnerArgs `grr.FlowRunnerArgs` NOT NULL, @@ -470,7 +470,7 @@ CREATE INDEX ScheduledFlowsByCreator ON ScheduledFlows(Creator); CREATE TABLE Paths( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, Type `grr.PathInfo.PathType` NOT NULL, Path BYTES(MAX) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), @@ -485,7 +485,7 @@ CREATE INDEX PathsByClientIdTypePathDepth ON Paths(ClientId, Type, Path, Depth); CREATE TABLE PathFileStats( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, Type `grr.PathInfo.PathType` NOT NULL, Path BYTES(MAX) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), @@ -494,7 +494,7 @@ CREATE TABLE PathFileStats( INTERLEAVE IN PARENT Paths ON DELETE CASCADE; CREATE TABLE PathFileHashes( - ClientId INT64 NOT NULL, + ClientId STRING(18) NOT NULL, Type `grr.PathInfo.PathType` NOT NULL, Path BYTES(MAX) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), @@ -542,7 +542,7 @@ CREATE TABLE YaraSignatureReferences( ) PRIMARY KEY (BlobId); CREATE TABLE Hunts( - HuntId INT64 NOT NULL, + HuntId STRING(8) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), LastUpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Creator STRING(256) NOT NULL, @@ -579,8 +579,8 @@ CREATE INDEX HuntsByCreationTime CREATE INDEX HuntsByCreator ON Hunts(Creator); CREATE TABLE HuntOutputPlugins( - HuntId INT64 NOT NULL, - OutputPluginId INT64 NOT NULL, + HuntId STRING(8) NOT NULL, + OutputPluginId STRING(8) NOT NULL, Name STRING(256) NOT NULL, Args `google.protobuf.Any`, State `google.protobuf.Any` NOT NULL, @@ -588,7 +588,7 @@ CREATE TABLE HuntOutputPlugins( INTERLEAVE IN PARENT Hunts ON DELETE CASCADE; CREATE TABLE ForemanRules( - HuntId INT64 NOT NULL, + HuntId STRING(8) NOT NULL, ExpirationTime TIMESTAMP NOT NULL, Payload `grr.ForemanCondition`, diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index a67e9e8d0..8e6cc3d59 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -6,6 +6,9 @@ import re from typing import Collection, Iterator, Mapping, Optional, Sequence, Tuple +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import iterator from grr_response_proto import jobs_pb2 @@ -23,8 +26,6 @@ class ClientsMixin: db: spanner_utils.Database - # TODO(b/196379916): Implement client methods. - @db_utils.CallLogged @db_utils.CallAccounted def MultiWriteClientMetadata( @@ -36,8 +37,38 @@ def MultiWriteClientMetadata( fleetspeak_validation_info: Optional[Mapping[str, str]] = None, ) -> None: """Writes metadata about the clients.""" - # Early return to avoid generating empty mutation. - + if not client_ids: + return + + row = {} + + if first_seen is not None: + row["FirstSeenTime"] = first_seen.AsDatetime() + if last_ping is not None: + row["LastPingTime"] = last_ping.AsDatetime() + if last_foreman is not None: + row["LastForemanTime"] = last_foreman.AsDatetime() + + if fleetspeak_validation_info is not None: + if fleetspeak_validation_info: + row["FleetspeakValidationInfo"] = ( + models_clients.FleetspeakValidationInfoFromDict( + fleetspeak_validation_info + ) + ) + else: + row["FleetspeakValidationInfo"] = None + + def Mutation(mut) -> None: + columns = [] + rows = [] + for client_id in client_ids: + client_row = {"ClientId": client_id, **row} + columns, values = zip(*client_row.items()) + rows.append(values) + mut.insert_or_update(table="Clients", columns=columns, values=rows) + + self.db.Mutate(Mutation, txn_tag="MultiWriteClientMetadata") @db_utils.CallLogged @db_utils.CallAccounted @@ -48,6 +79,56 @@ def MultiReadClientMetadata( """Reads ClientMetadata records for a list of clients.""" result = {} + keys = [] + for client_id in client_ids: + keys.append([client_id]) + keyset = spanner_lib.KeySet(keys=keys) + + cols = ( + "ClientId", + "LastStartupTime", + "LastCrashTime", + "Certificate", + "FirstSeenTime", + "LastPingTime", + "LastForemanTime", + "FleetspeakValidationInfo", + ) + + for row in self.db.ReadSet(table="Clients", rows=keyset, cols=cols): + client_id = row[0] + + metadata = objects_pb2.ClientMetadata() + + if row[1] is not None: + metadata.startup_info_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[1]) + ) + if row[2] is not None: + metadata.last_crash_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[2]) + ) + if row[3] is not None: + metadata.certificate = row[3] + if row[4] is not None: + metadata.first_seen = int( + rdfvalue.RDFDatetime.FromDatetime(row[4]) + ) + if row[5] is not None: + metadata.ping = int( + rdfvalue.RDFDatetime.FromDatetime(row[5]) + ) + if row[6] is not None: + metadata.last_foreman_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[6]) + ) + if row[7] is not None: + metadata.last_fleetspeak_validation_info.ParseFromString( + row[7] + ) + + result[client_id] = metadata + return result @db_utils.CallLogged @@ -60,7 +141,44 @@ def MultiAddClientLabels( ) -> None: """Attaches user labels to the specified clients.""" # Early return to avoid generating empty mutation. - + if not client_ids or not labels: + return + + def Mutation(mut) -> None: + label_rows = [] + for label in labels: + label_rows.append([label]) + mut.insert_or_update(table="Labels", columns=["Label"], values=label_rows) + + client_rows = [] + for client_id in client_ids: + for label in labels: + client_rows.append([client_id, owner, label]) + columns = ["ClientId", "Owner", "Label"] + mut.insert_or_update(table="ClientLabels", columns=columns, values=client_rows) + + try: + self.db.Mutate(Mutation, txn_tag="MultiAddClientLabels") + except NotFound as error: + message = str(error) + if "Parent row is missing: Clients" in message: + raise db_lib.AtLeastOneUnknownClientError(client_ids) from error + elif "fk_client_label_owner_username" in message: + match = re.search( + r"\((?P\w+)\) in Users\(Username\)", + message, + ) + if match is not None: + username = match["username"] + else: + username = "" + logging.error( + "Couldn't extract username from foreign key constraint: %s", + message, + ) + raise db_lib.UnknownGRRUserError(username=username, cause=error) + else: + raise @db_utils.CallLogged @db_utils.CallAccounted @@ -71,6 +189,26 @@ def MultiReadClientLabels( """Reads the user labels for a list of clients.""" result = {client_id: [] for client_id in client_ids} + query = """ + SELECT l.ClientId, ARRAY_AGG((l.Owner, l.Label)) + FROM ClientLabels AS l + WHERE l.ClientId IN UNNEST({client_ids}) + GROUP BY l.ClientId + """ + params = {"client_ids": client_ids} + + for client_id, labels in self.db.ParamQuery( + query, params, txn_tag="MultiReadClientLabels" + ): + for owner, label in labels: + label_proto = objects_pb2.ClientLabel() + label_proto.name = label + label_proto.owner = owner + result[client_id].append(label_proto) + + for labels in result.values(): + labels.sort(key=lambda label: (label.owner, label.name)) + return result @db_utils.CallLogged @@ -83,6 +221,11 @@ def RemoveClientLabels( ) -> None: """Removes a list of user labels from a given client.""" + def Mutation(mut: spanner_utils.Mutation) -> None: + for label in labels: + mut.Delete(table="ClientLabels", key=(client_id, owner, label)) + + self.db.Mutate(Mutation, txn_tag="RemoveClientLabels") @db_utils.CallLogged @db_utils.CallAccounted @@ -90,6 +233,13 @@ def ReadAllClientLabels(self) -> Collection[str]: """Lists all client labels known to the system.""" result = [] + query = """ + SELECT l.Label + FROM Labels AS l + """ + + for (label,) in self.db.Query(query, txn_tag="ReadAllClientLabels"): + result.append(label) return result @@ -98,6 +248,31 @@ def ReadAllClientLabels(self) -> Collection[str]: def WriteClientSnapshot(self, snapshot: objects_pb2.ClientSnapshot) -> None: """Writes new client snapshot.""" + startup = snapshot.startup_info + snapshot_without_startup_info = objects_pb2.ClientSnapshot() + snapshot_without_startup_info.CopyFrom(snapshot) + snapshot_without_startup_info.ClearField("startup_info") + + def Mutation(mut) -> None: + clients_rows = [] + clients_rows.append([snapshot.client_id,spanner_lib.COMMIT_TIMESTAMP,spanner_lib.COMMIT_TIMESTAMP]) + clients_columns = ["ClientId", "LastSnapshotTime", "LastStartupTime"] + mut.update(table="Clients", columns=clients_columns, values=clients_rows) + + snapshots_rows = [] + snapshots_rows.append([snapshot.client_id, spanner_lib.COMMIT_TIMESTAMP, snapshot_without_startup_info]) + snapshots_columns= ["ClientId", "CreationTime", "Snapshot"] + mut.insert(table="ClientSnapshots", columns=snapshots_columns, values=snapshots_rows) + + startups_rows = [] + startups_rows.append([snapshot.client_id, spanner_lib.COMMIT_TIMESTAMP, startup]) + startups_columns = [ "ClientId", "CreationTime", "Startup"] + mut.insert(table="ClientStartups", columns=startups_columns, values=startups_rows) + + try: + self.db.Mutate(Mutation, txn_tag="WriteClientSnapshot") + except NotFound as error: + raise db_lib.UnknownClientError(snapshot.client_id, cause=error) @db_utils.CallLogged @db_utils.CallAccounted @@ -106,10 +281,33 @@ def MultiReadClientSnapshot( client_ids: Collection[str], ) -> Mapping[str, Optional[objects_pb2.ClientSnapshot]]: """Reads the latest client snapshots for a list of clients.""" - # Unfortunately, Spanner has troubles with handling `UNNEST` expressions if - # the given array is empty, so we just handle such case separately. + if not client_ids: + return {} + + result = {client_id: None for client_id in client_ids} - return {} + query = """ + SELECT c.ClientId, ss.CreationTime, ss.Snapshot, su.Startup + FROM Clients AS c, ClientSnapshots AS ss, ClientStartups AS su + WHERE c.ClientId IN UNNEST({client_ids}) + AND ss.ClientId = c.ClientId + AND ss.CreationTime = c.LastSnapshotTime + AND su.ClientId = c.ClientId + AND su.CreationTime = c.LastStartupTime + """ + for row in self.db.ParamQuery( + query, {"client_ids": client_ids}, txn_tag="MultiReadClientSnapshot" + ): + client_id, creation_time, snapshot_bytes, startup_bytes = row + + snapshot = objects_pb2.ClientSnapshot() + snapshot.ParseFromString(snapshot_bytes) + snapshot.startup_info.ParseFromString(startup_bytes) + snapshot.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + result[client_id] = snapshot + + return result @db_utils.CallLogged @db_utils.CallAccounted @@ -117,12 +315,44 @@ def ReadClientSnapshotHistory( self, client_id: str, timerange: Optional[ - Tuple[Optional[rdfvalue.RDFDatetime], Optional[rdfvalue.RDFDatetime]] + tuple[Optional[rdfvalue.RDFDatetime], Optional[rdfvalue.RDFDatetime]] ] = None, ) -> Sequence[objects_pb2.ClientSnapshot]: """Reads the full history for a particular client.""" result = [] + query = """ + SELECT ss.CreationTime, ss.Snapshot, su.Startup + FROM ClientSnapshots AS ss, ClientStartups AS su + WHERE ss.ClientId = {client_id} + AND ss.ClientId = su.ClientId + AND ss.CreationTime = su.CreationTime + """ + params = {"client_id": client_id} + + if timerange is not None: + time_since, time_until = timerange + if time_since is not None: + query += " AND ss.CreationTime >= {time_since}" + params["time_since"] = time_since.AsDatetime() + if time_until is not None: + query += " AND ss.CreationTime <= {time_until}" + params["time_until"] = time_until.AsDatetime() + + query += " ORDER BY ss.CreationTime DESC" + + for row in self.db.ParamQuery( + query, params=params, txn_tag="ReadClientSnapshotHistory" + ): + creation_time, snapshot_bytes, startup_bytes = row + + snapshot = objects_pb2.ClientSnapshot() + snapshot.ParseFromString(snapshot_bytes) + snapshot.startup_info.ParseFromString(startup_bytes) + snapshot.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + result.append(snapshot) + return result @db_utils.CallLogged @@ -134,6 +364,19 @@ def WriteClientStartupInfo( ) -> None: """Writes a new client startup record.""" + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.update(table="Clients", + columns=["ClientId", "LastStartupTime"], + values=[[client_id, spanner_lib.COMMIT_TIMESTAMP]]) + + mut.insert(table="ClientStartups", + columns=["ClientId", "CreationTime", "Startup"], + values=[client_id ,spanner_lib.COMMIT_TIMESTAMP, startup]) + + try: + self.db.Mutate(Mutation, txn_tag="WriteClientStartupInfo") + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) @db_utils.CallLogged @db_utils.CallAccounted @@ -144,6 +387,23 @@ def WriteClientRRGStartup( ) -> None: """Writes a new RRG startup entry to the database.""" + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.update( + table="Clients", + columns=("ClientId", "LastRRGStartuptime"), + values=[(client_id, spanner_lib.COMMIT_TIMESTAMP)], + ) + + mut.insert( + table="ClientRRGStartups", + columns=("ClientId", "CreationTime", "Startup"), + values=[(client_id, spanner_lib.COMMIT_TIMESTAMP, startup)], + ) + + try: + self.db.Mutate(Mutation) + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) @db_utils.CallLogged @db_utils.CallAccounted @@ -152,8 +412,29 @@ def ReadClientRRGStartup( client_id: str, ) -> Optional[rrg_startup_pb2.Startup]: """Reads the latest RRG startup entry for the given client.""" + query = """ + SELECT su.Startup + FROM Clients AS c + LEFT JOIN ClientRRGStartups AS su + ON c.ClientId = su.ClientId + AND c.LastRRGStartupTime = su.CreationTime + WHERE c.ClientId = {client_id} + """ + params = { + "client_id": client_id, + } + + try: + (startup_bytes,) = self.db.ParamQuerySingle( + query, params, txn_tag="ReadClientRRGStartup" + ) + except iterator.NoYieldsError: + raise db_lib.UnknownClientError(client_id) # pylint: disable=raise-missing-from - return None + if startup_bytes is None: + return None + + return rrg_startup_pb2.Startup.FromString(startup_bytes) @db_utils.CallLogged @db_utils.CallAccounted @@ -163,7 +444,26 @@ def ReadClientStartupInfo( ) -> Optional[jobs_pb2.StartupInfo]: """Reads the latest client startup record for a single client.""" - return None + query = """ + SELECT su.CreationTime, su.Startup + FROM Clients AS c, ClientStartups AS su + WHERE c.ClientId = {client_id} + AND c.ClientId = su.ClientId + AND c.LastStartupTime = su.CreationTime + """ + params = {"client_id": client_id} + + try: + (creation_time, startup_bytes) = self.db.ParamQuerySingle( + query, params, txn_tag="ReadClientStartupInfo" + ) + except iterator.NoYieldsError: + return None + + startup = jobs_pb2.StartupInfo() + startup.ParseFromString(startup_bytes) + startup.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + return startup @db_utils.CallLogged @db_utils.CallAccounted @@ -174,6 +474,19 @@ def WriteClientCrashInfo( ) -> None: """Writes a new client crash record.""" + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.update(table="Clients", + columns=["ClientId", "LastCrashTime"], + values=[[client_id, spanner_lib.COMMIT_TIMESTAMP]]) + + mut.insert(table="ClientCrashes", + columns=["ClientId", "CreationTime", "Crash"], + values=[[client_id, spanner_lib.COMMIT_TIMESTAMP, crash]]) + + try: + self.db.Mutate(Mutation, txn_tag="WriteClientCrashInfo") + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) @db_utils.CallLogged @db_utils.CallAccounted @@ -183,7 +496,27 @@ def ReadClientCrashInfo( ) -> Optional[jobs_pb2.ClientCrash]: """Reads the latest client crash record for a single client.""" - return None + query = """ + SELECT cr.CreationTime, cr.Crash + FROM Clients AS c, ClientCrashes AS cr + WHERE c.ClientId = {client_id} + AND c.ClientId = cr.ClientId + AND c.LastCrashTime = cr.CreationTime + """ + params = {"client_id": client_id} + + try: + (creation_time, crash_bytes) = self.db.ParamQuerySingle( + query, params, txn_tag="ReadClientCrashInfo" + ) + except iterator.NoYieldsError: + return None + + crash = jobs_pb2.ClientCrash() + crash = crash.FromString(crash_bytes) + crash.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + return crash @db_utils.CallLogged @db_utils.CallAccounted @@ -200,9 +533,21 @@ def ReadClientCrashInfoHistory( WHERE cr.ClientId = {client_id} ORDER BY cr.CreationTime DESC """ - return None + params = {"client_id": client_id} + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadClientCrashInfoHistory" + ): + creation_time, crash_bytes = row + + crash = jobs_pb2.ClientCrash() + crash.ParseFromString(crash_bytes) + crash.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + result.append(crash) + + return result - # TODO(b/196379916): Investigate whether we need to batch this call or not. @db_utils.CallLogged @db_utils.CallAccounted def MultiReadClientFullInfo( @@ -211,11 +556,117 @@ def MultiReadClientFullInfo( min_last_ping: Optional[rdfvalue.RDFDatetime] = None, ) -> Mapping[str, objects_pb2.ClientFullInfo]: """Reads full client information for a list of clients.""" - # Spanner is having issues with `UNNEST` on empty arrays so we exit early in - # such cases. + if not client_ids: + return {} result = {} + query = """ + SELECT c.ClientId, + c.Certificate, + c.FleetspeakValidationInfo, + c.FirstSeenTime, c.LastPingTime, c.LastForemanTime, + c.LastSnapshotTime, c.LastStartupTime, c.LastCrashTime, + ss_last.Snapshot, + su_last.Startup, su_last_snapshot.Startup, + rrg_su_last.Startup, + ARRAY(SELECT AS STRUCT + l.Owner, l.Label + FROM ClientLabels AS l + WHERE l.ClientId = c.ClientId) + FROM Clients AS c + LEFT JOIN ClientSnapshots AS ss_last + ON (c.ClientId = ss_last.ClientId) + AND (c.LastSnapshotTime = ss_last.CreationTime) + LEFT JOIN ClientStartups AS su_last + ON (c.ClientId = su_last.ClientId) + AND (c.LastStartupTime = su_last.CreationTime) + LEFT JOIN ClientRrgStartups AS rrg_su_last + ON (c.ClientId = rrg_su_last.ClientId) + AND (c.LastRrgStartupTime = rrg_su_last.CreationTime) + LEFT JOIN ClientStartups AS su_last_snapshot + ON (c.ClientId = su_last_snapshot.ClientId) + AND (c.LastSnapshotTime = su_last_snapshot.CreationTime) + WHERE c.ClientId IN UNNEST({client_ids}) + """ + params = {"client_ids": client_ids} + + if min_last_ping is not None: + query += " AND c.LastPingTime >= {min_last_ping_time}" + params["min_last_ping_time"] = min_last_ping.AsDatetime() + + for row in self.db.ParamQuery( + query, params, txn_tag="MultiReadClientFullInfo" + ): + client_id, certificate_bytes, *row = row + fleetspeak_validation_info_bytes, *row = row + first_seen_time, last_ping_time, *row = row + last_foreman_time, *row = row + last_snapshot_time, last_startup_time, last_crash_time, *row = row + last_snapshot_bytes, *row = row + last_startup_bytes, last_snapshot_startup_bytes, *row = row + last_rrg_startup_bytes, *row = row + (label_rows,) = row + + info = objects_pb2.ClientFullInfo() + + if last_startup_bytes is not None: + info.last_startup_info.ParseFromString(last_startup_bytes) + info.last_startup_info.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_startup_time) + ) + + if last_snapshot_bytes is not None: + info.last_snapshot.ParseFromString(last_snapshot_bytes) + info.last_snapshot.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_snapshot_time) + ) + + if last_snapshot_startup_bytes is not None: + info.last_snapshot.startup_info.ParseFromString( + last_snapshot_startup_bytes + ) + info.last_snapshot.startup_info.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_snapshot_time) + ) + + if certificate_bytes is not None: + info.metadata.certificate = certificate_bytes + if fleetspeak_validation_info_bytes is not None: + info.metadata.last_fleetspeak_validation_info.ParseFromString( + fleetspeak_validation_info_bytes + ) + if first_seen_time is not None: + info.metadata.first_seen = int( + rdfvalue.RDFDatetime.FromDatetime(first_seen_time) + ) + if last_ping_time is not None: + info.metadata.ping = int( + rdfvalue.RDFDatetime.FromDatetime(last_ping_time) + ) + if last_foreman_time is not None: + info.metadata.last_foreman_time = int( + rdfvalue.RDFDatetime.FromDatetime(last_foreman_time) + ) + if last_startup_time is not None: + info.metadata.startup_info_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_startup_time) + ) + if last_crash_time is not None: + info.metadata.last_crash_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_crash_time) + ) + + info.last_snapshot.client_id = client_id + + for owner, label in label_rows: + info.labels.add(owner=owner, name=label) + + if last_rrg_startup_bytes is not None: + info.last_rrg_startup.ParseFromString(last_rrg_startup_bytes) + + result[client_id] = info + return result @db_utils.CallLogged @@ -227,8 +678,25 @@ def ReadClientLastPings( batch_size: int = 0, ) -> Iterator[Mapping[str, Optional[rdfvalue.RDFDatetime]]]: """Yields dicts of last-ping timestamps for clients in the DB.""" + if not batch_size: + batch_size = db_lib.CLIENT_IDS_BATCH_SIZE + + last_client_id = "0" + + while True: + client_last_pings_batch = self._ReadClientLastPingsBatch( + count=(batch_size or db_lib.CLIENT_IDS_BATCH_SIZE), + last_client_id=last_client_id, + min_last_ping_time=min_last_ping, + max_last_ping_time=max_last_ping, + ) + if client_last_pings_batch: + yield client_last_pings_batch + if len(client_last_pings_batch) < batch_size: + break + last_client_id = max(client_last_pings_batch.keys()) def _ReadClientLastPingsBatch( self, @@ -250,6 +718,38 @@ def _ReadClientLastPingsBatch( """ result = {} + query = """ + SELECT c.ClientId, c.LastPingTime + FROM Clients AS c + WHERE c.ClientId > {last_client_id} + """ + params = {"last_client_id": last_client_id} + + if min_last_ping_time is not None: + query += " AND c.LastPingTime >= {min_last_ping_time}" + params["min_last_ping_time"] = min_last_ping_time.AsDatetime() + if max_last_ping_time is not None: + query += ( + " AND (c.LastPingTime IS NULL OR " + " c.LastPingTime <= {max_last_ping_time})" + ) + params["max_last_ping_time"] = max_last_ping_time.AsDatetime() + + query += """ + ORDER BY ClientId + LIMIT {count} + """ + params["count"] = count + + for client_id, last_ping_time in self.db.ParamQuery( + query, params, txn_tag="ReadClientLastPingsBatch" + ): + + if last_ping_time is not None: + result[client_id] = rdfvalue.RDFDatetime.FromDatetime(last_ping_time) + else: + result[client_id] = None + return result @db_utils.CallLogged @@ -257,6 +757,21 @@ def _ReadClientLastPingsBatch( def DeleteClient(self, client_id: str) -> None: """Deletes a client with all associated metadata.""" + def Transaction(txn: spanner_utils.Transaction) -> None: + # It looks like Spanner does not raise exception if we attempt to delete + # a non-existing row, so we have to verify row existence ourself. + try: + txn.Read(table="Clients", key=(client_id,), cols=[]) + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) + + with txn.Mutate() as mut: + mut.Delete(table="Clients", key=(client_id,)) + mut.DeleteWithPrefix(table="ClientSnapshots", key=(client_id,)) + mut.DeleteWithPrefix(table="ClientStartups", key=(client_id,)) + mut.DeleteWithPrefix(table="ClientCrashes", key=(client_id,)) + + self.db.Transact(Transaction, txn_tag="DeleteClient") @db_utils.CallLogged @db_utils.CallAccounted @@ -267,7 +782,22 @@ def MultiAddClientKeywords( ) -> None: """Associates the provided keywords with the specified clients.""" # Early return to avoid generating empty mutation. - + if not client_ids or not keywords: + return + + def Mutation(mut: spanner_utils.Mutation) -> None: + for client_id in client_ids: + rows = [] + for keyword in keywords: + row = [client_id, keyword, spanner_lib.COMMIT_TIMESTAMP] + rows.append(row) + columns = ["ClientId", "Keyword", "CreationTime"] + mut.insert_or_update(table="ClientKeywords", columns=columns, values=rows) + + try: + self.db.Mutate(Mutation, txn_tag="MultiAddClientKeywords") + except NotFound as error: + raise db_lib.AtLeastOneUnknownClientError(client_ids) from error @db_utils.CallLogged @db_utils.CallAccounted @@ -277,39 +807,43 @@ def ListClientsForKeywords( start_time: Optional[rdfvalue.RDFDatetime] = None, ) -> Mapping[str, Collection[str]]: """Lists the clients associated with keywords.""" + results = {keyword: [] for keyword in keywords} - return None - - @db_utils.CallLogged - @db_utils.CallAccounted - def RemoveClientKeyword(self, client_id: str, keyword: str) -> None: - """Removes the association of a particular client to a keyword.""" - + query = """ + SELECT k.Keyword, ARRAY_AGG(k.ClientId) + FROM ClientKeywords@{{FORCE_INDEX=ClientKeywordsByKeywordCreationTime}} AS k + WHERE k.Keyword IN UNNEST({keywords}) + """ + params = { + "keywords": spanner_lib.Array(str, keywords), + } + if start_time is not None: + query += " AND k.CreationTime >= {cutoff_time}" + params["cutoff_time"] = start_time.AsDatetime() -def IntClientID(client_id: str) -> int: - """Converts a client identifier to its integer representation. + query += " GROUP BY k.Keyword" - This function wraps the value in PySpanner's `UInt64` wrapper. It is needed - because by default PySpanner assumes that integers are `Int64` and this can - cause conversion errors for large values. + for keyword, client_ids in self.db.ParamQuery( + query, params, txn_tag="ListClientsForKeywords" + ): + results[keyword].extend(client_ids) - Args: - client_id: A client identifier to convert. + return results - Returns: - An integer representation of the given client identifier. - """ - return db_utils.ClientIDToInt(client_id) + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveClientKeyword(self, client_id: str, keyword: str) -> None: + """Removes the association of a particular client to a keyword.""" + self.db.Delete( + table="ClientKeywords", + key=(client_id, keyword), + txn_tag="RemoveClientKeyword", + ) _EPOCH = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) -# TODO(b/196379916): The F1 implementation uses a single constant for all -# queries deleting a lot of data. We should probably follow this pattern here, -# this constant should be moved to a more appropriate module. Also, it might -# be worthwhile to use Spanner's partitioned DML feature [1]. -# # pylint: disable=line-too-long # [1]: https://g3doc.corp.google.com/spanner/g3doc/userguide/sqlv1/data-manipulation-language.md#a-note-about-locking # pylint: enable=line-too-long diff --git a/grr/server/grr_response_server/databases/spanner_clients_test.py b/grr/server/grr_response_server/databases/spanner_clients_test.py new file mode 100644 index 000000000..ccd0f5e1a --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_clients_test.py @@ -0,0 +1,35 @@ +from absl.testing import absltest + +from grr.server.grr_response_server.databases import db +from grr.server.grr_response_server.databases import db_clients_test +from grr.server.grr_response_server.databases import db_test_utils +from grr.server.grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseClientsTest( + db_clients_test.DatabaseTestClientsMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + + # TODO(b/196379916): Enforce this constraint in other database implementations + # and move this test to `DatabaseTestClientsMixin`. + def testLabelWriteToUnknownUser(self): + client_id = db_test_utils.InitializeClient(self.db) + + with self.assertRaises(db.UnknownGRRUserError) as ctx: + self.db.AddClientLabels(client_id, owner="foo", labels=["bar", "baz"]) + + self.assertIsInstance(ctx.exception, db.UnknownGRRUserError) + self.assertEqual(ctx.exception.username, "foo") + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index 127b7acd0..2c704725f 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -4,6 +4,8 @@ import datetime from typing import Any, Mapping, Optional, Sequence +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils from grr_response_proto import flows_pb2 @@ -29,7 +31,19 @@ def WriteCronJob(self, cronjob: flows_pb2.CronJob) -> None: cronjob: A flows_pb2.CronJob object. """ # We currently expect to reuse `created_at` if set. - + rdf_created_at = rdfvalue.RDFDatetime().FromMicrosecondsSinceEpoch( + cronjob.created_at + ) + creation_time = rdf_created_at.AsDatetime() or spanner_lib.COMMIT_TIMESTAMP + + row = { + "JobId": cronjob.cron_job_id, + "Job": cronjob, + "Enabled": bool(cronjob.enabled), + "CreationTime": creation_time, + } + + self.db.InsertOrUpdate(table="CronJobs", row=row, txn_tag="WriteCronJob") @db_utils.CallLogged @db_utils.CallAccounted diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 957f27f09..720ddd71c 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -276,8 +276,8 @@ def WriteFlowObject( row["State"] = int(flow_obj.flow_state) row["NextRequestToProcess"] = flow_obj.next_request_to_process - row["CreationTime"] = spanner_lib.CommitTimestamp() - row["UpdateTime"] = spanner_lib.CommitTimestamp() + row["CreationTime"] = spanner_lib.COMMIT_TIMESTAMP + row["UpdateTime"] = spanner_lib.COMMIT_TIMESTAMP if flow_obj.HasField("client_crash_info"): row["Crash"] = flow_obj.client_crash_info @@ -299,8 +299,8 @@ def WriteFlowObject( row["Flow"] = flow_obj - row["ReplyCount"] = spanner_lib.UInt64(flow_obj.num_replies_sent) - row["NetworkBytesSent"] = spanner_lib.UInt64(flow_obj.network_bytes_sent) + row["ReplyCount"] = int(flow_obj.num_replies_sent) + row["NetworkBytesSent"] = int(flow_obj.network_bytes_sent) row["UserCpuTimeUsed"] = float(flow_obj.cpu_time_used.user_cpu_time) row["SystemCpuTimeUsed"] = float(flow_obj.cpu_time_used.system_cpu_time) @@ -311,9 +311,9 @@ def WriteFlowObject( ) else: self.db.Insert(table="Flows", row=row, txn_tag="WriteFlowObject_I") - except spanner_errors.AlreadyExistsError as error: + except AlreadyExists as error: raise db.FlowExistsError(client_id, flow_id) from error - except spanner_errors.RowNotFoundError as error: + except NotFound as error: if "Parent row is missing: Clients" in str(error): raise db.UnknownClientError(client_id) else: @@ -453,14 +453,14 @@ def UpdateFlow( row = { "ClientId": spanner_clients.IntClientID(client_id), "FlowId": IntFlowID(flow_id), - "UpdateTime": spanner_lib.CommitTimestamp(), + "UpdateTime": spanner_lib.COMMIT_TIMESTAMP, } if isinstance(flow_obj, flows_pb2.Flow): row["Flow"] = flow_obj row["State"] = int(flow_obj.flow_state) - row["ReplyCount"] = spanner_lib.UInt64(flow_obj.num_replies_sent) - row["NetworkBytesSent"] = spanner_lib.UInt64(flow_obj.network_bytes_sent) + row["ReplyCount"] = int(flow_obj.num_replies_sent) + row["NetworkBytesSent"] = int(flow_obj.network_bytes_sent) row["UserCpuTimeUsed"] = float(flow_obj.cpu_time_used.user_cpu_time) row["SystemCpuTimeUsed"] = float(flow_obj.cpu_time_used.system_cpu_time) if isinstance(flow_state, flows_pb2.Flow.FlowState.ValueType): @@ -482,7 +482,7 @@ def UpdateFlow( try: self.db.Update(table="Flows", row=row, txn_tag="UpdateFlow") - except spanner_errors.RowNotFoundError as error: + except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) @db_utils.CallLogged @@ -740,7 +740,7 @@ def _BuildFlowProcessingRequestWrites( key = ( spanner_clients.IntClientID(r.client_id), IntFlowID(r.flow_id), - spanner_lib.CommitTimestamp(), + spanner_lib.COMMIT_TIMESTAMP, ) ts = None @@ -886,7 +886,7 @@ def Txn(txn) -> None: "NextResponseId": r.next_response_id, "CallbackState": r.callback_state, "Payload": r, - "CreationTime": spanner_lib.CommitTimestamp(), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, } if r.start_time: update_dict["StartTime"] = ( @@ -938,7 +938,7 @@ def Txn(txn) -> None: try: self.db.Transact(Txn, txn_tag="WriteFlowRequests") - except spanner_errors.RowNotFoundError as error: + except NotFound as error: if "Parent row is missing: Flows" in str(error): raise db.AtLeastOneUnknownFlowError(flow_keys, cause=error) else: @@ -1064,7 +1064,7 @@ def _BuildResponseWrites( "Response": None, "Status": None, "Iterator": None, - "CreationTime": spanner_lib.CommitTimestamp(), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, } if isinstance(r, flows_pb2.FlowResponse): @@ -1601,7 +1601,7 @@ def Mutation(mut) -> None: try: self.db.Mutate(Mutation, txn_tag="DeleteFlowRequests") # TODO(b/196379916): Narrow the exception types (cl/450440276). - except spanner_errors.BadUsageError: + except Exception: if len(requests) == 1: # If there is only one request and we still hit Spanner limits it means # that the requests has a lot of responses. It should be extremely rare @@ -1765,7 +1765,7 @@ def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: row = { "ClientId": spanner_clients.IntClientID(entry.client_id), "FlowId": IntFlowID(entry.flow_id), - "CreationTime": spanner_lib.CommitTimestamp(), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, "Message": entry.message, } @@ -1774,7 +1774,7 @@ def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: try: self.db.Insert(table="FlowLogEntries", row=row) - except spanner_errors.RowNotFoundError as error: + except NotFound as error: raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error def ReadFlowLogEntries( @@ -1873,13 +1873,13 @@ def Mutation(mut) -> None: "LogLevel": log.level, "LogTime": log.timestamp.ToDatetime(), "LogMessage": log.message, - "CreationTime": spanner_lib.CommitTimestamp(), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, } mut.Insert(table="FlowRRGLogs", row=row) try: self.db.Mutate(Mutation) - except spanner_errors.RowNotFoundError as error: + except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) from error @db_utils.CallLogged @@ -1942,7 +1942,7 @@ def WriteFlowOutputPluginLogEntry( "ClientId": spanner_clients.IntClientID(entry.client_id), "FlowId": IntFlowID(entry.flow_id), "OutputPluginId": IntOutputPluginID(entry.output_plugin_id), - "CreationTime": spanner_lib.CommitTimestamp(), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, "Type": int(entry.log_entry_type), "Message": entry.message, } @@ -1952,7 +1952,7 @@ def WriteFlowOutputPluginLogEntry( try: self.db.Insert(table="FlowOutputPluginLogEntries", row=row) - except spanner_errors.RowNotFoundError as error: + except NotFound as error: raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error @db_utils.CallLogged @@ -2072,7 +2072,7 @@ def WriteScheduledFlow( try: self.db.InsertOrUpdate(table="ScheduledFlows", row=row) - except spanner_errors.RowNotFoundError as error: + except NotFound as error: if "Parent row is missing: Clients" in str(error): raise db.UnknownClientError(scheduled_flow.client_id) from error elif "fk_creator_users_username" in str(error): @@ -2099,7 +2099,7 @@ def DeleteScheduledFlow( def Transaction(txn) -> None: try: txn.Read(table="ScheduledFlows", cols=["ScheduledFlowId"], key=key) - except spanner_errors.RowNotFoundError as e: + except NotFound as e: raise db.UnknownScheduledFlowError( client_id=client_id, creator=creator, @@ -2313,7 +2313,7 @@ def _ReadHuntState( try: row = txn.Read(table="Hunts", key=(IntHuntID(hunt_id),), cols=("State",)) return row["State"] - except spanner_errors.RowNotFoundError: + except NotFound: return None @db_utils.CallLogged @@ -2335,7 +2335,7 @@ def Txn(txn) -> flows_pb2.Flow: key=(int_client_id, int_flow_id), cols=_READ_FLOW_OBJECT_COLS, ) - except spanner_errors.RowNotFoundError as error: + except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) flow = _ParseReadFlowObjectRow(client_id, flow_id, row) @@ -2381,7 +2381,7 @@ def Txn(txn) -> flows_pb2.Flow: flow.processing_deadline ).AsDatetime() ), - "ProcessingStartTime": spanner_lib.CommitTimestamp(), + "ProcessingStartTime": spanner_lib.COMMIT_TIMESTAMP, }, ) @@ -2418,7 +2418,7 @@ def Txn(txn) -> bool: < rdfvalue.RDFDatetime.Now() ): return False - except spanner_errors.RowNotFoundError: + except NotFound: pass txn.Update( table="Flows", @@ -2431,16 +2431,16 @@ def Txn(txn) -> bool: "SystemCpuTimeUsed": float( flow_obj.cpu_time_used.system_cpu_time ), - "NetworkBytesSent": spanner_lib.UInt64( + "NetworkBytesSent": int( flow_obj.network_bytes_sent ), "ProcessingWorker": None, "ProcessingStartTime": None, "ProcessingEndTime": None, - "NextRequesttoProcess": spanner_lib.UInt64( + "NextRequesttoProcess": int( flow_obj.next_request_to_process ), - "UpdateTime": spanner_lib.CommitTimestamp(), + "UpdateTime": spanner_lib.COMMIT_TIMESTAMP, "ReplyCount": flow_obj.num_replies_sent, }, ) diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index f4645c32c..02d0443e2 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -3,6 +3,10 @@ from typing import AbstractSet, Callable, Collection, Iterable, List, Mapping, Optional, Sequence +from google.api_core.exceptions import AlreadyExists + +from google.cloud import spanner as spanner_lib + from google.protobuf import any_pb2 from grr_response_core.lib import rdfvalue from grr_response_core.lib.rdfvalues import client as rdf_client @@ -27,7 +31,30 @@ class HuntsMixin: @db_utils.CallAccounted def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt): """Writes a hunt object to the database.""" - + row = { + "HuntId": hunt_obj.hunt_id, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "LastUpdateTime": spanner_lib.COMMIT_TIMESTAMP, + "Creator": hunt_obj.creator, + "DurationMicros": hunt_obj.duration * 10**6, + "Description": hunt_obj.description, + "ClientRate": float(hunt_obj.client_rate), + "ClientLimit": hunt_obj.client_limit, + "State": int(hunt_obj.hunt_state), + "StateReason": int(hunt_obj.hunt_state_reason), + "StateComment": hunt_obj.hunt_state_comment, + "InitStartTime": None, + "LastStartTime": None, + "ClientCountAtStartTime": hunt_obj.num_clients_at_start_time, + "Hunt": hunt_obj, + } + + try: + self.db.Insert(table="Hunts", row=row, txn_tag="WriteHuntObject") + except AlreadyExists as error: + raise abstract_db.DuplicatedHuntError( + hunt_id=hunt_obj.hunt_id, cause=error + ) @db_utils.CallLogged diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 1f863eec7..a786ae038 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -3,8 +3,8 @@ import datetime import logging -import random import sys +import uuid from typing import Optional, Sequence, Tuple @@ -67,21 +67,21 @@ def DeleteGRRUser(self, username: str) -> None: def Transaction(txn) -> None: try: - txn.read(table="Users", columns=("Username",), keyset=keyset) + txn.read(table="Users", columns=("Username",), keyset=keyset).one() except NotFound: raise abstract_db.UnknownGRRUserError(username) query = f""" DELETE FROM ApprovalGrants@{{FORCE_INDEX=ApprovalGrantsByGrantor}} AS g - WHERE g.Grantor = {username} + WHERE g.Grantor = '{username}' """ txn.execute_sql(query) query = f""" DELETE FROM ScheduledFlows@{{FORCE_INDEX=ScheduledFlowsByCreator}} AS f - WHERE f.Creator = {username} + WHERE f.Creator = '{username}' """ txn.execute_sql(query) @@ -181,7 +181,7 @@ def CountGRRUsers(self) -> int: @db_utils.CallAccounted def WriteApprovalRequest(self, request: objects_pb2.ApprovalRequest) -> str: """Writes an approval request object.""" - approval_id = random.randint(0, sys.maxsize) + approval_id = str(uuid.uuid4()) row = { "Requestor": request.requestor_username, @@ -198,9 +198,9 @@ def WriteApprovalRequest(self, request: objects_pb2.ApprovalRequest) -> str: } if request.approval_type == _APPROVAL_TYPE_CLIENT: - row["SubjectClientId"] = db_utils.ClientIDToInt(request.subject_id) + row["SubjectClientId"] = request.subject_id elif request.approval_type == _APPROVAL_TYPE_HUNT: - row["SubjectHuntId"] = db_utils.HuntIDToInt(request.subject_id) + row["SubjectHuntId"] = request.subject_id elif request.approval_type == _APPROVAL_TYPE_CRON_JOB: row["SubjectCronJobId"] = request.subject_id else: @@ -210,7 +210,7 @@ def WriteApprovalRequest(self, request: objects_pb2.ApprovalRequest) -> str: table="ApprovalRequests", row=row, txn_tag="WriteApprovalRequest" ) - return _HexApprovalID(approval_id) + return approval_id @db_utils.CallLogged @db_utils.CallAccounted @@ -220,7 +220,6 @@ def ReadApprovalRequest( approval_id: str, ) -> objects_pb2.ApprovalRequest: """Reads an approval request object with a given id.""" - approval_id = _UnhexApprovalID(approval_id) query = """ SELECT r.SubjectClientId, r.SubjectHuntId, r.SubjectCronJobId, @@ -256,7 +255,7 @@ def ReadApprovalRequest( request = objects_pb2.ApprovalRequest( requestor_username=username, - approval_id=_HexApprovalID(approval_id), + approval_id=approval_id, reason=reason, timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), expiration_time=RDFDatetime(expiration_time).AsMicrosecondsSinceEpoch(), @@ -265,10 +264,10 @@ def ReadApprovalRequest( ) if subject_client_id is not None: - request.subject_id = db_utils.IntToClientID(subject_client_id) + request.subject_id = subject_client_id request.approval_type = _APPROVAL_TYPE_CLIENT elif subject_hunt_id is not None: - request.subject_id = db_utils.IntToHuntID(subject_hunt_id) + request.subject_id = subject_hunt_id request.approval_type = _APPROVAL_TYPE_HUNT elif subject_cron_job_id is not None: request.subject_id = subject_cron_job_id @@ -327,13 +326,13 @@ def ReadApprovalRequests( query += " AND r.SubjectClientId IS NOT NULL" if subject_id is not None: query += " AND r.SubjectClientId = {{subject_client_id}}" - params["subject_client_id"] = db_utils.ClientIDToInt(subject_id) + params["subject_client_id"] = subject_id index = "ApprovalRequestsByRequestorSubjectClientId" elif typ == _APPROVAL_TYPE_HUNT: query += " AND r.SubjectHuntId IS NOT NULL" if subject_id is not None: query += " AND r.SubjectHuntId = {{subject_hunt_id}}" - params["subject_hunt_id"] = db_utils.HuntIDToInt(subject_id) + params["subject_hunt_id"] = subject_id index = "ApprovalRequestsByRequestorSubjectHuntId" elif typ == _APPROVAL_TYPE_CRON_JOB: query += " AND r.SubjectCronJobId IS NOT NULL" @@ -360,7 +359,7 @@ def ReadApprovalRequests( request = objects_pb2.ApprovalRequest( requestor_username=username, - approval_id=_HexApprovalID(approval_id), + approval_id=approval_id, reason=reason, timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), expiration_time=RDFDatetime( @@ -371,10 +370,10 @@ def ReadApprovalRequests( ) if subject_client_id is not None: - request.subject_id = db_utils.IntToClientID(subject_client_id) + request.subject_id = subject_client_id request.approval_type = _APPROVAL_TYPE_CLIENT elif subject_hunt_id is not None: - request.subject_id = db_utils.IntToHuntID(subject_hunt_id) + request.subject_id = subject_hunt_id request.approval_type = _APPROVAL_TYPE_HUNT elif subject_cron_job_id is not None: request.subject_id = subject_cron_job_id @@ -406,10 +405,9 @@ def GrantApproval( """Grants approval for a given request using given username.""" row = { "Requestor": requestor_username, - "ApprovalId": _UnhexApprovalID(approval_id), + "ApprovalId": approval_id, "Grantor": grantor_username, - # TODO: Look into Spanner sequences to generate unique IDs. - "GrantId": random.randint(0, sys.maxsize), + "GrantId": str(uuid.uuid4()), "CreationTime": spanner_lib.COMMIT_TIMESTAMP, } @@ -424,8 +422,7 @@ def WriteUserNotification( """Writes a notification for a given user.""" row = { "Username": notification.username, - # TODO: Look into Spanner sequences to generate unique IDs. - "NotificationId": random.randint(0, sys.maxsize), + "NotificationId": str(uuid.uuid4()), "Type": int(notification.notification_type), "State": int(notification.state), "CreationTime": spanner_lib.COMMIT_TIMESTAMP, @@ -526,7 +523,7 @@ def UpdateUserNotifications( query = f""" UPDATE UserNotifications n SET n.State = {state} - WHERE n.Username = {username} + WHERE n.Username = '{username}' AND n.CreationTime IN ({param_placeholders}) """ From e1d267bf00dc1a41558ed3360ca4696e83e0fa1c Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sun, 8 Jun 2025 17:23:14 +0000 Subject: [PATCH 018/168] Adds Clients table --- .../grr_response_server/databases/spanner.sdl | 42 ++-- .../databases/spanner_clients.py | 49 ++--- .../databases/spanner_clients_test.py | 10 +- .../databases/spanner_flows.py | 188 ++++++++---------- .../databases/spanner_foreman_rules.py | 6 +- .../databases/spanner_utils.py | 9 +- 6 files changed, 136 insertions(+), 168 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 7aaff306e..d202012d9 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -312,9 +312,9 @@ CREATE INDEX ClientKeywordsByKeywordCreationTime CREATE TABLE Flows( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - ParentFlowId STRING(8), - ParentHuntId STRING(8), + FlowId STRING(16) NOT NULL, + ParentFlowId STRING(16), + ParentHuntId STRING(16), LongFlowId STRING(256) NOT NULL, Creator STRING(256) NOT NULL, Name STRING(256) NOT NULL, @@ -348,8 +348,8 @@ CREATE INDEX FlowsByParentHuntIdFlowIdState CREATE TABLE FlowResults( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - HuntId STRING(8), + FlowId STRING(16) NOT NULL, + HuntId STRING(16), CreationTime TIMESTAMP NOT NULL, Payload `google.protobuf.Any`, RdfType STRING(MAX), @@ -366,8 +366,8 @@ CREATE INDEX FlowResultsByHuntIdFlowIdRdfTypeTagCreationTime CREATE TABLE FlowErrors( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - HuntId STRING(8) NOT NULL, + FlowId STRING(16) NOT NULL, + HuntId STRING(16) NOT NULL, CreationTime TIMESTAMP NOT NULL, Payload `google.protobuf.Any`, RdfType STRING(MAX), @@ -382,11 +382,11 @@ CREATE INDEX FlowErrorsByHuntIdFlowIdRdfTypeTagCreationTime CREATE TABLE FlowRequests( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - RequestId STRING(8) NOT NULL, + FlowId STRING(16) NOT NULL, + RequestId STRING(16) NOT NULL, NeedsProcessing BOOL, ExpectedResponseCount INT64, - NextResponseId STRING(8), + NextResponseId STRING(16), CallbackState STRING(256), Payload `grr.FlowRequest` NOT NULL, StartTime TIMESTAMP, @@ -397,9 +397,9 @@ CREATE TABLE FlowRequests( CREATE TABLE FlowResponses( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - RequestId STRING(8) NOT NULL, - ResponseId STRING(8) NOT NULL, + FlowId STRING(16) NOT NULL, + RequestId STRING(16) NOT NULL, + ResponseId STRING(16) NOT NULL, Response `grr.FlowResponse`, Status `grr.FlowStatus`, Iterator `grr.FlowIterator`, @@ -414,8 +414,8 @@ CREATE TABLE FlowResponses( CREATE TABLE FlowLogEntries( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - HuntId STRING(8), + FlowId STRING(16) NOT NULL, + HuntId STRING(16), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Message STRING(MAX) NOT NULL, ) PRIMARY KEY (ClientId, FlowId, CreationTime), @@ -426,9 +426,9 @@ CREATE INDEX FlowLogEntriesByHuntIdCreationTime CREATE TABLE FlowRRGLogs( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, - RequestId STRING(8) NOT NULL, - ResponseId STRING(8) NOT NULL, + FlowId STRING(16) NOT NULL, + RequestId STRING(16) NOT NULL, + ResponseId STRING(16) NOT NULL, LogLevel `rrg.Log.Level` NOT NULL, LogTime TIMESTAMP NOT NULL, LogMessage STRING(MAX) NOT NULL, @@ -438,9 +438,9 @@ CREATE TABLE FlowRRGLogs( CREATE TABLE FlowOutputPluginLogEntries( ClientId STRING(18) NOT NULL, - FlowId STRING(8) NOT NULL, + FlowId STRING(16) NOT NULL, OutputPluginId STRING(8) NOT NULL, - HuntId STRING(8), + HuntId STRING(16), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, Message STRING(MAX) NOT NULL, @@ -453,7 +453,7 @@ CREATE INDEX FlowOutputPluginLogEntriesByHuntIdCreationTime CREATE TABLE ScheduledFlows( ClientId STRING(18) NOT NULL, Creator STRING(256) NOT NULL, - ScheduledFlowId STRING(8) NOT NULL, + ScheduledFlowId STRING(16) NOT NULL, FlowName STRING(256) NOT NULL, FlowArgs `google.protobuf.Any` NOT NULL, RunnerArgs `grr.FlowRunnerArgs` NOT NULL, diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index 8e6cc3d59..ab742fa96 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -159,24 +159,12 @@ def Mutation(mut) -> None: try: self.db.Mutate(Mutation, txn_tag="MultiAddClientLabels") - except NotFound as error: + except Exception as error: message = str(error) - if "Parent row is missing: Clients" in message: + if "Parent row for row [" in message: raise db_lib.AtLeastOneUnknownClientError(client_ids) from error elif "fk_client_label_owner_username" in message: - match = re.search( - r"\((?P\w+)\) in Users\(Username\)", - message, - ) - if match is not None: - username = match["username"] - else: - username = "" - logging.error( - "Couldn't extract username from foreign key constraint: %s", - message, - ) - raise db_lib.UnknownGRRUserError(username=username, cause=error) + raise db_lib.UnknownGRRUserError(username=owner, cause=error) else: raise @@ -221,9 +209,11 @@ def RemoveClientLabels( ) -> None: """Removes a list of user labels from a given client.""" - def Mutation(mut: spanner_utils.Mutation) -> None: + def Mutation(mut) -> None: + keys = [] for label in labels: - mut.Delete(table="ClientLabels", key=(client_id, owner, label)) + keys.append([client_id, owner, label]) + mut.delete(table="ClientLabels", keyset=spanner_lib.KeySet(keys=keys)) self.db.Mutate(Mutation, txn_tag="RemoveClientLabels") @@ -371,7 +361,7 @@ def Mutation(mut: spanner_utils.Mutation) -> None: mut.insert(table="ClientStartups", columns=["ClientId", "CreationTime", "Startup"], - values=[client_id ,spanner_lib.COMMIT_TIMESTAMP, startup]) + values=[[client_id ,spanner_lib.COMMIT_TIMESTAMP, startup]]) try: self.db.Mutate(Mutation, txn_tag="WriteClientStartupInfo") @@ -428,7 +418,7 @@ def ReadClientRRGStartup( (startup_bytes,) = self.db.ParamQuerySingle( query, params, txn_tag="ReadClientRRGStartup" ) - except iterator.NoYieldsError: + except NotFound: raise db_lib.UnknownClientError(client_id) # pylint: disable=raise-missing-from if startup_bytes is None: @@ -457,7 +447,7 @@ def ReadClientStartupInfo( (creation_time, startup_bytes) = self.db.ParamQuerySingle( query, params, txn_tag="ReadClientStartupInfo" ) - except iterator.NoYieldsError: + except NotFound: return None startup = jobs_pb2.StartupInfo() @@ -509,7 +499,7 @@ def ReadClientCrashInfo( (creation_time, crash_bytes) = self.db.ParamQuerySingle( query, params, txn_tag="ReadClientCrashInfo" ) - except iterator.NoYieldsError: + except NotFound: return None crash = jobs_pb2.ClientCrash() @@ -757,19 +747,17 @@ def _ReadClientLastPingsBatch( def DeleteClient(self, client_id: str) -> None: """Deletes a client with all associated metadata.""" - def Transaction(txn: spanner_utils.Transaction) -> None: + def Transaction(txn) -> None: # It looks like Spanner does not raise exception if we attempt to delete # a non-existing row, so we have to verify row existence ourself. + keyrange = spanner_lib.KeyRange(start_closed=[client_id], end_closed=[client_id]) + keyset = spanner_lib.KeySet(ranges=[keyrange]) try: - txn.Read(table="Clients", key=(client_id,), cols=[]) + txn.read(table="Clients", keyset=keyset, columns=["ClientId"]).one() except NotFound as error: raise db_lib.UnknownClientError(client_id, cause=error) - with txn.Mutate() as mut: - mut.Delete(table="Clients", key=(client_id,)) - mut.DeleteWithPrefix(table="ClientSnapshots", key=(client_id,)) - mut.DeleteWithPrefix(table="ClientStartups", key=(client_id,)) - mut.DeleteWithPrefix(table="ClientCrashes", key=(client_id,)) + txn.delete(table="Clients", keyset=keyset) self.db.Transact(Transaction, txn_tag="DeleteClient") @@ -815,7 +803,7 @@ def ListClientsForKeywords( WHERE k.Keyword IN UNNEST({keywords}) """ params = { - "keywords": spanner_lib.Array(str, keywords), + "keywords": list(keywords), } if start_time is not None: @@ -844,7 +832,4 @@ def RemoveClientKeyword(self, client_id: str, keyword: str) -> None: _EPOCH = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) -# pylint: disable=line-too-long -# [1]: https://g3doc.corp.google.com/spanner/g3doc/userguide/sqlv1/data-manipulation-language.md#a-note-about-locking -# pylint: enable=line-too-long _DELETE_BATCH_SIZE = 5_000 diff --git a/grr/server/grr_response_server/databases/spanner_clients_test.py b/grr/server/grr_response_server/databases/spanner_clients_test.py index ccd0f5e1a..450b8e295 100644 --- a/grr/server/grr_response_server/databases/spanner_clients_test.py +++ b/grr/server/grr_response_server/databases/spanner_clients_test.py @@ -1,9 +1,9 @@ from absl.testing import absltest -from grr.server.grr_response_server.databases import db -from grr.server.grr_response_server.databases import db_clients_test -from grr.server.grr_response_server.databases import db_test_utils -from grr.server.grr_response_server.databases import spanner_test_lib +from grr_response_server.databases import db +from grr_response_server.databases import db_clients_test +from grr_response_server.databases import db_test_utils +from grr_response_server.databases import spanner_test_lib def setUpModule() -> None: @@ -19,8 +19,6 @@ class SpannerDatabaseClientsTest( ): # Test methods are defined in the base mixin class. - # TODO(b/196379916): Enforce this constraint in other database implementations - # and move this test to `DatabaseTestClientsMixin`. def testLabelWriteToUnknownUser(self): client_id = db_test_utils.InitializeClient(self.db) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 720ddd71c..f21576831 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -261,15 +261,15 @@ def WriteFlowObject( flow_id = flow_obj.flow_id row = { - "ClientId": spanner_clients.IntClientID(client_id), - "FlowId": IntFlowID(flow_id), + "ClientId": client_id, + "FlowId": flow_id, "LongFlowId": flow_obj.long_flow_id, } if flow_obj.parent_flow_id: - row["ParentFlowId"] = IntFlowID(flow_obj.parent_flow_id) + row["ParentFlowId"] = flow_obj.parent_flow_id if flow_obj.parent_hunt_id: - row["ParentHuntId"] = IntHuntID(flow_obj.parent_hunt_id) + row["ParentHuntId"] = flow_obj.parent_hunt_id row["Creator"] = flow_obj.creator row["Name"] = flow_obj.flow_class_name @@ -327,13 +327,11 @@ def ReadFlowObject( flow_id: str, ) -> flows_pb2.Flow: """Reads a flow object from the database.""" - int_client_id = spanner_clients.IntClientID(client_id) - int_flow_id = IntFlowID(flow_id) try: row = self.db.Read( table="Flows", - key=(int_client_id, int_flow_id), + key=[client_id, flow_id], cols=_READ_FLOW_OBJECT_COLS, ) except NotFound as error: @@ -738,8 +736,8 @@ def _BuildFlowProcessingRequestWrites( for r in requests: key = ( - spanner_clients.IntClientID(r.client_id), - IntFlowID(r.flow_id), + r.client_id, + r.flow_id, spanner_lib.COMMIT_TIMESTAMP, ) @@ -765,10 +763,11 @@ def WriteFlowProcessingRequests( ) -> None: """Writes a list of flow processing requests to the database.""" - def Mutation(mut) -> None: - self._BuildFlowProcessingRequestWrites(mut, requests) - - self.db.BufferedMutate(Mutation, txn_tag="WriteFlowProcessingRequests") + ### TODO ### + #def Mutation(mut) -> None: + # self._BuildFlowProcessingRequestWrites(mut, requests) + # + #self.db.Mutate(Mutation, txn_tag="WriteFlowProcessingRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -870,59 +869,51 @@ def WriteFlowRequests( def Txn(txn) -> None: needs_processing = {} - with txn.Mutate() as mut: - for r in requests: - if r.needs_processing: - needs_processing.setdefault((r.client_id, r.flow_id), []).append(r) - - client_id_int = spanner_clients.IntClientID(r.client_id) - flow_id_int = IntFlowID(r.flow_id) - - update_dict = { - "ClientId": client_id_int, - "FlowId": flow_id_int, - "RequestId": r.request_id, - "NeedsProcessing": r.needs_processing, - "NextResponseId": r.next_response_id, - "CallbackState": r.callback_state, - "Payload": r, - "CreationTime": spanner_lib.COMMIT_TIMESTAMP, - } - if r.start_time: - update_dict["StartTime"] = ( - rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( - r.start_time - ).AsDatetime() - ) + columns = ["ClientId", + "FlowId", + "RequestId", + "NeedsProcessing", + "NextResponseId", + "CallbackState", + "Payload", + "CreationTime", + "StartTime"] + rows = [] + for r in requests: + if r.needs_processing: + needs_processing.setdefault((r.client_id, r.flow_id), []).append(r) + + if r.start_time: + start_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(r.start_time).AsDatetime() + else: + start_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(0).AsDatetime() - mut.InsertOrUpdate("FlowRequests", update_dict) + rows.append([r.client_id, r.flow_id, str(r.request_id), r.needs_processing, str(r.next_response_id), + r.callback_state, r, spanner_lib.COMMIT_TIMESTAMP, start_time]) + txn.insert_or_update(table="FlowRequests", columns=columns, values=rows) if needs_processing: flow_processing_requests = [] - rows = spanner_lib.RowSet() + keys = [] # Note on linting: adding .keys() triggers a warning that # .keys() should be omitted. Omitting keys leads to a # mistaken warning that .items() was not called. for client_id, flow_id in needs_processing: # pylint: disable=dict-iter-missing-items - rows.Add( - spanner_lib.Key( - spanner_clients.IntClientID(client_id), IntFlowID(flow_id) - ) - ) + keys.Add((client_id, flow_id)) - cols = ( - "ClientId", - "FlowId", - "NextRequestToProcess", + columns = ( + "ClientId", + "FlowId", + "NextRequestToProcess", ) - for row in txn.ReadSet(table="Flows", rows=rows, cols=cols): - client_id = db_utils.IntToClientID(row["ClientId"]) - flow_id = db_utils.IntToFlowID(row["FlowId"]) + for row in txn.read(table="Flows", keyset=spanner_lib.KeySet(keys=keys), columns=columns): + client_id = row[0] + flow_id = row[1] candidate_requests = needs_processing.get((client_id, flow_id), []) for r in candidate_requests: - if row["NextRequestToProcess"] == r.request_id or r.start_time: + if row[2] == r.request_id or r.start_time: req = flows_pb2.FlowProcessingRequest( client_id=client_id, flow_id=flow_id ) @@ -930,11 +921,11 @@ def Txn(txn) -> None: req.delivery_time = r.start_time flow_processing_requests.append(req) - if flow_processing_requests: - with txn.BufferedMutate() as mut: - self._BuildFlowProcessingRequestWrites( - mut, flow_processing_requests - ) + #if flow_processing_requests: + # with txn.BufferedMutate() as mut: + # self._BuildFlowProcessingRequestWrites( + # mut, flow_processing_requests + # ) try: self.db.Transact(Txn, txn_tag="WriteFlowRequests") @@ -989,20 +980,14 @@ def _ReadRequestsInfo( # Callback states by request. callback_state_by_request = {} - req_rows = spanner_lib.RowSet() + keys = [] for r in responses: - req_rows.Add( - spanner_lib.Key( - spanner_clients.IntClientID(r.client_id), - IntFlowID(r.flow_id), - r.request_id, - ) - ) + keys.append([r.client_id, r.flow_id, str(r.request_id)]) - for row in txn.ReadSet( + for row in txn.read( table="FlowRequests", - rows=req_rows, - cols=[ + keyset=spanner_lib.KeySet(keys=keys), + columns=[ "ClientID", "FlowID", "RequestID", @@ -1012,17 +997,17 @@ def _ReadRequestsInfo( ): request_key = _RequestKey( - db_utils.IntToClientID(row["ClientID"]), - db_utils.IntToFlowID(row["FlowID"]), - row["RequestID"], + row[0], + row[1], + int(row[2]), ) currently_available_requests.add(request_key) - callback_state: str = row["CallbackState"] + callback_state: str = row[3] if callback_state: callback_state_by_request[request_key] = callback_state - responses_expected: int = row["ExpectedResponseCount"] + responses_expected: int = row[4] if responses_expected: responses_expected_by_request[request_key] = responses_expected @@ -1053,31 +1038,32 @@ def _BuildResponseWrites( TypeError: if responses have objects other than FlowResponse, FlowStatus or FlowIterator. """ + columns = ["ClientId", + "FlowId", + "RequestId", + "ResponseId", + "Response", + "Status", + "Iterator", + "CreationTime"] + rows = [] + for r in responses: + response = None + status = None + iterator = None + if isinstance(r, flows_pb2.FlowResponse): + response = r + elif isinstance(r, flows_pb2.FlowStatus): + status = r + elif isinstance(r, flows_pb2.FlowIterator): + iterator = r + else: + # This can't really happen due to DB validator type checking. + raise TypeError(f"Got unexpected response type: {type(r)} {r}") + rows.append([r.client_id, r.flow_id, str(r.request_id), str(r.response_id), + response,status,iterator,spanner_lib.COMMIT_TIMESTAMP]) - with txn.Mutate() as mut: - for r in responses: - row = { - "ClientId": spanner_clients.IntClientID(r.client_id), - "FlowId": IntFlowID(r.flow_id), - "RequestId": r.request_id, - "ResponseId": r.response_id, - "Response": None, - "Status": None, - "Iterator": None, - "CreationTime": spanner_lib.COMMIT_TIMESTAMP, - } - - if isinstance(r, flows_pb2.FlowResponse): - row["Response"] = r - elif isinstance(r, flows_pb2.FlowStatus): - row["Status"] = r - elif isinstance(r, flows_pb2.FlowIterator): - row["Iterator"] = r - else: - # This can't really happen due to DB validator type checking. - raise TypeError(f"Got unexpected response type: {type(r)} {r}") - - mut.InsertOrUpdate("FlowResponses", row) + txn.insert_or_update(table="FlowResponses", columns=columns, values=rows) def _BuildExpectedUpdates( self, updates: dict[_RequestKey, int], txn @@ -1133,9 +1119,7 @@ def _WriteFlowResponsesAndExpectedUpdates( if not responses: return ({}, {}) - def Txn( - txn: spanner_utils.Transaction, - ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: + def Txn(txn) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: ( responses_expected_by_request, callback_state_by_request, @@ -1212,9 +1196,9 @@ def Txn( return responses_expected_by_request, callback_state_by_request - return self.db.Transact( + return tuple(self.db.Transact( Txn, txn_tag="WriteFlowResponsesAndExpectedUpdates" - ).value + )) def _GetFlowResponsesPerRequestCounts( self, diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules.py b/grr/server/grr_response_server/databases/spanner_foreman_rules.py index f21f96eac..5c61300e8 100644 --- a/grr/server/grr_response_server/databases/spanner_foreman_rules.py +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules.py @@ -18,9 +18,8 @@ class ForemanRulesMixin: @db_utils.CallAccounted def WriteForemanRule(self, rule: jobs_pb2.ForemanCondition) -> None: """Writes a foreman rule to the database.""" - hunt_id_int = db_utils.HuntIDToInt(rule.hunt_id) row = { - "HuntId": hunt_id_int, + "HuntId": rule.hunt_id_int, "ExpirationTime": ( rdfvalue.RDFDatetime() .FromMicrosecondsSinceEpoch(rule.expiration_time) @@ -36,9 +35,8 @@ def WriteForemanRule(self, rule: jobs_pb2.ForemanCondition) -> None: @db_utils.CallAccounted def RemoveForemanRule(self, hunt_id: str) -> None: """Removes a foreman rule from the database.""" - hunt_id_int = db_utils.HuntIDToInt(hunt_id) self.db.Delete( - table="ForemanRules", key=(hunt_id_int), txn_tag="RemoveForemanRule" + table="ForemanRules", key=[hunt_id], txn_tag="RemoveForemanRule" ) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index c9b116ab6..547a28c11 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -153,6 +153,11 @@ def _get_param_type(self, value): return param_types.TIMESTAMP elif py_type is decimal.Decimal: return param_types.NUMERIC + elif py_type is list: + if len(value) > 0: + return param_types.Array(self._get_param_type(value[0])) + else: + raise TypeError(f"Empty value for Python type: {py_type.__name__} for Spanner type conversion.") else: # Potentially raise an error for unsupported types or return None # For a generic solution, raising an error for unknown types is often safer. @@ -490,15 +495,13 @@ def Read( Returns: A mapping from columns to values of the read row. """ - range = KeyRange(start_closed=key, end_closed=key) - keyset = KeySet(ranges=[range]) + keyset = KeySet(keys=[key]) with self._pyspanner.snapshot() as snapshot: results = snapshot.read( table=table, columns=cols, keyset=keyset ) - return results.one() def ReadSet( From c23ebc4b908e614a890db0687be6bd605eb29388 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 9 Jun 2025 18:49:36 +0000 Subject: [PATCH 019/168] EoD 20250609 --- grr/proto/grr_response_proto/flows.proto | 1 + .../grr_response_server/databases/spanner.sdl | 12 +- .../databases/spanner_flows.py | 690 ++++++++---------- .../databases/spanner_utils.py | 2 +- 4 files changed, 299 insertions(+), 406 deletions(-) diff --git a/grr/proto/grr_response_proto/flows.proto b/grr/proto/grr_response_proto/flows.proto index 52b6cadff..5586a0577 100644 --- a/grr/proto/grr_response_proto/flows.proto +++ b/grr/proto/grr_response_proto/flows.proto @@ -2287,6 +2287,7 @@ message FlowProcessingRequest { optional uint64 creation_time = 4 [(sem_type) = { type: "RDFDatetime", }]; + optional string ack_id = 5; } message FlowRequest { diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index d202012d9..e5cbd04dd 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -322,7 +322,7 @@ CREATE TABLE Flows( CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), UpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Crash `grr.ClientCrash`, - NextRequestToProcess STRING(8), + NextRequestToProcess STRING(16), ProcessingWorker STRING(MAX), ProcessingStartTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), ProcessingEndTime TIMESTAMP, @@ -439,7 +439,7 @@ CREATE TABLE FlowRRGLogs( CREATE TABLE FlowOutputPluginLogEntries( ClientId STRING(18) NOT NULL, FlowId STRING(16) NOT NULL, - OutputPluginId STRING(8) NOT NULL, + OutputPluginId STRING(16) NOT NULL, HuntId STRING(16), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, @@ -542,7 +542,7 @@ CREATE TABLE YaraSignatureReferences( ) PRIMARY KEY (BlobId); CREATE TABLE Hunts( - HuntId STRING(8) NOT NULL, + HuntId STRING(16) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), LastUpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Creator STRING(256) NOT NULL, @@ -579,8 +579,8 @@ CREATE INDEX HuntsByCreationTime CREATE INDEX HuntsByCreator ON Hunts(Creator); CREATE TABLE HuntOutputPlugins( - HuntId STRING(8) NOT NULL, - OutputPluginId STRING(8) NOT NULL, + HuntId STRING(16) NOT NULL, + OutputPluginId STRING(16) NOT NULL, Name STRING(256) NOT NULL, Args `google.protobuf.Any`, State `google.protobuf.Any` NOT NULL, @@ -588,7 +588,7 @@ CREATE TABLE HuntOutputPlugins( INTERLEAVE IN PARENT Hunts ON DELETE CASCADE; CREATE TABLE ForemanRules( - HuntId STRING(8) NOT NULL, + HuntId STRING(16) NOT NULL, ExpirationTime TIMESTAMP NOT NULL, Payload `grr.ForemanCondition`, diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index f21576831..adef1ec1d 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -4,6 +4,7 @@ import dataclasses import datetime import logging +from time import sleep from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union from google.api_core.exceptions import AlreadyExists, NotFound @@ -87,8 +88,8 @@ def _BuildReadFlowResultsErrorsConditions( WHERE t.ClientId = {client_id} AND t.FlowId = {flow_id} """ - params["client_id"] = spanner_clients.IntClientID(client_id) - params["flow_id"] = IntFlowID(flow_id) + params["client_id"] = client_id + params["flow_id"] = flow_id if with_tag is not None: query += " AND t.Tag = {tag} " @@ -132,8 +133,8 @@ def _BuildCountFlowResultsErrorsConditions( WHERE t.ClientId = {client_id} AND t.FlowId = {flow_id} """ - params["client_id"] = spanner_clients.IntClientID(client_id) - params["flow_id"] = IntFlowID(flow_id) + params["client_id"] = client_id + params["flow_id"] = flow_id if with_tag is not None: query += " AND t.Tag = {tag} " @@ -171,51 +172,51 @@ def _ParseReadFlowObjectRow( ) -> flows_pb2.Flow: """Parses a row fetched with _READ_FLOW_OBJECT_COLS.""" result = flows_pb2.Flow() - result.ParseFromString(row["Flow"]) + result.ParseFromString(row[13]) - creation_time = rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) - update_time = rdfvalue.RDFDatetime.FromDatetime(row["UpdateTime"]) + creation_time = rdfvalue.RDFDatetime.FromDatetime(row[6]) + update_time = rdfvalue.RDFDatetime.FromDatetime(row[7]) # We treat column values as the source of truth for values, not the message # in the database itself. At least this is what the F1 implementation does. result.client_id = client_id result.flow_id = flow_id - result.long_flow_id = row["LongFlowId"] - - if row["ParentFlowId"] is not None: - result.parent_flow_id = db_utils.IntToFlowID(row["ParentFlowId"]) - if row["ParentHuntId"] is not None: - result.parent_hunt_id = db_utils.IntToHuntID(row["ParentHuntId"]) - - if row["Name"] is not None: - result.flow_class_name = row["Name"] - if row["Creator"] is not None: - result.creator = row["Creator"] - if row["State"] not in [None, flows_pb2.Flow.FlowState.UNSET]: - result.flow_state = row["State"] - if row["NextRequestToProcess"]: - result.next_request_to_process = row["NextRequestToProcess"] + result.long_flow_id = row[0] + + if row[1] is not None: + result.parent_flow_id = row[1] + if row[2] is not None: + result.parent_hunt_id = row[2] + + if row[4] is not None: + result.flow_class_name = row[4] + if row[3] is not None: + result.creator = row[3] + if row[5] not in [None, flows_pb2.Flow.FlowState.UNSET]: + result.flow_state = row[5] + if row[12]: + result.next_request_to_process = int(row[12]) result.create_time = int(creation_time) result.last_update_time = int(update_time) - if row["Crash"] is not None: + if row[8] is not None: client_crash = jobs_pb2.ClientCrash() - client_crash.ParseFromString(row["Crash"]) + client_crash.ParseFromString(row[8]) result.client_crash_info.CopyFrom(client_crash) result.ClearField("processing_on") - if row["ProcessingWorker"] is not None: - result.processing_on = row["ProcessingWorker"] + if row[9] is not None: + result.processing_on = row[9] result.ClearField("processing_since") - if row["ProcessingStartTime"] is not None: + if row[10] is not None: result.processing_since = int( - rdfvalue.RDFDatetime.FromDatetime(row["ProcessingStartTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[10]) ) result.ClearField("processing_deadline") - if row["ProcessingEndTime"] is not None: + if row[11] is not None: result.processing_deadline = int( - rdfvalue.RDFDatetime.FromDatetime(row["ProcessingEndTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[11]) ) return result @@ -368,10 +369,10 @@ def ReadAllFlowObjects( conds = [] if client_id is not None: - params["client_id"] = spanner_clients.IntClientID(client_id) + params["client_id"] = client_id conds.append("f.ClientId = {client_id}") if parent_flow_id is not None: - params["parent_flow_id"] = IntFlowID(parent_flow_id) + params["parent_flow_id"] = parent_flow_id conds.append("f.ParentFlowId = {parent_flow_id}") if min_create_time is not None: params["min_creation_time"] = min_create_time.AsDatetime() @@ -382,30 +383,30 @@ def ReadAllFlowObjects( if not include_child_flows: conds.append("f.ParentFlowId IS NULL") if not_created_by is not None: - params["not_created_by"] = spanner_lib.Array(str, not_created_by) + params["not_created_by"] = list(not_created_by) conds.append("f.Creator NOT IN UNNEST({not_created_by})") if conds: query += f" WHERE {' AND '.join(conds)}" for row in self.db.ParamQuery(query, params, txn_tag="ReadAllFlowObjects"): - int_client_id, int_flow_id, long_flow_id, *row = row - int_parent_flow_id, int_parent_hunt_id, *row = row + client_id, flow_id, long_flow_id, *row = row + parent_flow_id, parent_hunt_id, *row = row creator, name, state, *row = row creation_time, update_time, *row = row crash_bytes, next_request_to_process, flow_bytes = row flow = flows_pb2.Flow() flow.ParseFromString(flow_bytes) - flow.client_id = db_utils.IntToClientID(int_client_id) - flow.flow_id = db_utils.IntToFlowID(int_flow_id) + flow.client_id = client_id + flow.flow_id = flow_id flow.long_flow_id = long_flow_id - flow.next_request_to_process = next_request_to_process + flow.next_request_to_process = int(next_request_to_process) - if int_parent_flow_id is not None: - flow.parent_flow_id = db_utils.IntToFlowID(int_parent_flow_id) - if int_parent_hunt_id is not None: - flow.parent_hunt_id = db_utils.IntToHuntID(int_parent_hunt_id) + if parent_flow_id is not None: + flow.parent_flow_id = parent_flow_id + if parent_hunt_id is not None: + flow.parent_hunt_id = parent_hunt_id flow.creator = creator flow.flow_state = state @@ -449,8 +450,8 @@ def UpdateFlow( """Updates flow objects in the database.""" row = { - "ClientId": spanner_clients.IntClientID(client_id), - "FlowId": IntFlowID(flow_id), + "ClientId": client_id, + "FlowId": flow_id, "UpdateTime": spanner_lib.COMMIT_TIMESTAMP, } @@ -489,18 +490,20 @@ def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: """Writes flow results for a given flow.""" def Mutation(mut) -> None: + rows = [] + columns = ["ClientId", "FlowId", "HuntId", "CreationTime", + "Tag", "RdfType", "Payload"] for r in results: - row = { - "ClientId": spanner_clients.IntClientID(r.client_id), - "FlowId": IntFlowID(r.flow_id), - "HuntId": IntHuntID(r.hunt_id) if r.hunt_id else 0, - "CreationTime": rdfvalue.RDFDatetime.Now().AsDatetime(), - "Tag": r.tag, - "RdfType": db_utils.TypeURLToRDFTypeName(r.payload.type_url), - "Payload": r.payload, - } - - mut.Insert("FlowResults", row) + rows.append([ + r.client_id, + r.flow_id, + r.hunt_id if r.hunt_id else "0", + rdfvalue.RDFDatetime.Now().AsDatetime(), + r.tag, + db_utils.TypeURLToRDFTypeName(r.payload.type_url), + r.payload, + ]) + mut.insert(table="FlowResults", columns=columns, values=rows) self.db.Mutate(Mutation, txn_tag="WriteFlowResults") @@ -510,17 +513,19 @@ def WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: """Writes flow errors for a given flow.""" def Mutation(mut) -> None: + rows = [] + columns = ["ClientId", "FlowId", "HuntId", + "CreationTime", "Payload", "RdfType", "Tag"] for r in errors: - row = { - "ClientId": spanner_clients.IntClientID(r.client_id), - "FlowId": IntFlowID(r.flow_id), - "HuntId": IntHuntID(r.hunt_id) if r.hunt_id else 0, - "CreationTime": rdfvalue.RDFDatetime.Now().AsDatetime(), - "Payload": r.payload, - "RdfType": db_utils.TypeURLToRDFTypeName(r.payload.type_url), - "Tag": r.tag, - } - mut.Insert("FlowErrors", row) + rows.append([r.client_id, + r.flow_id, + r.hunt_id if r.hunt_id else "0", + rdfvalue.RDFDatetime.Now().AsDatetime(), + r.payload, + db_utils.TypeURLToRDFTypeName(r.payload.type_url), + r.tag, + ]) + mut.insert(table="FlowErrors", columns=columns, values=rows) self.db.Mutate(Mutation, txn_tag="WriteFlowErrors") @@ -564,7 +569,7 @@ def ReadFlowResults( result.payload.ParseFromString(payload_bytes) if hunt_id is not None: - result.hunt_id = db_utils.IntToHuntID(hunt_id) + result.hunt_id = hunt_id if tag is not None: result.tag = tag @@ -626,7 +631,7 @@ def ReadFlowErrors( error.payload.Pack(unrecognized) if hunt_id is not None: - error.hunt_id = db_utils.IntToHuntID(hunt_id) + error.hunt_id = hunt_id if tag is not None: error.tag = tag @@ -688,8 +693,8 @@ def CountFlowResultsByType( """ params = { - "client_id": spanner_clients.IntClientID(client_id), - "flow_id": IntFlowID(flow_id), + "client_id": client_id, + "flow_id": flow_id, } result = {} @@ -715,8 +720,8 @@ def CountFlowErrorsByType( """ params = { - "client_id": spanner_clients.IntClientID(client_id), - "flow_id": IntFlowID(flow_id), + "client_id": client_id, + "flow_id": flow_id, } result = {} @@ -729,31 +734,15 @@ def CountFlowErrorsByType( def _BuildFlowProcessingRequestWrites( self, - mut: spanner_utils.Mutation, requests: Iterable[flows_pb2.FlowProcessingRequest], ) -> None: - """Builds db writes for a list of FlowProcessingRequests.""" + """Writes a list of FlowProcessingRequests to the queue.""" + flowProcessingRequests = [] + for request in requests: + flowProcessingRequests.append(request.SerializeToString()) - for r in requests: - key = ( - r.client_id, - r.flow_id, - spanner_lib.COMMIT_TIMESTAMP, - ) + self.db.PublishFlowProcessingRequests(flowProcessingRequests) - ts = None - if r.delivery_time: - ts = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( - r.delivery_time - ).AsDatetimeUTC() - - mut.Send( - queue="FlowProcessingRequestsQueue", - key=key, - value=r, - column="Payload", - deliver_time=ts, - ) @db_utils.CallLogged @db_utils.CallAccounted @@ -763,11 +752,7 @@ def WriteFlowProcessingRequests( ) -> None: """Writes a list of flow processing requests to the database.""" - ### TODO ### - #def Mutation(mut) -> None: - # self._BuildFlowProcessingRequestWrites(mut, requests) - # - #self.db.Mutate(Mutation, txn_tag="WriteFlowProcessingRequests") + self._BuildFlowProcessingRequestWrites(requests) @db_utils.CallLogged @db_utils.CallAccounted @@ -775,17 +760,14 @@ def ReadFlowProcessingRequests( self, ) -> Sequence[flows_pb2.FlowProcessingRequest]: """Reads all flow processing requests from the database.""" - query = """ - SELECT t.Payload, t.CreationTime FROM FlowProcessingRequestsQueue AS t - """ - results = [] - for payload, creation_time in self.db.ParamQuery( - query, {}, txn_tag="ReadFlowProcessingRequests" - ): + for result in self.db.ReadFlowProcessingRequests(): req = flows_pb2.FlowProcessingRequest() - req.ParseFromString(payload) - req.creation_time = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + req.ParseFromString(result["payload"]) + req.creation_time = int( + rdfvalue.RDFDatetime.FromDatetime(result["publish_time"]) + ) + req.ack_id = result["ack_id"] results.append(req) return results @@ -800,8 +782,8 @@ def AckFlowProcessingRequests( def Mutation(mut) -> None: for r in requests: key = ( - spanner_clients.IntClientID(r.client_id), - IntFlowID(r.flow_id), + r.client_id, + r.flow_id, rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( r.creation_time ).AsDatetime(), @@ -825,25 +807,27 @@ def RegisterFlowProcessingHandler( """Registers a handler to receive flow processing messages.""" self.UnregisterFlowProcessingHandler() - def Callback(expanded_key: Sequence[Any], payload: bytes): + def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time): try: req = flows_pb2.FlowProcessingRequest() req.ParseFromString(payload) + req.ack_id = ack_id req.creation_time = int( - rdfvalue.RDFDatetime.FromDatetime(expanded_key[2]) + rdfvalue.RDFDatetime.FromDatetime(publish_time) ) handler(req) except Exception as e: # pylint: disable=broad-except - logging.exception("Exception raised during Flow processing: %s", e) + logging.exception( + "Exception raised during MessageHandlerRequest processing: %s", e + ) - receiver = self.db.NewQueueReceiver( - "FlowProcessingRequestsQueue", + receiver = self.db.NewRequestQueue( + "FlowProcessing", Callback, receiver_max_keepalive_seconds=3000, receiver_max_active_callbacks=50, receiver_max_messages_per_callback=1, ) - receiver.Receive() self._flow_processing_request_receiver = receiver def UnregisterFlowProcessingHandler( @@ -900,7 +884,7 @@ def Txn(txn) -> None: # .keys() should be omitted. Omitting keys leads to a # mistaken warning that .items() was not called. for client_id, flow_id in needs_processing: # pylint: disable=dict-iter-missing-items - keys.Add((client_id, flow_id)) + keys.append([client_id, flow_id]) columns = ( "ClientId", @@ -921,11 +905,8 @@ def Txn(txn) -> None: req.delivery_time = r.start_time flow_processing_requests.append(req) - #if flow_processing_requests: - # with txn.BufferedMutate() as mut: - # self._BuildFlowProcessingRequestWrites( - # mut, flow_processing_requests - # ) + if flow_processing_requests: + self._BuildFlowProcessingRequestWrites(flow_processing_requests) try: self.db.Transact(Txn, txn_tag="WriteFlowRequests") @@ -1074,16 +1055,15 @@ def _BuildExpectedUpdates( updates: dict mapping requests to the number of expected responses. txn: transaction to use for the writes. """ - - with txn.Mutate() as mut: - for r_key, num_responses_expected in updates.items(): - row = { - "ClientId": spanner_clients.IntClientID(r_key.client_id), - "FlowId": IntFlowID(r_key.flow_id), - "RequestId": r_key.request_id, - "ExpectedResponseCount": num_responses_expected, - } - mut.Update("FlowRequests", row) + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "ExpectedResponseCount"] + for r_key, num_responses_expected in updates.items(): + rows.append([r_key.client_id, + r_key.flow_id, + r_key.request_id, + num_responses_expected, + ]) + txn.update(table="FlowRequests", columns=columns, values=rows) def _WriteFlowResponsesAndExpectedUpdates( self, @@ -1202,8 +1182,7 @@ def Txn(txn) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: def _GetFlowResponsesPerRequestCounts( self, - request_keys: Iterable[_RequestKey], - txn, + request_keys: Iterable[_RequestKey] ) -> dict[_RequestKey, int]: """Gets counts of already received responses for given requests. @@ -1231,9 +1210,9 @@ def _GetFlowResponsesPerRequestCounts( fr.RequestId = {{request_id_{i}}}) """) - params[f"client_id_{i}"] = db_utils.ClientIDToInt(req_key.client_id) - params[f"flow_id_{i}"] = db_utils.FlowIDToInt(req_key.flow_id) - params[f"request_id_{i}"] = req_key.request_id + params[f"client_id_{i}"] = req_key.client_id + params[f"flow_id_{i}"] = req_key.flow_id + params[f"request_id_{i}"] = str(req_key.request_id) query = f""" SELECT fr.ClientId, fr.FlowId, fr.RequestId, COUNT(*) AS ResponseCount @@ -1243,12 +1222,12 @@ def _GetFlowResponsesPerRequestCounts( """ result = {} - for row in txn.ParamQuery(query, params): - client_id_int, flow_id_int, request_id, count = row + for row in self.db.ParamQuery(query, params): + client_id, flow_id, request_id, count = row req_key = _RequestKey( - db_utils.IntToClientID(client_id_int), - db_utils.IntToFlowID(flow_id_int), + client_id, + flow_id, request_id, ) result[req_key] = count @@ -1282,66 +1261,62 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( have to be notified of incoming responses. start_time in the tuple corresponds to the intended notification delivery time. """ - flow_rows = spanner_lib.RowSet() - req_rows = spanner_lib.RowSet() + flow_keys = [] + req_keys = [] unique_flow_keys = set() for req_key in set(requests) | set(callback_states): - client_id_int = spanner_clients.IntClientID(req_key.client_id) - flow_id_int = IntFlowID(req_key.flow_id) + req_keys.append([req_key.client_id, req_key.flow_id, req_key.request_id]) + unique_flow_keys.add((req_key.client_id, req_key.flow_id)) - req_rows.AddPrefixRange( - spanner_lib.Key(client_id_int, flow_id_int, req_key.request_id) - ) - unique_flow_keys.add((client_id_int, flow_id_int)) - - for client_id_int, flow_id_int in unique_flow_keys: - flow_rows.AddPrefixRange(spanner_lib.Key(client_id_int, flow_id_int)) + for client_id, flow_id in unique_flow_keys: + flow_keys.append([client_id, flow_id]) next_request_to_process_by_flow = {} - flow_cols = ( + flow_cols = [ "ClientId", "FlowId", "NextRequestToProcess", - ) - for row in txn.ReadSet("Flows", flow_rows, flow_cols): - client_id_int: int = row["ClientId"] - flow_id_int: int = row["FlowId"] - next_request_id: int = row["NextRequestToProcess"] - next_request_to_process_by_flow[(client_id_int, flow_id_int)] = ( + ] + for row in txn.read(table="Flows", + keyset=spanner_lib.KeySet(keys=flow_keys), + columns=flow_cols): + client_id: int = row[0] + flow_id: int = row[1] + next_request_id: int = row[2] + next_request_to_process_by_flow[(client_id, flow_id)] = ( next_request_id ) requests_to_mark = set() requests_to_notify = set() - req_cols = ( + req_cols = [ "ClientId", "FlowId", "RequestId", "NeedsProcessing", "StartTime", - ) - for row in txn.ReadSet("FlowRequests", req_rows, req_cols): - client_id_int: int = row["ClientId"] - flow_id_int: int = row["FlowId"] - request_id: int = row["RequestId"] - np: bool = row["NeedsProcessing"] + ] + for row in txn.read(table="FlowRequests", + keyset=spanner_lib.KeySet(keys=req_keys), + columns=req_cols): + client_id: str = row[0] + flow_id: str = row[1] + request_id: str = row[2] + np: bool = row[3] start_time: Optional[rdfvalue.RDFDatetime] = None - if row["StartTime"] is not None: - start_time = rdfvalue.RDFDatetime.FromDatetime(row["StartTime"]) + if row[4] is not None: + start_time = rdfvalue.RDFDatetime.FromDatetime(row[4]) if not np: - client_id = db_utils.IntToClientID(client_id_int) - flow_id = db_utils.IntToFlowID(flow_id_int) req_key = _RequestKey(client_id, flow_id, request_id) if req_key in requests: requests_to_mark.add(req_key) if ( - next_request_to_process_by_flow[(client_id_int, flow_id_int)] - == request_id + next_request_to_process_by_flow[(client_id, flow_id)] == request_id ): requests_to_notify.add((_FlowKey(client_id, flow_id), start_time)) @@ -1356,16 +1331,15 @@ def _BuildNeedsProcessingUpdates( requests: keys of requests to be updated. txn: transaction to use. """ - - with txn.Mutate() as mut: - for req_key in requests: - row = { - "ClientId": spanner_clients.IntClientID(req_key.client_id), - "FlowId": IntFlowID(req_key.flow_id), - "RequestId": req_key.request_id, - "NeedsProcessing": True, - } - mut.Update("FlowRequests", row) + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "NeedsProcessing"] + for req_key in requests: + rows.append([req_key.client_id, + req_key.flow_id, + req_key.request_id, + True, + ]) + txn.update(table="FlowRequests", columns=columns, values=rows) def _UpdateNeedsProcessingAndWriteFlowProcessingRequests( self, @@ -1407,8 +1381,7 @@ def _UpdateNeedsProcessingAndWriteFlowProcessingRequests( fpr.delivery_time = int(start_time) flow_processing_requests.append(fpr) - with txn.BufferedMutate() as mut: - self._BuildFlowProcessingRequestWrites(mut, flow_processing_requests) + self._BuildFlowProcessingRequestWrites(flow_processing_requests) @db_utils.CallLogged @db_utils.CallAccounted @@ -1438,9 +1411,8 @@ def WriteFlowResponses( return # Get actual per-request responses counts using a separate transaction. - read_txn = self.db.Snapshot() counts = self._GetFlowResponsesPerRequestCounts( - responses_expected_by_request, read_txn + responses_expected_by_request ) requests_ready_for_processing = set() @@ -1467,11 +1439,9 @@ def DeleteAllFlowRequestsAndResponses( flow_id: str, ) -> None: """Deletes all requests and responses for a given flow from the database.""" - int_client_id = spanner_clients.IntClientID(client_id) - int_flow_id = IntFlowID(flow_id) self.db.DeleteWithPrefix( "FlowRequests", - (int_client_id, int_flow_id), + (client_id, flow_id), txn_tag="DeleteAllFlowRequestsAndResponses", ) @@ -1495,64 +1465,51 @@ def ReadAllFlowRequestsAndResponses( ] ]: """Reads all requests and responses for a given flow from the database.""" - - txn = self.db.Snapshot() - - req_rows = spanner_lib.RowSet() - req_rows.AddPrefixRange( - spanner_lib.Key( - spanner_clients.IntClientID(client_id), IntFlowID(flow_id) - ) - ) - req_cols = ( + rowrange = spanner_lib.KeyRange(start_closed=[client_id, flow_id], end_closed=[client_id, flow_id]) + rows = spanner_lib.KeySet(ranges=[rowrange]) + req_cols = [ "Payload", "NeedsProcessing", "ExpectedResponseCount", "CallbackState", "NextResponseId", "CreationTime", - ) + ] requests = [] - for row in txn.ReadSet(table="FlowRequests", rows=req_rows, cols=req_cols): + for row in self.db.ReadSet(table="FlowRequests", rows=rows, cols=req_cols): request = flows_pb2.FlowRequest() - request.ParseFromString(row["Payload"]) - request.needs_processing = row["NeedsProcessing"] - if row["ExpectedResponseCount"] is not None: - request.nr_responses_expected = row["ExpectedResponseCount"] - request.callback_state = row["CallbackState"] - request.next_response_id = row["NextResponseId"] + request.ParseFromString(row[0]) + request.needs_processing = row[1] + if row[2] is not None: + request.nr_responses_expected = row[2] + request.callback_state = row[3] + request.next_response_id = int(row[4]) request.timestamp = int( - rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[5]) ) requests.append(request) - resp_rows = spanner_lib.RowSet() - resp_rows.AddPrefixRange( - spanner_lib.Key( - spanner_clients.IntClientID(client_id), IntFlowID(flow_id) - ) - ) - resp_cols = ( + resp_cols = [ "Response", "Status", "Iterator", "CreationTime", - ) + ] responses = {} - for row in txn.ReadSet( - table="FlowResponses", rows=resp_rows, cols=resp_cols + for row in self.db.ReadSet( + table="FlowResponses", rows=rows, cols=resp_cols ): - if row["Status"] is not None: + if row[1] is not None: response = flows_pb2.FlowStatus() - response.ParseFromString(row["Status"]) - elif row["Iterator"] is not None: + response.ParseFromString(row[1]) + elif row[2] is not None: response = flows_pb2.FlowIterator() - response.ParseFromString(row["Iterator"]) + response.ParseFromString(row[2]) else: response = flows_pb2.FlowResponse() - response.ParseFromString(row["Response"]) + response.ParseFromString(row[0]) response.timestamp = int( - rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[3]) ) responses.setdefault(response.request_id, {})[ response.response_id @@ -1576,15 +1533,14 @@ def DeleteFlowRequests( def Mutation(mut) -> None: for request in requests: key = [ - spanner_clients.IntClientID(request.client_id), - IntFlowID(request.flow_id), + request.client_id, + request.flow_id, request.request_id, ] mut.Delete(table="FlowRequests", key=key) try: self.db.Mutate(Mutation, txn_tag="DeleteFlowRequests") - # TODO(b/196379916): Narrow the exception types (cl/450440276). except Exception: if len(requests) == 1: # If there is only one request and we still hit Spanner limits it means @@ -1636,14 +1592,8 @@ def ReadFlowRequests( ], ]: """Reads all requests for a flow that can be processed by the worker.""" - - txn = self.db.Snapshot() - rows = spanner_lib.RowSet() - rows.AddPrefixRange( - spanner_lib.Key( - spanner_clients.IntClientID(client_id), IntFlowID(flow_id) - ) - ) + rowrange = spanner_lib.KeyRange(start_closed=[client_id, flow_id], end_closed=[client_id, flow_id]) + rows = spanner_lib.KeySet(ranges=[rowrange]) responses: dict[ int, @@ -1655,24 +1605,24 @@ def ReadFlowRequests( ] ], ] = {} - resp_cols = ( + resp_cols = [ "Response", "Status", "Iterator", "CreationTime", - ) - for row in txn.ReadSet(table="FlowResponses", rows=rows, cols=resp_cols): - if row["Status"]: + ] + for row in self.db.ReadSet(table="FlowResponses", rows=rows, cols=resp_cols): + if row[1]: response = flows_pb2.FlowStatus() - response.ParseFromString(row["Status"]) - elif row["Iterator"]: + response.ParseFromString(row[1]) + elif row[2]: response = flows_pb2.FlowIterator() - response.ParseFromString(row["Iterator"]) + response.ParseFromString(row[2]) else: response = flows_pb2.FlowResponse() - response.ParseFromString(row["Response"]) + response.ParseFromString(row[0]) response.timestamp = int( - rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[3]) ) responses.setdefault(response.request_id, []).append(response) @@ -1689,24 +1639,24 @@ def ReadFlowRequests( ], ], ] = {} - req_cols = ( + req_cols = [ "Payload", "NeedsProcessing", "ExpectedResponseCount", "NextResponseId", "CallbackState", "CreationTime", - ) - for row in txn.ReadSet(table="FlowRequests", rows=rows, cols=req_cols): + ] + for row in self.db.ReadSet(table="FlowRequests", rows=rows, cols=req_cols): request = flows_pb2.FlowRequest() - request.ParseFromString(row["Payload"]) - request.needs_processing = row["NeedsProcessing"] - if row["ExpectedResponseCount"] is not None: - request.nr_responses_expected = row["ExpectedResponseCount"] - request.callback_state = row["CallbackState"] - request.next_response_id = row["NextResponseId"] + request.ParseFromString(row[0]) + request.needs_processing = row[1] + if row[2] is not None: + request.nr_responses_expected = row[2] + request.callback_state = row[4] + request.next_response_id = int(row[3]) request.timestamp = int( - rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[5]) ) requests[request.request_id] = ( request, @@ -1726,17 +1676,14 @@ def UpdateIncrementalFlowRequests( ) -> None: """Updates next response ids of given requests.""" - int_client_id = spanner_clients.IntClientID(client_id) - int_flow_id = IntFlowID(flow_id) - with self.db.MutationPool() as mp: for request_id, response_id in next_response_id_updates.items(): with mp.Apply() as mut: mut.Update( table="FlowRequests", row={ - "ClientId": int_client_id, - "FlowId": int_flow_id, + "ClientId": client_id, + "FlowId": flow_id, "RequestId": request_id, "NextResponseId": response_id, }, @@ -1747,14 +1694,14 @@ def UpdateIncrementalFlowRequests( def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: """Writes a single flow log entry to the database.""" row = { - "ClientId": spanner_clients.IntClientID(entry.client_id), - "FlowId": IntFlowID(entry.flow_id), + "ClientId": entry.client_id, + "FlowId": entry.flow_id, "CreationTime": spanner_lib.COMMIT_TIMESTAMP, "Message": entry.message, } if entry.hunt_id: - row["HuntId"] = IntHuntID(entry.hunt_id) + row["HuntId"] = entry.hunt_id try: self.db.Insert(table="FlowLogEntries", row=row) @@ -1781,8 +1728,8 @@ def ReadFlowLogEntries( AND l.FlowId = {flow_id} """ params = { - "client_id": spanner_clients.IntClientID(client_id), - "flow_id": IntFlowID(flow_id), + "client_id": client_id, + "flow_id": flow_id, } if with_substring is not None: @@ -1797,14 +1744,14 @@ def ReadFlowLogEntries( params["count"] = count for row in self.db.ParamQuery(query, params): - int_hunt_id, creation_time, message = row + hunt_id, creation_time, message = row result = flows_pb2.FlowLogEntry() result.client_id = client_id result.flow_id = flow_id - if int_hunt_id is not None: - result.hunt_id = db_utils.IntToHuntID(int_hunt_id) + if hunt_id is not None: + result.hunt_id = hunt_id result.timestamp = rdfvalue.RDFDatetime.FromDatetime( creation_time @@ -1826,8 +1773,8 @@ def CountFlowLogEntries(self, client_id: str, flow_id: str) -> int: AND l.FlowId = {flow_id} """ params = { - "client_id": spanner_clients.IntClientID(client_id), - "flow_id": IntFlowID(flow_id), + "client_id": client_id, + "flow_id": flow_id, } (count,) = self.db.ParamQuerySingle(query, params) @@ -1848,18 +1795,20 @@ def WriteFlowRRGLogs( return def Mutation(mut) -> None: + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "ResponseId", + "LogLevel", "LogTime", "LogMessage", "CreationTime"] for response_id, log in logs.items(): - row = { - "ClientId": db_utils.ClientIDToInt(client_id), - "FlowId": db_utils.FlowIDToInt(flow_id), - "RequestId": request_id, - "ResponseId": response_id, - "LogLevel": log.level, - "LogTime": log.timestamp.ToDatetime(), - "LogMessage": log.message, - "CreationTime": spanner_lib.COMMIT_TIMESTAMP, - } - mut.Insert(table="FlowRRGLogs", row=row) + rows.append([client_id, + flow_id, + request_id, + response_id, + log.level, + log.timestamp.ToDatetime(), + log.message, + spanner_lib.COMMIT_TIMESTAMP + ]) + mut.insert(table="FlowRRGLogs", columns=columns, values=rows) try: self.db.Mutate(Mutation) @@ -1891,8 +1840,8 @@ def ReadFlowRRGLogs( {offset} """ params = { - "client_id": db_utils.ClientIDToInt(client_id), - "flow_id": db_utils.FlowIDToInt(flow_id), + "client_id": client_id, + "flow_id": flow_id, "offset": offset, "count": count, } @@ -1923,16 +1872,16 @@ def WriteFlowOutputPluginLogEntry( entry: An output plugin flow entry to write. """ row = { - "ClientId": spanner_clients.IntClientID(entry.client_id), - "FlowId": IntFlowID(entry.flow_id), - "OutputPluginId": IntOutputPluginID(entry.output_plugin_id), + "ClientId": entry.client_id, + "FlowId": entry.flow_id, + "OutputPluginId": entry.output_plugin_id, "CreationTime": spanner_lib.COMMIT_TIMESTAMP, "Type": int(entry.log_entry_type), "Message": entry.message, } if entry.hunt_id: - row["HuntId"] = IntHuntID(entry.hunt_id) + row["HuntId"] = entry.hunt_id try: self.db.Insert(table="FlowOutputPluginLogEntries", row=row) @@ -1965,9 +1914,9 @@ def ReadFlowOutputPluginLogEntries( AND l.OutputPluginId = {output_plugin_id} """ params = { - "client_id": spanner_clients.IntClientID(client_id), - "flow_id": IntFlowID(flow_id), - "output_plugin_id": IntOutputPluginID(output_plugin_id), + "client_id": client_id, + "flow_id": flow_id, + "output_plugin_id": output_plugin_id, } if with_type is not None: @@ -1982,15 +1931,15 @@ def ReadFlowOutputPluginLogEntries( params["count"] = count for row in self.db.ParamQuery(query, params): - int_hunt_id, creation_time, int_type, message = row + hunt_id, creation_time, int_type, message = row result = flows_pb2.FlowOutputPluginLogEntry() result.client_id = client_id result.flow_id = flow_id result.output_plugin_id = output_plugin_id - if int_hunt_id is not None: - result.hunt_id = db_utils.IntToHuntID(int_hunt_id) + if hunt_id is not None: + result.hunt_id = hunt_id result.timestamp = rdfvalue.RDFDatetime.FromDatetime( creation_time @@ -2022,9 +1971,9 @@ def CountFlowOutputPluginLogEntries( AND l.OutputPluginId = {output_plugin_id} """ params = { - "client_id": spanner_clients.IntClientID(client_id), - "flow_id": IntFlowID(flow_id), - "output_plugin_id": IntOutputPluginID(output_plugin_id), + "client_id": client_id, + "flow_id": flow_id, + "output_plugin_id": output_plugin_id, } if with_type is not None: @@ -2042,9 +1991,9 @@ def WriteScheduledFlow( ) -> None: """Inserts or updates the ScheduledFlow in the database.""" row = { - "ClientID": spanner_clients.IntClientID(scheduled_flow.client_id), + "ClientId": scheduled_flow.client_id, "Creator": scheduled_flow.creator, - "ScheduledFlowId": IntFlowID(scheduled_flow.scheduled_flow_id), + "ScheduledFlowId": scheduled_flow.scheduled_flow_id, "FlowName": scheduled_flow.flow_name, "FlowArgs": scheduled_flow.flow_args, "RunnerArgs": scheduled_flow.runner_args, @@ -2056,8 +2005,8 @@ def WriteScheduledFlow( try: self.db.InsertOrUpdate(table="ScheduledFlows", row=row) - except NotFound as error: - if "Parent row is missing: Clients" in str(error): + except Exception as error: + if "Parent row for row [" in str(error): raise db.UnknownClientError(scheduled_flow.client_id) from error elif "fk_creator_users_username" in str(error): raise db.UnknownGRRUserError(scheduled_flow.creator) from error @@ -2073,16 +2022,11 @@ def DeleteScheduledFlow( scheduled_flow_id: str, ) -> None: """Deletes the ScheduledFlow from the database.""" - - key = ( - spanner_clients.IntClientID(client_id), - creator, - IntFlowID(scheduled_flow_id), - ) + keyset = spanner_lib.KeySet(keys=[[client_id, creator, scheduled_flow_id]]) def Transaction(txn) -> None: try: - txn.Read(table="ScheduledFlows", cols=["ScheduledFlowId"], key=key) + txn.read(table="ScheduledFlows", columns=["ScheduledFlowId"], keyset=keyset).one() except NotFound as e: raise db.UnknownScheduledFlowError( client_id=client_id, @@ -2090,8 +2034,7 @@ def Transaction(txn) -> None: scheduled_flow_id=scheduled_flow_id, ) from e - with txn.Mutate() as mut: - mut.Delete(table="ScheduledFlows", key=key) + txn.delete(table="ScheduledFlows", keyset=keyset) self.db.Transact(Transaction) @@ -2103,12 +2046,10 @@ def ListScheduledFlows( creator: str, ) -> Sequence[flows_pb2.ScheduledFlow]: """Lists all ScheduledFlows for the client and creator.""" - rows = spanner_lib.RowSet() - rows.AddPrefixRange( - spanner_lib.Key(spanner_clients.IntClientID(client_id), creator) - ) + range = spanner_lib.KeyRange(start_closed=[client_id, creator], end_closed=[client_id, creator]) + rows = spanner_lib.KeySet(ranges=[range]) - cols = ( + cols = [ "ClientId", "Creator", "ScheduledFlowId", @@ -2117,21 +2058,21 @@ def ListScheduledFlows( "RunnerArgs", "CreationTime", "Error", - ) + ] results = [] for row in self.db.ReadSet("ScheduledFlows", rows, cols): sf = flows_pb2.ScheduledFlow() - sf.client_id = db_utils.IntToClientID(row["ClientId"]) - sf.creator = row["Creator"] - sf.scheduled_flow_id = db_utils.IntToFlowID(row["ScheduledFlowId"]) - sf.flow_name = row["FlowName"] - sf.flow_args.ParseFromString(row["FlowArgs"]) - sf.runner_args.ParseFromString(row["RunnerArgs"]) + sf.client_id = row[0] + sf.creator = row[1] + sf.scheduled_flow_id = row[2] + sf.flow_name = row[3] + sf.flow_args.ParseFromString(row[4]) + sf.runner_args.ParseFromString(row[5]) sf.create_time = int( - rdfvalue.RDFDatetime.FromDatetime(row["CreationTime"]) + rdfvalue.RDFDatetime.FromDatetime(row[6]) ) - sf.error = row["Error"] + sf.error = row[7] results.append(sf) @@ -2161,9 +2102,9 @@ def ReadMessageHandlerRequests( for result in self.db.ReadMessageHandlerRequests(): req = objects_pb2.MessageHandlerRequest() req.ParseFromString(result["payload"]) - req.timestamp = rdfvalue.RDFDatetime.FromDatetime( - result["publish_time"] - ).AsMicrosecondsSinceEpoch() + req.creation_time = int( + rdfvalue.RDFDatetime.FromDatetime(result["publish_time"]) + ) req.ack_id = result["ack_id"] results.append(req) @@ -2295,8 +2236,8 @@ def _ReadHuntState( self, txn, hunt_id: str ) -> Optional[int]: try: - row = txn.Read(table="Hunts", key=(IntHuntID(hunt_id),), cols=("State",)) - return row["State"] + row = txn.read(table="Hunts", keyset=spanner_lib.KeySet[[hunt_id]], columns=["State",]).one() + return row[0] except NotFound: return None @@ -2309,16 +2250,14 @@ def LeaseFlowForProcessing( processing_time: rdfvalue.Duration, ) -> flows_pb2.Flow: """Marks a flow as being processed on this worker and returns it.""" - int_client_id = spanner_clients.IntClientID(client_id) - int_flow_id = IntFlowID(flow_id) def Txn(txn) -> flows_pb2.Flow: try: - row = txn.Read( + row = txn.read( table="Flows", - key=(int_client_id, int_flow_id), - cols=_READ_FLOW_OBJECT_COLS, - ) + keyset=spanner_lib.KeySet[[client_id, flow_id]], + columns=_READ_FLOW_OBJECT_COLS, + ).one() except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) @@ -2357,8 +2296,8 @@ def Txn(txn) -> flows_pb2.Flow: txn.Update( table="Flows", row={ - "ClientId": int_client_id, - "FlowId": int_flow_id, + "ClientId": client_id, + "FlowId": flow_id, "ProcessingWorker": flow.processing_on, "ProcessingEndTime": ( rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( @@ -2383,18 +2322,16 @@ def Txn(txn) -> flows_pb2.Flow: @db_utils.CallAccounted def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: """Releases a flow that the worker was processing to the database.""" - int_client_id = spanner_clients.IntClientID(flow_obj.client_id) - int_flow_id = IntFlowID(flow_obj.flow_id) def Txn(txn) -> bool: try: - row = txn.Read( + row = txn.read( table="FlowRequests", - key=(int_client_id, int_flow_id, flow_obj.next_request_to_process), - cols=("NeedsProcessing", "StartTime"), - ) - if row["NeedsProcessing"]: - start_time = row["StartTime"] + keyset=spanner_lib.KeySet[[flow_obj.client_id, flow_obj.flow_id, flow_obj.next_request_to_process]], + columns=["NeedsProcessing", "StartTime"], + ).one() + if row[0]: + start_time = row[1] if start_time is None: return False elif ( @@ -2404,67 +2341,22 @@ def Txn(txn) -> bool: return False except NotFound: pass - txn.Update( + txn.update( table="Flows", - row={ - "ClientId": int_client_id, - "FlowId": int_flow_id, - "Flow": flow_obj, - "State": int(flow_obj.flow_state), - "UserCpuTimeUsed": float(flow_obj.cpu_time_used.user_cpu_time), - "SystemCpuTimeUsed": float( - flow_obj.cpu_time_used.system_cpu_time - ), - "NetworkBytesSent": int( - flow_obj.network_bytes_sent - ), - "ProcessingWorker": None, - "ProcessingStartTime": None, - "ProcessingEndTime": None, - "NextRequesttoProcess": int( - flow_obj.next_request_to_process - ), - "UpdateTime": spanner_lib.COMMIT_TIMESTAMP, - "ReplyCount": flow_obj.num_replies_sent, - }, + columns=["ClientId", "FlowId", "Flow", "State", "UserCpuTimeUsed", + "SystemCpuTimeUsed", "NetworkBytesSent", "ProcessingWorker", + "ProcessingStartTime", "ProcessingEndTime", "NextRequesttoProcess", + "UpdateTime", "ReplyCount"], + values=[[flow_obj.client_id, flow_obj.flow_id,flow_obj, + int(flow_obj.flow_state), float(flow_obj.cpu_time_used.user_cpu_time), + float(flow_obj.cpu_time_used.system_cpu_time), + int(flow_obj.network_bytes_sent), None, None, None, + flow_obj.next_request_to_process, + spanner_lib.COMMIT_TIMESTAMP, + flow_obj.num_replies_sent, + ]], ) - return True return self.db.Transact(Txn).value - -def IntFlowID(flow_id: str) -> int: - """Converts a flow identifier to its integer representation. - - Args: - flow_id: A flow identifier to convert. - - Returns: - An integer representation of the given flow identifier. - """ - return db_utils.FlowIDToInt(flow_id) - - -def IntHuntID(hunt_id: str) -> int: - """Converts a hunt identifier to its integer representation. - - Args: - hunt_id: A hunt identifier to convert. - - Returns: - An integer representation of the given hunt identifier. - """ - return db_utils.HuntIDToInt(hunt_id) - - -def IntOutputPluginID(output_plugin_id: str) -> int: - """Converts an output plugin identifier to its integer representation. - - Args: - output_plugin_id: An output plugin identifier to convert. - - Returns: - An integer representation of the given output plugin identifier. - """ - return db_utils.OutputPluginIDToInt(output_plugin_id) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 547a28c11..de399915a 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -100,7 +100,7 @@ def __init__(self, pyspanner: spanner_lib.database, project_id: str, self.project_id = project_id self.publisher = pubsub_v1.PublisherClient() self.subscriber = pubsub_v1.SubscriberClient() - self.flow_proccessing_sub_path = self.subscriber.subscription_path(project_id, flow_processing_sub_id) + self.flow_processing_sub_path = self.subscriber.subscription_path(project_id, flow_processing_sub_id) self.flow_processing_top_path = self.publisher.topic_path(project_id, flow_processing_top_id) self.message_handler_sub_path = self.subscriber.subscription_path(project_id, msg_handler_sub_id) self.message_handler_top_path = self.publisher.topic_path(project_id, msg_handler_top_id) From de568f5f77f3cd9f8c281d21f570f07a2a7b00ff Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 10 Jun 2025 19:20:35 +0000 Subject: [PATCH 020/168] EoD 20250610 --- .../databases/spanner_flows.py | 58 ++++++++--------- .../databases/spanner_foreman_rules.py | 2 +- .../databases/spanner_message_handler_test.py | 1 + .../databases/spanner_test_lib.py | 65 ++++++++++++++++--- .../databases/spanner_utils.py | 2 +- 5 files changed, 86 insertions(+), 42 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index adef1ec1d..237d73301 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -4,7 +4,6 @@ import dataclasses import datetime import logging -from time import sleep from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union from google.api_core.exceptions import AlreadyExists, NotFound @@ -1676,18 +1675,18 @@ def UpdateIncrementalFlowRequests( ) -> None: """Updates next response ids of given requests.""" - with self.db.MutationPool() as mp: + def Txn(txn) -> None: + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "NextResponseId"] for request_id, response_id in next_response_id_updates.items(): - with mp.Apply() as mut: - mut.Update( - table="FlowRequests", - row={ - "ClientId": client_id, - "FlowId": flow_id, - "RequestId": request_id, - "NextResponseId": response_id, - }, - ) + rows.append([client_id, flow_id, request_id, response_id]) + txn.update( + table="FlowRequests", + columns=columns, + values=rows, + ) + + self.db.Transact(Txn) @db_utils.CallLogged @db_utils.CallAccounted @@ -2102,7 +2101,7 @@ def ReadMessageHandlerRequests( for result in self.db.ReadMessageHandlerRequests(): req = objects_pb2.MessageHandlerRequest() req.ParseFromString(result["payload"]) - req.creation_time = int( + req.timestamp = int( rdfvalue.RDFDatetime.FromDatetime(result["publish_time"]) ) req.ack_id = result["ack_id"] @@ -2236,7 +2235,7 @@ def _ReadHuntState( self, txn, hunt_id: str ) -> Optional[int]: try: - row = txn.read(table="Hunts", keyset=spanner_lib.KeySet[[hunt_id]], columns=["State",]).one() + row = txn.read(table="Hunts", keyset=spanner_lib.KeySet(keys=[[hunt_id]]), columns=["State",]).one() return row[0] except NotFound: return None @@ -2255,7 +2254,7 @@ def Txn(txn) -> flows_pb2.Flow: try: row = txn.read( table="Flows", - keyset=spanner_lib.KeySet[[client_id, flow_id]], + keyset=spanner_lib.KeySet(keys=[[client_id, flow_id]]), columns=_READ_FLOW_OBJECT_COLS, ).one() except NotFound as error: @@ -2293,28 +2292,23 @@ def Txn(txn) -> flows_pb2.Flow: flow.processing_on = utils.ProcessIdString() flow.processing_deadline = int(now + processing_time) - txn.Update( + txn.update( table="Flows", - row={ - "ClientId": client_id, - "FlowId": flow_id, - "ProcessingWorker": flow.processing_on, - "ProcessingEndTime": ( - rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + columns = ["ClientId", "FlowId", "ProcessingWorker", + "ProcessingEndTime","ProcessingStartTime"], + values=[[client_id, flow_id, flow.processing_on, + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( flow.processing_deadline - ).AsDatetime() - ), - "ProcessingStartTime": spanner_lib.COMMIT_TIMESTAMP, - }, + ).AsDatetime(), + spanner_lib.COMMIT_TIMESTAMP, + ]] ) return flow - result = self.db.Transact(Txn) - - leased_flow = result.value + leased_flow, commit_timestamp = self.db.Transact(Txn) leased_flow.processing_since = int( - rdfvalue.RDFDatetime.FromDatetime(result.commit_time) + rdfvalue.RDFDatetime.FromDatetime(commit_timestamp) ) return leased_flow @@ -2327,7 +2321,7 @@ def Txn(txn) -> bool: try: row = txn.read( table="FlowRequests", - keyset=spanner_lib.KeySet[[flow_obj.client_id, flow_obj.flow_id, flow_obj.next_request_to_process]], + keyset=spanner_lib.KeySet(keys=[[flow_obj.client_id, flow_obj.flow_id, flow_obj.next_request_to_process]]), columns=["NeedsProcessing", "StartTime"], ).one() if row[0]: @@ -2358,5 +2352,5 @@ def Txn(txn) -> bool: ) return True - return self.db.Transact(Txn).value + return self.db.Transact(Txn) diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules.py b/grr/server/grr_response_server/databases/spanner_foreman_rules.py index 5c61300e8..9b6950fe3 100644 --- a/grr/server/grr_response_server/databases/spanner_foreman_rules.py +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules.py @@ -19,7 +19,7 @@ class ForemanRulesMixin: def WriteForemanRule(self, rule: jobs_pb2.ForemanCondition) -> None: """Writes a foreman rule to the database.""" row = { - "HuntId": rule.hunt_id_int, + "HuntId": rule.hunt_id, "ExpirationTime": ( rdfvalue.RDFDatetime() .FromMicrosecondsSinceEpoch(rule.expiration_time) diff --git a/grr/server/grr_response_server/databases/spanner_message_handler_test.py b/grr/server/grr_response_server/databases/spanner_message_handler_test.py index 088dd6b0c..1702ff178 100644 --- a/grr/server/grr_response_server/databases/spanner_message_handler_test.py +++ b/grr/server/grr_response_server/databases/spanner_message_handler_test.py @@ -71,6 +71,7 @@ def testMessageHandlerRequests(self): self.assertCountEqual(requests[2:4], read) + def testMessageHandlerLeaseManagement(self): ######################## # Lease Management tests ######################## diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index c84db358b..f1fa76c5f 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -1,11 +1,13 @@ """A library with utilities for testing the Spanner database implementation.""" import os import unittest +import uuid from typing import Optional from absl.testing import absltest +from google.cloud import pubsub_v1 from google.cloud import spanner_v1 as spanner_lib from google.cloud import spanner_admin_database_v1 from google.cloud.spanner import Client, KeySet @@ -108,25 +110,72 @@ class TestCase(absltest.TestCase): This class takes care of setting up a clean database for every test method. It is intended to be used with database test suite mixins. """ + msg_handler_top = None + msg_handler_sub = None + flow_processing_top = None + flow_processing_sub = None + + project_id = None def setUp(self): super().setUp() - project_id = _GetEnvironOrSkip("PROJECT_ID") - msg_handler_top_id = _GetEnvironOrSkip("MESSAGE_HANDLER_TOPIC_ID") - msg_handler_sub_id = _GetEnvironOrSkip("MESSAGE_HANDLER_SUBSCRIPTION_ID") - flow_processing_top_id = _GetEnvironOrSkip("FLOW_PROCESSING_TOPIC_ID") - flow_processing_sub_id = _GetEnvironOrSkip("FLOW_PROCESSING_SUBSCRIPTION_ID") + self.project_id = _GetEnvironOrSkip("PROJECT_ID") + + msg_uuid=str(uuid.uuid4()) + flow_uuid=str(uuid.uuid4()) + self.msg_handler_top_id = "msg-top"+msg_uuid + self.msg_handler_sub_id = "msg-sub"+msg_uuid + + self.flow_processing_top_id = "flow-top"+flow_uuid + self.flow_processing_sub_id = "flow-sub"+flow_uuid + + publisher = pubsub_v1.PublisherClient() + msg_handler_top_path = publisher.topic_path(self.project_id, self.msg_handler_top_id) + flow_processing_top_path = publisher.topic_path(self.project_id, self.flow_processing_top_id) + message_handler_top = publisher.create_topic(request={"name": msg_handler_top_path}) + flow_processing_top = publisher.create_topic(request={"name": flow_processing_top_path}) + + subscriber = pubsub_v1.SubscriberClient() + msg_handler_sub_path = subscriber.subscription_path(self.project_id, self.msg_handler_sub_id) + flow_processing_sub_path = subscriber.subscription_path(self.project_id, self.flow_processing_sub_id) + message_handler_sub = subscriber.create_subscription(request={"name": msg_handler_sub_path, + "topic": msg_handler_top_path} + ) + flow_processing_sub = subscriber.create_subscription(request={"name": flow_processing_sub_path, + "topic": flow_processing_top_path} + ) + _clean_database() - self.raw_db = spanner_utils.Database(_TEST_DB, project_id, - msg_handler_top_id, msg_handler_sub_id, - flow_processing_top_id, flow_processing_sub_id) + self.raw_db = spanner_utils.Database(_TEST_DB, self.project_id, + self.msg_handler_top_id, self.msg_handler_sub_id, + self.flow_processing_top_id, self.flow_processing_sub_id) spannerDB = spanner_db.SpannerDB(self.raw_db) self.db = abstract_db.DatabaseValidationWrapper(spannerDB) + def tearDown(self): + subscriber = pubsub_v1.SubscriberClient() + msg_handler_sub_path = subscriber.subscription_path(self.project_id, self.msg_handler_sub_id) + flow_processing_sub_path = subscriber.subscription_path(self.project_id, self.flow_processing_sub_id) + + # Wrap the subscriber in a 'with' block to automatically call close() to + # close the underlying gRPC channel when done. + with subscriber: + subscriber.delete_subscription(request={"subscription": msg_handler_sub_path}) + subscriber.delete_subscription(request={"subscription": flow_processing_sub_path}) + + publisher = pubsub_v1.PublisherClient() + msg_handler_top_path = publisher.topic_path(self.project_id, self.msg_handler_top_id) + flow_processing_top_path = publisher.topic_path(self.project_id, self.flow_processing_top_id) + + publisher.delete_topic(request={"topic": msg_handler_top_path}) + publisher.delete_topic(request={"topic": flow_processing_top_path}) + + super().tearDown() + def _get_table_names(db): with db.snapshot() as snapshot: diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index de399915a..9ab78c0e0 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -176,7 +176,7 @@ def Transact( txn_tag: Transaction tag to apply. Returns: - The result the transaction function returned. + The result of the transaction function executed. """ return self._pyspanner.run_in_transaction(func) From ae87797cb1abd8e53ad58ad9f0427e3a3ced352a Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 11 Jun 2025 17:12:20 +0000 Subject: [PATCH 021/168] EoD 20250611 --- .../databases/spanner_cron_jobs_test.py | 22 +++++ .../databases/spanner_flows.py | 83 ++++++++----------- .../databases/spanner_hunts_test.py | 25 ++++++ .../databases/spanner_paths_test.py | 58 +++++++++++++ .../databases/spanner_utils.py | 54 +++++++++++- 5 files changed, 191 insertions(+), 51 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner_cron_jobs_test.py create mode 100644 grr/server/grr_response_server/databases/spanner_hunts_test.py create mode 100644 grr/server/grr_response_server/databases/spanner_paths_test.py diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py new file mode 100644 index 000000000..3c37a2ed5 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_cronjob_test +from grr_response_server.databases.local.spanner import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseCronJobsTest( + db_cronjob_test.DatabaseTestCronJobMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 237d73301..950e5dcfd 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -313,8 +313,8 @@ def WriteFlowObject( self.db.Insert(table="Flows", row=row, txn_tag="WriteFlowObject_I") except AlreadyExists as error: raise db.FlowExistsError(client_id, flow_id) from error - except NotFound as error: - if "Parent row is missing: Clients" in str(error): + except Exception as error: + if "Parent row for row [" in str(error): raise db.UnknownClientError(client_id) else: raise @@ -738,6 +738,7 @@ def _BuildFlowProcessingRequestWrites( """Writes a list of FlowProcessingRequests to the queue.""" flowProcessingRequests = [] for request in requests: + request.creation_time=self.db.Now().AsMicrosecondsSinceEpoch() flowProcessingRequests.append(request.SerializeToString()) self.db.PublishFlowProcessingRequests(flowProcessingRequests) @@ -777,28 +778,17 @@ def AckFlowProcessingRequests( self, requests: Iterable[flows_pb2.FlowProcessingRequest] ) -> None: """Acknowledges and deletes flow processing requests.""" - - def Mutation(mut) -> None: - for r in requests: - key = ( - r.client_id, - r.flow_id, - rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( - r.creation_time - ).AsDatetime(), - ) - mut.Ack("FlowProcessingRequestsQueue", key) - - self.db.BufferedMutate(Mutation, txn_tag="AckFlowProcessingRequests") + ack_ids = [] + for r in requests: + ack_ids.append(r.ack_id) + + self.db.AckFlowProcessingRequests(ack_ids) @db_utils.CallLogged @db_utils.CallAccounted def DeleteAllFlowProcessingRequests(self) -> None: """Deletes all flow processing requests from the database.""" - query = """ - DELETE FROM FlowProcessingRequestsQueue WHERE true - """ - self.db.ParamExecute(query, {}, txn_tag="DeleteAllFlowProcessingRequests") + self.db.DeleteAllFlowProcessingRequests() def RegisterFlowProcessingHandler( self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] @@ -810,14 +800,26 @@ def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time): try: req = flows_pb2.FlowProcessingRequest() req.ParseFromString(payload) - req.ack_id = ack_id - req.creation_time = int( - rdfvalue.RDFDatetime.FromDatetime(publish_time) - ) - handler(req) + date_time_now = rdfvalue.RDFDatetime.Now() + epoch_now = date_time_now.AsMicrosecondsSinceEpoch() + epoch_in_ten = epoch_now + 10 * 1000000 + if req.delivery_time > epoch_now: + ack_ids = [] + ack_ids.append(ack_id) + # figure out when we reach the delivery time, and push it out (max 10 mins allowed by PubSub) + ack_deadline = req.delivery_time if req.delivery_time <= epoch_in_ten else epoch_in_ten + # PubSub wants the deadline in seconds from now + ack_deadline = int((ack_deadline - epoch_now)/1000000) + self.db.LeaseFlowProcessingRequests(ack_ids, ack_deadline) + else: + #req.creation_time = int( + # rdfvalue.RDFDatetime.FromDatetime(publish_time) + #) + req.ack_id = ack_id + handler(req) except Exception as e: # pylint: disable=broad-except logging.exception( - "Exception raised during MessageHandlerRequest processing: %s", e + "Exception raised during FlowProcessingRequest processing: %s", e ) receiver = self.db.NewRequestQueue( @@ -910,7 +912,7 @@ def Txn(txn) -> None: try: self.db.Transact(Txn, txn_tag="WriteFlowRequests") except NotFound as error: - if "Parent row is missing: Flows" in str(error): + if "Parent row for row [" in str(error): raise db.AtLeastOneUnknownFlowError(flow_keys, cause=error) else: raise @@ -1441,7 +1443,6 @@ def DeleteAllFlowRequestsAndResponses( self.db.DeleteWithPrefix( "FlowRequests", (client_id, flow_id), - txn_tag="DeleteAllFlowRequestsAndResponses", ) @db_utils.CallLogged @@ -2115,21 +2116,11 @@ def _BuildDeleteMessageHandlerRequestWrites( requests: Iterable[objects_pb2.MessageHandlerRequest], ) -> None: """Deletes given requests within a given transaction.""" - req_rows = spanner_lib.RowSet() + ack_ids = [] for r in requests: - req_rows.AddPrefixRange(spanner_lib.Key(r.handler_name, r.request_id)) - - to_delete = [] - req_cols = ("HandlerName", "RequestId", "CreationTime") - for row in txn.ReadSet("MessageHandlerRequestsQueue", req_rows, req_cols): - handler_name: str = row["HandlerName"] - request_id: int = row["RequestId"] - creation_time: datetime.datetime = row["CreationTime"] - to_delete.append((handler_name, request_id, creation_time)) - - with txn.BufferedMutate() as mut: - for td_key in to_delete: - mut.Ack("MessageHandlerRequestsQueue", td_key) + ack_ids.append(r.ack_id) + + self.db.AckMessageHandlerRequests(ack_ids) @db_utils.CallLogged @db_utils.CallAccounted @@ -2137,7 +2128,6 @@ def DeleteMessageHandlerRequests( self, requests: Iterable[objects_pb2.MessageHandlerRequest] ) -> None: """Deletes a list of message handler requests from the database.""" - ack_ids = [] for request in requests: ack_ids.append(request.ack_id) @@ -2306,10 +2296,10 @@ def Txn(txn) -> flows_pb2.Flow: return flow - leased_flow, commit_timestamp = self.db.Transact(Txn) - leased_flow.processing_since = int( - rdfvalue.RDFDatetime.FromDatetime(commit_timestamp) - ) + leased_flow = self.db.Transact(Txn) + #leased_flow.processing_since = int( + # rdfvalue.RDFDatetime.FromDatetime(commit_stats.commit_timestamp) + #) return leased_flow @db_utils.CallLogged @@ -2353,4 +2343,3 @@ def Txn(txn) -> bool: return True return self.db.Transact(Txn) - diff --git a/grr/server/grr_response_server/databases/spanner_hunts_test.py b/grr/server/grr_response_server/databases/spanner_hunts_test.py new file mode 100644 index 000000000..28f412d5c --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_hunts_test.py @@ -0,0 +1,25 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_hunts_test +from grr_response_server.databases import db_test_utils +from grr_response_server.databases.local.spanner import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseHuntsTest( + db_hunts_test.DatabaseTestHuntMixin, + db_test_utils.QueryTestHelpersMixin, + spanner_test_lib.TestCase, +): + pass + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_paths_test.py b/grr/server/grr_response_server/databases/spanner_paths_test.py new file mode 100644 index 000000000..0b108c698 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_paths_test.py @@ -0,0 +1,58 @@ +from collections.abc import Sequence + +from absl.testing import absltest + +from grr_response_server.databases import db_paths_test +from grr_response_server.databases.local.spanner import paths as spanner_paths +from grr_response_server.databases.local.spanner import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabasePathsTest( + db_paths_test.DatabaseTestPathsMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + pass + + +class EncodePathComponentsTest(absltest.TestCase): + + def testEmptyComponent(self): + with self.assertRaises(ValueError): + spanner_paths.EncodePathComponents(("foo", "", "bar")) + + def testSlashComponent(self): + with self.assertRaises(ValueError): + spanner_paths.EncodePathComponents(("foo", "bar/baz", "quux")) + + +class EncodeDecodePathComponentsTest(absltest.TestCase): + + def testEmpty(self): + self._testComponents(()) + + def testSingle(self): + self._testComponents(("foo",)) + + def testMultiple(self): + self._testComponents(("foo", "bar", "baz", "quux")) + + def testUnicode(self): + self._testComponents(("zażółć", "gęślą", "jaźń")) + + def _testComponents(self, components: Sequence[str]): # pylint: disable=invalid-name + encoded = spanner_paths.EncodePathComponents(components) + decoded = spanner_paths.DecodePathComponents(encoded) + + self.assertSequenceEqual(components, decoded) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 9ab78c0e0..5d881dac9 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -3,6 +3,7 @@ import contextlib import datetime import decimal +import pytz import re import time @@ -30,6 +31,7 @@ from google.rpc.code_pb2 import OK +from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import collection from grr_response_core.lib.util import iterator @@ -105,6 +107,24 @@ def __init__(self, pyspanner: spanner_lib.database, project_id: str, self.message_handler_sub_path = self.subscriber.subscription_path(project_id, msg_handler_sub_id) self.message_handler_top_path = self.publisher.topic_path(project_id, msg_handler_top_id) + def Now(self) -> rdfvalue.RDFDatetime: + """Retrieves current time as reported by the database.""" + try: + with self._pyspanner.snapshot() as snapshot: + query = "SELECT CURRENT_TIMESTAMP()" + results = snapshot.execute_sql(query) + # Get the first (and only) row + # and the first (and only) column from that row. + timestamp = next(results)[0] + return rdfvalue.RDFDatetime.FromDatetime(timestamp) + except Exception as e: + print(f"Error executing query: {e}") + return None + + def MinTimestamp(self) -> rdfvalue.RDFDatetime: + """Returns minimal timestamp allowed by the DB.""" + return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) + def _parametrize(self, query: str, names: Iterable[str]) -> str: match = self._PYSPANNER_PARAM_REGEX.search(query) if match is not None: @@ -167,6 +187,7 @@ def Transact( self, func: Callable[["Transaction"], _T], txn_tag: Optional[str] = None, + log_commit_stats: Optional[bool] = False ) -> List[Any]: """Execute the given callback function in a Spanner transaction. @@ -545,8 +566,14 @@ def AckMessageHandlerRequests(self, ack_ids: [str]) -> None: self.AckRequests(ack_ids, self.message_handler_sub_path) def AckFlowProcessingRequests(self, ack_ids: [str]) -> None: - self.AckRequests(ack_ids, self.flow_proccessing_sub_path) - + self.AckRequests(ack_ids, self.flow_processing_sub_path) + + def DeleteAllMessageHandlerRequests(self) -> None: + self.DeleteAllRequests(self.message_handler_sub_path) + + def DeleteAllFlowProcessingRequests(self) -> None: + self.DeleteAllRequests(self.flow_processing_sub_path) + def LeaseMessageHandlerRequests(self, ack_ids: [str], ack_deadline: int) -> None: self.subscriber.modify_ack_deadline( request={ @@ -556,6 +583,15 @@ def LeaseMessageHandlerRequests(self, ack_ids: [str], ack_deadline: int) -> None } ) + def LeaseFlowProcessingRequests(self, ack_ids: [str], ack_deadline: int) -> None: + self.subscriber.modify_ack_deadline( + request={ + "subscription": self.flow_processing_sub_path, + "ack_ids": ack_ids, + "ack_deadline_seconds": ack_deadline, + } + ) + def PublishRequests(self, requests: [str], top_path: str) -> None: for req in requests: self.publisher.publish(top_path, req) @@ -565,16 +601,26 @@ def AckRequests(self, ack_ids: [str], sub_path: str) -> None: request={"subscription": sub_path, "ack_ids": ack_ids} ) + def DeleteAllRequests(self, sub_path: str) -> None: + client = pubsub_v1.SubscriberClient() + # Initialize request argument(s) + request = { + "subscription": sub_path, + "time": datetime.datetime.now(pytz.utc) + datetime.timedelta(days=30) + } + # Make the request + response = client.seek(request=request) + def ReadRequests(self, sub_path: str): # Make the request start_time = time.time() results = {} - while time.time() - start_time < 10: + while time.time() - start_time < 2: time.sleep(0.1) response = self.subscriber.pull( - request={ + request = { "subscription": sub_path, "max_messages": 10000, }, From 140488ea09a2619943a4b2c1b2e20bb9160826ed Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Thu, 12 Jun 2025 19:33:17 +0000 Subject: [PATCH 022/168] EoD 20250612 --- .../databases/spanner_cron_jobs.py | 509 +++++++- .../databases/spanner_cron_jobs_test.py | 2 +- .../databases/spanner_flows.py | 14 +- .../databases/spanner_hunts.py | 1061 ++++++++++++++++- .../databases/spanner_hunts_test.py | 2 +- .../databases/spanner_paths.py | 543 ++++++++- .../databases/spanner_paths_test.py | 4 +- .../databases/spanner_utils.py | 19 +- 8 files changed, 2097 insertions(+), 57 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index 2c704725f..c3f833777 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -4,6 +4,7 @@ import datetime from typing import Any, Mapping, Optional, Sequence +from google.api_core.exceptions import NotFound from google.cloud import spanner as spanner_lib from grr_response_core.lib import rdfvalue @@ -63,8 +64,23 @@ def ReadCronJobs( UnknownCronJobError: A cron job for at least one of the given ids does not exist. """ + where_ids = "" + params = {} + if cronjob_ids: + where_ids = " WHERE cj.JobId IN UNNEST(@cronjob_ids)" + params["cronjob_ids"] = cronjob_ids - return None + def Transaction(txn) -> Sequence[flows_pb2.CronJob]: + return self._SelectCronJobsWith(txn, where_ids, params) + + res = self.db.Transact(Transaction, txn_tag="ReadCronJobs") + + if cronjob_ids and len(res) != len(cronjob_ids): + missing = set(cronjob_ids) - set([c.cron_job_id for c in res]) + raise db.UnknownCronJobError( + "CronJob(s) with id(s) %s not found." % missing + ) + return res @db_utils.CallLogged @db_utils.CallAccounted @@ -93,7 +109,24 @@ def UpdateCronJob( # pytype: disable=annotation-type-mismatch Raises: UnknownCronJobError: A cron job with the given id does not exist. """ - + row = {"JobId": cronjob_id} + if last_run_status is not _UNCHANGED: + row["LastRunStatus"] = int(last_run_status) + if last_run_time is not _UNCHANGED: + row["LastRunTime"] = ( + last_run_time.AsDatetime() if last_run_time else last_run_time + ) + if current_run_id is not _UNCHANGED: + row["CurrentRunId"] = current_run_id + if state is not _UNCHANGED: + row["State"] = state if state is not None else None + if forced_run_requested is not _UNCHANGED: + row["ForcedRunRequested"] = forced_run_requested + + try: + self.db.Update("CronJobs", row=row, txn_tag="UpdateCronJob") + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error @db_utils.CallLogged @db_utils.CallAccounted @@ -106,6 +139,14 @@ def EnableCronJob(self, cronjob_id: str) -> None: Raises: UnknownCronJobError: A cron job with the given id does not exist. """ + row = { + "JobId": cronjob_id, + "Enabled": True, + } + try: + self.db.Update("CronJobs", row=row, txn_tag="EnableCronJob") + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error @db_utils.CallLogged @@ -119,6 +160,14 @@ def DisableCronJob(self, cronjob_id: str) -> None: Raises: UnknownCronJobError: A cron job with the given id does not exist. """ + row = { + "JobId": cronjob_id, + "Enabled": False, + } + try: + self.db.Update("CronJobs", row=row, txn_tag="DisableCronJob") + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error @db_utils.CallLogged @db_utils.CallAccounted @@ -131,6 +180,18 @@ def DeleteCronJob(self, cronjob_id: str) -> None: Raises: UnknownCronJobError: A cron job with the given id does not exist. """ + def Transaction(txn) -> None: + # Spanner does not raise if we attept to delete a non-existing row so + # we check it exists ourselves. + keyset = spanner_lib.KeySet(keys=[[cronjob_id]]) + try: + txn.read(table="CronJobs", keyset=keyset, columns=['JobId']).one() + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error + + txn.delete(table="CronJobs", keyset=keyset) + + self.db.Transact(Transaction, txn_tag="DeleteCronJob") @db_utils.CallLogged @@ -152,7 +213,70 @@ def LeaseCronJobs( A list of cronjobs.CronJob objects that were leased. """ - return [] + # We can't simply Update the rows because `UPDATE ... SET` will not return + # the affected rows. So, this transaction is broken up in three parts: + # 1. Identify the rows that will be updated + # 2. Update these rows with the new lease information + # 3. Read back the affected rows and return them + def Transaction(txn) -> Sequence[flows_pb2.CronJob]: + now = rdfvalue.RDFDatetime.Now() + lease_end_time = now + lease_time + lease_owner = utils.ProcessIdString() + + # --------------------------------------------------------------------- + # Query IDs to be updated on this transaction + # --------------------------------------------------------------------- + query_ids_to_update = """ + SELECT cj.JobId + FROM CronJobs as cj + WHERE (cj.LeaseEndTime IS NULL OR cj.LeaseEndTime < @now) + """ + params_ids_to_update = {"now": now.AsDatetime()} + + if cronjob_ids: + query_ids_to_update += "AND cj.JobId IN UNNEST(@cronjob_ids)" + params_ids_to_update["cronjob_ids"] = cronjob_ids + + response = txn.execute_sql(sql=query_ids_to_update, params=params_ids_to_update) + + ids_to_update = [] + for (job_id,) in response: + ids_to_update.append(job_id) + + if not ids_to_update: + return [] + + # --------------------------------------------------------------------- + # Effectively update them with this process as owner + # --------------------------------------------------------------------- + update_query = """ + UPDATE CronJobs as cj + SET cj.LeaseEndTime = @lease_end_time, cj.LeaseOwner = @lease_owner + WHERE cj.JobId IN UNNEST(@ids_to_update) + """ + update_params = { + "lease_end_time": lease_end_time.AsDatetime(), + "lease_owner": lease_owner, + "ids_to_update": ids_to_update, + } + + txn.execute_sql(sql=update_query, params=update_params) + + # --------------------------------------------------------------------- + # Query (and return) jobs that were updated + # --------------------------------------------------------------------- + where_updated = """ + WHERE (cj.LeaseOwner = @lease_owner) + AND cj.JobId IN UNNEST(@updated_ids) + """ + updated_params = { + "lease_owner": lease_owner, + "updated_ids": ids_to_update, + } + + return self._SelectCronJobsWith(txn, where_updated, updated_params) + + return self.db.Transact(Transaction, txn_tag="LeaseCronJobs") @db_utils.CallLogged @db_utils.CallAccounted @@ -165,7 +289,111 @@ def ReturnLeasedCronJobs(self, jobs: Sequence[flows_pb2.CronJob]) -> None: Raises: ValueError: If not all of the cronjobs are leased. """ - + if not jobs: + return + + # Identify jobs that are not lease (cannot be returned). If there are any + # leased jobs that need returning, then we'll go ahead and try to update + # them anyway. + unleased_jobs = [] + conditions = [] + jobs_to_update_args = {} + for i, job in enumerate(jobs): + if not job.leased_by or not job.leased_until: + unleased_jobs.append(job) + continue + + conditions.append( + "(cj.JobId={job_%d} AND " + "cj.LeaseEndTime={ld_%d} AND " + "cj.LeaseOwner={lo_%d})" % (i, i, i) + ) + dt_leased_until = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(job.leased_until) + .AsDatetime() + ) + jobs_to_update_args[i] = ( + job.cron_job_id, + dt_leased_until, + job.leased_by, + ) + + if not conditions: # all jobs are unleased. + raise ValueError("CronJobs to return are not leased: %s" % unleased_jobs) + + # We can't simply Update the rows because `UPDATE ... SET` will not return + # the affected rows. We need both the _already_ disabled jobs and the + # updated rows in order to raise the appropriate exceptions. + # 1. Identify the rows that need to be updated + # 2. Update the relevant rows with the new unlease information + # 3. Read back the affected rows and return them + def Transaction(txn) -> Sequence[flows_pb2.CronJob]: + # --------------------------------------------------------------------- + # Query IDs to be updated on this transaction + # --------------------------------------------------------------------- + query_job_ids_to_return = """ + SELECT cj.JobId + FROM CronJobs as cj + """ + params_job_ids_to_return = {} + query_job_ids_to_return += "WHERE" + " OR ".join(conditions) + for i, (job_id, ld, lo) in jobs_to_update_args.items(): + params_job_ids_to_return["job_%d" % i] = job_id + params_job_ids_to_return["ld_%d" % i] = ld + params_job_ids_to_return["lo_%d" % i] = lo + + response = txn.ParamQuery( + query_job_ids_to_return, params_job_ids_to_return + ) + + ids_to_return = [] + for (job_id,) in response: + ids_to_return.append(job_id) + + if not ids_to_return: + return [] + + # --------------------------------------------------------------------- + # Effectively update them, removing owners + # --------------------------------------------------------------------- + update_query = """ + UPDATE CronJobs as cj + SET cj.LeaseEndTime = NULL, cj.LeaseOwner = NULL + WHERE cj.JobId IN UNNEST({ids_to_return}) + """ + update_params = { + "ids_to_return": ids_to_return, + } + + txn.ParamExecute(update_query, update_params) + + # --------------------------------------------------------------------- + # Query (and return) jobs that were updated + # --------------------------------------------------------------------- + where_returned = """ + WHERE cj.JobId IN UNNEST({updated_ids}) + """ + returned_params = { + "updated_ids": ids_to_return, + } + + returned_jobs = self._SelectCronJobsWith( + txn, where_returned, returned_params + ) + + return returned_jobs + + returned_jobs = self.db.Transact( + Transaction, txn_tag="ReturnLeasedCronJobs" + ).value + if unleased_jobs: + raise ValueError("CronJobs to return are not leased: %s" % unleased_jobs) + if len(returned_jobs) != len(jobs): + raise ValueError( + "%d cronjobs in %s could not be returned. Successfully returned: %s" + % ((len(jobs) - len(returned_jobs)), jobs, returned_jobs) + ) @db_utils.CallLogged @db_utils.CallAccounted @@ -175,7 +403,47 @@ def WriteCronJobRun(self, run_object: flows_pb2.CronJobRun) -> None: Args: run_object: A flows_pb2.CronJobRun object to store. """ + # If created_at is set, we use that instead of the commit timestamp. + # This is important for overriding timestamps in testing. Ideally, we would + # have a better/easier way to mock CommitTimestamp instead. + if run_object.started_at: + creation_time = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(run_object.started_at) + .AsDatetime() + ) + else: + creation_time = spanner_lib.COMMIT_TIMESTAMP + row = { + "JobId": run_object.cron_job_id, + "RunId": run_object.run_id, + "CreationTime": creation_time, + "Payload": run_object, + "Status": int(run_object.status) or 0, + } + if run_object.finished_at: + row["FinishTime"] = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(run_object.finished_at) + .AsDatetime() + ) + if run_object.log_message: + row["LogMessage"] = run_object.log_message + if run_object.backtrace: + row["Backtrace"] = run_object.backtrace + + try: + self.db.InsertOrUpdate( + table="CronJobRuns", row=row, txn_tag="WriteCronJobRun" + ) + except Exception as error: + if "Parent row for row [" in str(error): + # This error can be raised only when the parent cron job does not exist. + message = f"Cron job with id '{run_object.cron_job_id}' not found." + raise db.UnknownCronJobError(message) from error + else: + raise @db_utils.CallLogged @db_utils.CallAccounted @@ -188,8 +456,35 @@ def ReadCronJobRuns(self, job_id: str) -> Sequence[flows_pb2.CronJobRun]: Returns: A list of flows_pb2.CronJobRun objects. """ - - return [] + cols = [ + "Payload", + "JobId", + "RunId", + "CreationTime", + "FinishTime", + "Status", + "LogMessage", + "Backtrace", + ] + rowrange = spanner_lib.KeyRange(start_closed=[job_id], end_closed=[job_id]) + rows = spanner_lib.KeySet(ranges=[rowrange]) + + res = [] + for row in self.db.ReadSet(table="CronJobRuns", rows=rows, cols=cols): + res.append( + _CronJobRunFromRow( + job_run=row[0], + job_id=row[1], + run_id=row[2], + creation_time=row[3], + finish_time=row[4], + status=row[5], + log_message=row[6], + backtrace=row[7], + ) + ) + + return sorted(res, key=lambda run: run.started_at or 0, reverse=True) @db_utils.CallLogged @db_utils.CallAccounted @@ -203,8 +498,33 @@ def ReadCronJobRun(self, job_id: str, run_id: str) -> flows_pb2.CronJobRun: Returns: An flows_pb2.CronJobRun object. """ - - return None + cols = [ + "Payload", + "JobId", + "RunId", + "CreationTime", + "FinishTime", + "Status", + "LogMessage", + "Backtrace", + ] + try: + row = self.db.Read(table="CronJobRuns", key=(job_id, run_id), cols=cols) + except NotFound as error: + raise db.UnknownCronJobRunError( + "Run with job id %s and run id %s not found." % (job_id, run_id) + ) from error + + return _CronJobRunFromRow( + job_run=row[0], + job_id=row[1], + run_id=row[2], + creation_time=row[3], + finish_time=row[4], + status=row[5], + log_message=row[6], + backtrace=row[7], + ) @db_utils.CallLogged @db_utils.CallAccounted @@ -218,6 +538,177 @@ def DeleteOldCronJobRuns(self, cutoff_timestamp: rdfvalue.RDFDatetime) -> int: Returns: The number of deleted runs. """ + query = """ + SELECT cjr.JobId, cjr.RunId + FROM CronJobRuns AS cjr + WHERE cjr.CreationTime < @cutoff_timestamp + """ + params = {"cutoff_timestamp": cutoff_timestamp.AsDatetime()} + + def Transaction(txn) -> int: + rows = list(txn.execute_sql(sql=query, params=params)) + + for job_id, run_id in rows: + keyset = spanner_lib.KeySet(keys=[[job_id, run_id]]) + + txn.delete(table="CronJobRuns", keyset=keyset) + + return len(rows) + + return self.db.Transact(Transaction, txn_tag="DeleteOldCronJobRuns").value + + def _SelectCronJobsWith( + self, + txn, + where_clause: str, + params: Mapping[str, Any], + ) -> Sequence[flows_pb2.CronJob]: + """Reads rows within the transaction and converts results into CronJobs. - return 0 + Args: + txn: a transaction that will param query and return a cursor for the rows. + where_clause: where clause for filtering the rows. + params: params to be applied to the where_clause. + + Returns: + A list of CronJobs read from the database. + """ + + query = """ + SELECT cj.Job, cj.JobId, cj.CreationTime, cj.Enabled, + cj.ForcedRunRequested, cj.LastRunStatus, cj.LastRunTime, + cj.CurrentRunId, cj.State, cj.LeaseEndTime, cj.LeaseOwner + FROM CronJobs as cj + """ + query += where_clause + + response = txn.execute_sql(sql=query, params=params) + + res = [] + for row in response: + ( + job, + job_id, + creation_time, + enabled, + forced_run_requested, + last_run_status, + last_run_time, + current_run_id, + state, + lease_end_time, + lease_owner, + ) = row + res.append( + _CronJobFromRow( + job=job, + job_id=job_id, + creation_time=creation_time, + enabled=enabled, + forced_run_requested=forced_run_requested, + last_run_status=last_run_status, + last_run_time=last_run_time, + current_run_id=current_run_id, + state=state, + lease_end_time=lease_end_time, + lease_owner=lease_owner, + ) + ) + + return res + + +def _CronJobFromRow( + job: Optional[bytes] = None, + job_id: Optional[str] = None, + creation_time: Optional[datetime.datetime] = None, + enabled: Optional[bool] = None, + forced_run_requested: Optional[bool] = None, + last_run_status: Optional[int] = None, + last_run_time: Optional[datetime.datetime] = None, + current_run_id: Optional[str] = None, + state: Optional[bytes] = None, + lease_end_time: Optional[datetime.datetime] = None, + lease_owner: Optional[str] = None, +) -> flows_pb2.CronJob: + """Creates a CronJob object from a database result row.""" + + if job is not None: + parsed = flows_pb2.CronJob() + parsed.ParseFromString(job) + job = parsed + else: + job = flows_pb2.CronJob( + created_at=rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch(), + ) + if job_id is not None: + job.cron_job_id = job_id + if current_run_id is not None: + job.current_run_id = current_run_id + if enabled is not None: + job.enabled = enabled + if forced_run_requested is not None: + job.forced_run_requested = forced_run_requested + if last_run_status is not None: + job.last_run_status = last_run_status + if last_run_time is not None: + job.last_run_time = rdfvalue.RDFDatetime.FromDatetime( + last_run_time + ).AsMicrosecondsSinceEpoch() + if state is not None: + read_state = jobs_pb2.AttributedDict() + read_state.ParseFromString(state) + job.state.CopyFrom(read_state) + if creation_time is not None: + job.created_at = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + if lease_end_time is not None: + job.leased_until = rdfvalue.RDFDatetime.FromDatetime( + lease_end_time + ).AsMicrosecondsSinceEpoch() + if lease_owner is not None: + job.leased_by = lease_owner + + return job + +def _CronJobRunFromRow( + job_run: Optional[bytes] = None, + job_id: Optional[str] = None, + run_id: Optional[str] = None, + creation_time: Optional[datetime.datetime] = None, + finish_time: Optional[datetime.datetime] = None, + status: Optional[int] = None, + log_message: Optional[str] = None, + backtrace: Optional[str] = None, +) -> flows_pb2.CronJobRun: + """Creates a CronJobRun object from a database result row.""" + + if job_run is not None: + parsed = flows_pb2.CronJobRun() + parsed.ParseFromString(job_run) + job_run = parsed + else: + job_run = flows_pb2.CronJobRun() + + if job_id is not None: + job_run.cron_job_id = job_id + if run_id is not None: + job_run.run_id = run_id + if creation_time is not None: + job_run.created_at = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + if finish_time is not None: + job_run.finished_at = rdfvalue.RDFDatetime.FromDatetime( + finish_time + ).AsMicrosecondsSinceEpoch() + if status is not None: + job_run.status = status + if log_message is not None: + job_run.log_message = log_message + if backtrace is not None: + job_run.backtrace = backtrace + + return job_run \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py index 3c37a2ed5..9c4d52a0d 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py @@ -1,7 +1,7 @@ from absl.testing import absltest from grr_response_server.databases import db_cronjob_test -from grr_response_server.databases.local.spanner import spanner_test_lib +from grr_response_server.databases import spanner_test_lib def setUpModule() -> None: diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 950e5dcfd..f2cc50329 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -898,7 +898,7 @@ def Txn(txn) -> None: candidate_requests = needs_processing.get((client_id, flow_id), []) for r in candidate_requests: - if row[2] == r.request_id or r.start_time: + if int(row[2]) == r.request_id or r.start_time: req = flows_pb2.FlowProcessingRequest( client_id=client_id, flow_id=flow_id ) @@ -1229,7 +1229,7 @@ def _GetFlowResponsesPerRequestCounts( req_key = _RequestKey( client_id, flow_id, - request_id, + int(request_id), ) result[req_key] = count @@ -1304,7 +1304,7 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( columns=req_cols): client_id: str = row[0] flow_id: str = row[1] - request_id: str = row[2] + request_id: int = row[2] np: bool = row[3] start_time: Optional[rdfvalue.RDFDatetime] = None if row[4] is not None: @@ -1337,7 +1337,7 @@ def _BuildNeedsProcessingUpdates( for req_key in requests: rows.append([req_key.client_id, req_key.flow_id, - req_key.request_id, + str(req_key.request_id), True, ]) txn.update(table="FlowRequests", columns=columns, values=rows) @@ -1396,7 +1396,7 @@ def WriteFlowResponses( ], ], ) -> None: - """Writes Flow ressages and updates corresponding requests.""" + """Writes Flow messages and updates corresponding requests.""" responses_expected_by_request = {} callback_state_by_request = {} for batch in collection.Batch(responses, self._write_rows_batch_size): @@ -1535,7 +1535,7 @@ def Mutation(mut) -> None: key = [ request.client_id, request.flow_id, - request.request_id, + str(request.request_id), ] mut.Delete(table="FlowRequests", key=key) @@ -1680,7 +1680,7 @@ def Txn(txn) -> None: rows = [] columns = ["ClientId", "FlowId", "RequestId", "NextResponseId"] for request_id, response_id in next_response_id_updates.items(): - rows.append([client_id, flow_id, request_id, response_id]) + rows.append([client_id, flow_id, str(request_id), str(response_id)]) txn.update( table="FlowRequests", columns=columns, diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index 02d0443e2..3de610334 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -1,9 +1,9 @@ #!/usr/bin/env python """A module with hunt methods of the Spanner database implementation.""" -from typing import AbstractSet, Callable, Collection, Iterable, List, Mapping, Optional, Sequence +from typing import AbstractSet, Callable, Collection, Iterable, List, Mapping, Optional, Sequence, Set -from google.api_core.exceptions import AlreadyExists +from google.api_core.exceptions import AlreadyExists, NotFound from google.cloud import spanner as spanner_lib @@ -20,6 +20,71 @@ from grr_response_server.databases import spanner_utils from grr_response_server.models import hunts as models_hunts +def _HuntOutputPluginStateFromRow( + plugin_name: str, + plugin_args: Optional[bytes], + plugin_state: bytes, +) -> output_plugin_pb2.OutputPluginState: + """Creates OutputPluginState from the corresponding table's row data.""" + plugin_args_any = None + if plugin_args is not None: + plugin_args_any = any_pb2.Any() + plugin_args_any.ParseFromString(plugin_args) + + plugin_descriptor = output_plugin_pb2.OutputPluginDescriptor( + plugin_name=plugin_name, args=plugin_args_any + ) + + plugin_state_any = any_pb2.Any() + plugin_state_any.ParseFromString(plugin_state) + # TODO(b/120464908): currently AttributedDict is used to store output + # plugins' states. This is suboptimal and unsafe and should be refactored + # towards using per-plugin protos. + attributed_dict = jobs_pb2.AttributedDict() + plugin_state_any.Unpack(attributed_dict) + + return output_plugin_pb2.OutputPluginState( + plugin_descriptor=plugin_descriptor, plugin_state=attributed_dict + ) + + +def _BinsToQuery(bins: list[int], column_name: str) -> str: + """Builds an SQL query part to fetch counts corresponding to given bins.""" + result = [] + # With the current StatsHistogram implementation the last bin simply + # takes all the values that are greater than range_max_value of + # the one-before-the-last bin. range_max_value of the last bin + # is thus effectively ignored. + for prev_b, next_b in zip([0] + bins[:-1], bins[:-1] + [None]): + query = f"COUNT(CASE WHEN {column_name} >= {prev_b}" + if next_b is not None: + query += f" AND {column_name} < {next_b}" + + query += " THEN 1 END)" + + result.append(query) + + return ", ".join(result) + + +_HUNT_FLOW_CONDITION_TO_FLOW_STATE_MAPPING = { + abstract_db.HuntFlowsCondition.FAILED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.ERROR, + ), + abstract_db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.FINISHED, + ), + abstract_db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.ERROR, + flows_pb2.Flow.FlowState.FINISHED, + ), + abstract_db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY: ( + flows_pb2.Flow.FlowState.RUNNING, + ), + abstract_db.HuntFlowsCondition.CRASHED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.CRASHED, + ), +} class HuntsMixin: """A Spanner database mixin with implementation of flow methods.""" @@ -57,6 +122,74 @@ def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt): ) + def _UpdateHuntObject( + self, + txn, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[float] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, + ): + """Updates the hunt object within a given transaction.""" + params = { + "hunt_id": hunt_id, + } + assignments = ["h.LastUpdateTime = PENDING_COMMIT_TIMESTAMP()"] + + if duration is not None: + assignments.append("h.DurationMicros = @duration_micros") + params["duration_micros"] = int(duration.microseconds) + + if client_rate is not None: + assignments.append("h.ClientRate = @client_rate") + params["client_rate"] = float(client_rate) + + if client_limit is not None: + assignments.append("h.ClientLimit = @client_limit") + params["client_limit"] = int(client_limit) + + if hunt_state is not None: + assignments.append("h.State = @hunt_state") + params["hunt_state"] = int(hunt_state) + + if hunt_state_reason is not None: + assignments.append("h.StateReason = @hunt_state_reason") + params["hunt_state_reason"] = int(hunt_state_reason) + + if hunt_state_comment is not None: + assignments.append("h.StateComment = @hunt_state_comment") + params["hunt_state_comment"] = hunt_state_comment + + if start_time is not None: + assignments.append( + "h.InitStartTime = IFNULL(h.InitStartTime, @start_time)" + ) + assignments.append("h.LastStartTime = @start_time") + params["start_time"] = start_time.AsDatetime() + + if num_clients_at_start_time is not None: + assignments.append( + "h.ClientCountAtStartTime = @client_count_at_start_time" + ) + params["client_count_at_start_time"] = spanner_lib.UInt64( + num_clients_at_start_time + ) + + query = f""" + UPDATE Hunts AS h + SET {", ".join(assignments)} + WHERE h.HuntId = @hunt_id + """ + + txn.execute_sql(sql=query, params=params) + @db_utils.CallLogged @db_utils.CallAccounted def UpdateHuntObject( @@ -75,19 +208,95 @@ def UpdateHuntObject( ): """Updates the hunt object.""" + def Txn(txn) -> None: + # Make sure the hunt is there. + try: + keyset = spanner_lib.KeySet(keys=[[hunt_id]]) + txn.read(table="Hunts", keyset=keyset, columns=["HuntId",]) + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + # Then update the hunt. + self._UpdateHuntObject( + txn, + hunt_id, + duration=duration, + client_rate=client_rate, + client_limit=client_limit, + hunt_state=hunt_state, + hunt_state_reason=hunt_state_reason, + hunt_state_comment=hunt_state_comment, + start_time=start_time, + num_clients_at_start_time=num_clients_at_start_time, + ) + + self.db.Transact(Txn, txn_tag="UpdateHuntObject") @db_utils.CallLogged @db_utils.CallAccounted def DeleteHuntObject(self, hunt_id: str) -> None: """Deletes a hunt object with a given id.""" - + keyset = spanner_lib.KeySet(keys=[[hunt_id]]) + self.db.Delete( + table="Hunts", keyset=keyset, txn_tag="DeleteHuntObject" + ) @db_utils.CallLogged @db_utils.CallAccounted def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: """Reads a hunt object from the database.""" - return None + cols = [ + "Hunt", + "CreationTime", + "LastUpdateTime", + "DurationMicros", + "Description", + "Creator", + "ClientRate", + "ClientLimit", + "State", + "StateReason", + "StateComment", + "InitStartTime", + "LastStartTime", + "ClientCountAtStartTime", + ] + + try: + row = self.db.Read(table="Hunts", key=[hunt_id], cols=cols) + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + hunt_obj = hunts_pb2.Hunt() + hunt_obj.ParseFromString(row[0]) + hunt_obj.create_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[1]) + ) + hunt_obj.last_update_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[2]) + ) + hunt_obj.duration = rdfvalue.DurationSeconds.From( + row[3], rdfvalue.MICROSECONDS + ).ToInt(rdfvalue.SECONDS) + hunt_obj.description = row[4] + hunt_obj.creator = row[5] + hunt_obj.client_rate = row[6] + hunt_obj.client_limit = row[7] + hunt_obj.hunt_state = row[8] + hunt_obj.hunt_state_reason = row[9] + hunt_obj.hunt_state_comment = row[10] + if row[11] is not None: + hunt_obj.init_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + if row[12] is not None: + hunt_obj.last_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[12]) + ) + hunt_obj.num_clients_at_start_time = row[13] + + return hunt_obj @db_utils.CallLogged @db_utils.CallAccounted @@ -98,15 +307,107 @@ def ReadHuntObjects( with_creator: Optional[str] = None, created_after: Optional[rdfvalue.RDFDatetime] = None, with_description_match: Optional[str] = None, - created_by: Optional[AbstractSet[str]] = None, - not_created_by: Optional[AbstractSet[str]] = None, - with_states: Optional[ - Collection[hunts_pb2.Hunt.HuntState.ValueType] - ] = None, - ) -> List[hunts_pb2.Hunt]: + created_by: Optional[Set[str]] = None, + not_created_by: Optional[Set[str]] = None, + #with_states: Optional[Collection[hunts_pb2.Hunt.HuntState]] = None, # TODO + ) -> list[hunts_pb2.Hunt]: """Reads hunt objects from the database.""" + conditions = [] + params = { + "limit": count, + "offset": offset, + } + + if with_creator is not None: + conditions.append("h.Creator = {creator}") + params["creator"] = with_creator + + if created_by is not None: + conditions.append("h.Creator IN UNNEST({created_by})") + params["created_by"] = list(created_by) + + if not_created_by is not None: + conditions.append("h.Creator NOT IN UNNEST({not_created_by})") + params["not_created_by"] = list(not_created_by) + + if created_after is not None: + conditions.append("h.CreationTime > {creation_time}") + params["creation_time"] = created_after.AsDatetime() + + if with_description_match is not None: + conditions.append("h.Description LIKE {description}") + params["description"] = f"%{with_description_match}%" + + if with_states is not None: + if not with_states: + return [] + ors = [] + for i, state in enumerate(with_states): + ors.append(f"h.State = {{state_{i}}}") + params[f"state_{i}"] = int(state) + conditions.append("(" + " OR ".join(ors) + ")") + + query = """ + SELECT + h.HuntId, + h.CreationTime, + h.LastUpdateTime, + h.Creator, + h.DurationMicros, + h.Description, + h.ClientRate, + h.ClientLimit, + h.State, + h.StateReason, + h.StateComment, + h.InitStartTime, + h.LastStartTime, + h.ClientCountAtStartTime, + h.Hunt + FROM Hunts AS h + """ + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += """ + ORDER BY h.CreationTime DESC + LIMIT {limit} OFFSET {offset} + """ + result = [] + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntObjects"): + hunt_obj = hunts_pb2.Hunt() + hunt_obj.ParseFromString(row[14]) + + hunt_obj.create_time = int(rdfvalue.RDFDatetime.FromDatetime(row[1])) + hunt_obj.last_update_time = int(rdfvalue.RDFDatetime.FromDatetime(row[2])) + hunt_obj.creator = row[3] + hunt_obj.duration = rdfvalue.DurationSeconds.From( + row[4], rdfvalue.MICROSECONDS + ).ToInt(rdfvalue.SECONDS) + + hunt_obj.description = row[5] + hunt_obj.client_rate = row[6] + hunt_obj.client_limit = row[7] + hunt_obj.hunt_state = row[8] + hunt_obj.hunt_state_reason = row[9] + hunt_obj.hunt_state_comment = row[10] + + if row[11] is not None: + hunt_obj.init_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + + if row[12] is not None: + hunt_obj.last_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[12]) + ) + + hunt_obj.num_clients_at_start_time = row[13] + + result.append(hunt_obj) return result @@ -127,7 +428,100 @@ def ListHuntObjects( ) -> Iterable[hunts_pb2.HuntMetadata]: """Reads metadata for hunt objects from the database.""" + conditions = [] + params = { + "limit": count, + "offset": offset, + } + + if with_creator is not None: + conditions.append("h.Creator = {creator}") + params["creator"] = with_creator + + if created_by is not None: + conditions.append("h.Creator IN UNNEST({created_by})") + params["created_by"] = list(created_by) + + if not_created_by is not None: + conditions.append("h.Creator NOT IN UNNEST({not_created_by})") + params["not_created_by"] = list(not_created_by) + + if created_after is not None: + conditions.append("h.CreationTime > {creation_time}") + params["creation_time"] = created_after.AsDatetime() + + if with_description_match is not None: + conditions.append("h.Description LIKE {description}") + params["description"] = f"%{with_description_match}%" + + if with_states is not None: + if not with_states: + return [] + ors = [] + for i, state in enumerate(with_states): + ors.append(f"h.State = {{state_{i}}}") + params[f"state_{i}"] = int(state) + conditions.append("(" + " OR ".join(ors) + ")") + + query = """ + SELECT + h.HuntId, + h.CreationTime, + h.LastUpdateTime, + h.Creator, + h.DurationMicros, + h.Description, + h.ClientRate, + h.ClientLimit, + h.State, + h.StateComment, + h.InitStartTime, + h.LastStartTime, + FROM Hunts AS h + """ + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += """ + ORDER BY h.CreationTime DESC + LIMIT {limit} OFFSET {offset} + """ + result = [] + for row in self.db.ParamQuery(query, params, txn_tag="ListHuntObjects"): + hunt_mdata = hunts_pb2.HuntMetadata() + hunt_mdata.hunt_id = row[0] + + hunt_mdata.create_time = int(rdfvalue.RDFDatetime.FromDatetime(row[1])) + hunt_mdata.last_update_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[2]) + ) + hunt_mdata.creator = row[3] + hunt_mdata.duration = int( + rdfvalue.Duration.From(row[4], rdfvalue.MICROSECONDS).ToInt( + rdfvalue.SECONDS + ) + ) + if row[5]: + hunt_mdata.description = row[5] + hunt_mdata.client_rate = row[6] + hunt_mdata.client_limit = row[7] + hunt_mdata.hunt_state = row[8] + if row[9]: + hunt_mdata.hunt_state_comment = row[9] + + if row[10] is not None: + hunt_mdata.init_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[10]) + ) + + if row[11] is not None: + hunt_mdata.last_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + + result.append(hunt_mdata) return result @@ -144,8 +538,61 @@ def ReadHuntResults( with_timestamp: Optional[rdfvalue.RDFDatetime] = None, ) -> Iterable[flows_pb2.FlowResult]: """Reads hunt results of a given hunt using given query options.""" + params = { + "hunt_id": hunt_id, + "offset": offset, + "count": count, + } + + query = """ + SELECT t.Payload, t.CreationTime, t.Tag, t.ClientId, t.FlowId + FROM FlowResults AS t + WHERE t.HuntId = {hunt_id} + AND t.FlowId = {hunt_id} + """ + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + + if with_substring is not None: + query += ( + " AND STRPOS(SAFE_CONVERT_BYTES_TO_STRING(t.Payload.value), " + "{substring}) != 0" + ) + params["substring"] = with_substring + + if with_timestamp is not None: + query += " AND t.CreationTime = {creation_time}" + params["creation_time"] = with_timestamp.AsDatetime() + + query += """ + ORDER BY t.CreationTime ASC LIMIT {count} OFFSET {offset} + """ results = [] + for ( + payload_bytes, + creation_time, + tag, + client_id, + flow_id, + ) in self.db.ParamQuery(query, params, txn_tag="ReadHuntResults"): + result = flows_pb2.FlowResult() + result.hunt_id = hunt_id + result.client_id = client_id + result.flow_id = flow_id + result.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + result.payload.ParseFromString(payload_bytes) + + if tag is not None: + result.tag = tag + + results.append(result) return results @@ -158,15 +605,49 @@ def CountHuntResults( with_type: Optional[str] = None, ) -> int: """Counts hunt results of a given hunt using given query options.""" + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT COUNT(*) + FROM FlowResults AS t + WHERE t.HuntId = {hunt_id} + """ + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type - return 0 + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountHuntResults" + ) + return count @db_utils.CallLogged @db_utils.CallAccounted def CountHuntResultsByType(self, hunt_id: str) -> Mapping[str, int]: """Returns counts of items in hunt results grouped by type.""" + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT t.RdfType, COUNT(*) + FROM FlowResults AS t + WHERE t.HuntId = {hunt_id} + GROUP BY t.RdfType + """ result = {} + for type_name, count in self.db.ParamQuery( + query, params, txn_tag="CountHuntResultsByType" + ): + result[type_name] = count return result @@ -180,8 +661,47 @@ def ReadHuntLogEntries( with_substring: Optional[str] = None, ) -> Sequence[flows_pb2.FlowLogEntry]: """Reads hunt log entries of a given hunt using given query options.""" + params = { + "hunt_id": hunt_id, + "offset": offset, + "count": count, + } + + query = """ + SELECT l.ClientId, + l.FlowId, + l.CreationTime, + l.Message + FROM FlowLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.FlowId = {hunt_id} + """ + + if with_substring is not None: + query += " AND STRPOS(l.Message, {substring}) != 0" + params["substring"] = with_substring + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count results = [] + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntLogEntries"): + client_id, flow_id, creation_time, message = row + + result = flows_pb2.FlowLogEntry() + result.hunt_id = hunt_id + result.client_id = client_id + result.flow_id = flow_id + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.message = message + + results.append(result) return results @@ -189,8 +709,21 @@ def ReadHuntLogEntries( @db_utils.CallAccounted def CountHuntLogEntries(self, hunt_id: str) -> int: """Returns number of hunt log entries of a given hunt.""" + params = { + "hunt_id": hunt_id, + } - return 0 + query = """ + SELECT COUNT(*) + FROM FlowLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.FlowId = {hunt_id} + """ + + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountHuntLogEntries" + ) + return count @db_utils.CallLogged @db_utils.CallAccounted @@ -204,7 +737,49 @@ def ReadHuntOutputPluginLogEntries( ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: """Reads hunt output plugin log entries.""" + query = """ + SELECT l.ClientId, + l.FlowId, + l.CreationTime, + l.Type, l.Message + FROM FlowOutputPluginLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "hunt_id": hunt_id, + "output_plugin_id": output_plugin_id, + } + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + results = [] + for row in self.db.ParamQuery( + query, params, txn_tag="ReadHuntOutputPluginLogEntries" + ): + client_id, flow_id, creation_time, int_type, message = row + + result = flows_pb2.FlowOutputPluginLogEntry() + result.hunt_id = hunt_id + result.client_id = client_id + result.flow_id = flow_id + result.output_plugin_id = output_plugin_id + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.log_entry_type = int_type + result.message = message + + results.append(result) return results @@ -217,17 +792,56 @@ def CountHuntOutputPluginLogEntries( with_type: Optional[str] = None, ) -> int: """Returns number of hunt output plugin log entries of a given hunt.""" + query = """ + SELECT COUNT(*) + FROM FlowOutputPluginLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "hunt_id": hunt_id, + "output_plugin_id": output_plugin_id, + } - return 0 + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountHuntOutputPluginLogEntries" + ) + return count @db_utils.CallLogged @db_utils.CallAccounted def ReadHuntOutputPluginsStates( self, hunt_id: str - ) -> List[output_plugin_pb2.OutputPluginState]: + ) -> list[output_plugin_pb2.OutputPluginState]: """Reads all hunt output plugins states of a given hunt.""" + # Make sure the hunt is there. + try: + self.db.Read(table="Hunts", key=[hunt_id,], cols=("HuntId",)) + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT s.Name, + s.Args, + s.State + FROM HuntOutputPlugins AS s + WHERE s.HuntId = {hunt_id} + """ results = [] + for row in self.db.ParamQuery( + query, params, txn_tag="ReadHuntOutputPluginsStates" + ): + name, args, state = row + results.append(_HuntOutputPluginStateFromRow(name, args, state)) return results @@ -240,6 +854,26 @@ def WriteHuntOutputPluginsStates( ) -> None: """Writes hunt output plugin states for a given hunt.""" + def Mutation(mut) -> None: + for index, state in enumerate(states): + state_any = any_pb2.Any() + state_any.Pack(state.plugin_state) + row = { + "HuntId": hunt_id, + "OutputPluginId": index, + "Name": state.plugin_descriptor.plugin_name, + "State": state_any.SerializeToString(), + } + + if state.plugin_descriptor.HasField("args"): + row["Args"] = state.plugin_descriptor.args.SerializeToString() + + mut.InsertOrUpdate(table="HuntOutputPlugins", row=row) + + try: + self.db.Mutate(Mutation, txn_tag="WriteHuntOutputPluginsStates") + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e @db_utils.CallLogged @db_utils.CallAccounted @@ -251,6 +885,30 @@ def UpdateHuntOutputPluginState( ) -> None: """Updates hunt output plugin state for a given output plugin.""" + def Txn(txn) -> None: + row = txn.read( + table="HuntOutputPlugins", + keyset=spanner_lib.KeySet(keys=[hunt_id, state_index]), + columns=["Name", "Args", "State"], + ) + state = _HuntOutputPluginStateFromRow( + row["Name"], row["Args"], row["State"] + ) + + modified_plugin_state = update_fn(state.plugin_state) + modified_plugin_state_any = any_pb2.Any() + modified_plugin_state_any.Pack(modified_plugin_state) + row_update = { + "HuntId": int_hunt_id, + "OutputPluginId": state_index, + "State": modified_plugin_state_any.SerializeToString(), + } + txn.Update("HuntOutputPlugins", row_update) + + try: + self.db.Transact(Txn, txn_tag="UpdateHuntOutputPluginState") + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e @db_utils.CallLogged @db_utils.CallAccounted @@ -263,7 +921,90 @@ def ReadHuntFlows( # pytype: disable=annotation-type-mismatch ) -> Sequence[flows_pb2.Flow]: """Reads hunt flows matching given conditions.""" + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT + f.ClientId, + f.Creator, + f.Name, + f.State, + f.CreationTime, + f.UpdateTime, + f.Crash, + f.ProcessingWorker, + f.ProcessingStartTime, + f.ProcessingEndTime, + f.NextRequestToProcess, + f.Flow + FROM Flows AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + """ + + if filter_condition != abstract_db.HuntFlowsCondition.UNSET: + states = _HUNT_FLOW_CONDITION_TO_FLOW_STATE_MAPPING[filter_condition] # pytype: disable=unsupported-operands + conditions = (f"f.State = {{state_{i}}}" for i in range(len(states))) + query += "AND (" + " OR ".join(conditions) + ")" + for i, state in enumerate(states): + params[f"state_{i}"] = int(state) + + query += """ + ORDER BY f.UpdateTime ASC LIMIT {count} OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + results = [] + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntFlows"): + ( + client_id, + creator, + name, + state, + creation_time, + update_time, + crash, + processing_worker, + processing_start_time, + processing_end_time, + next_request_to_process, + flow_payload_bytes, + ) = row + + result = flows_pb2.Flow() + result.ParseFromString(flow_payload_bytes) + result.client_id = client_id + result.parent_hunt_id = hunt_id + result.creator = creator + result.flow_class_name = name + result.flow_state = state + result.next_request_to_process = int(next_request_to_process) + result.create_time = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + result.last_update_time = int( + rdfvalue.RDFDatetime.FromDatetime(update_time) + ) + + if crash is not None: + client_crash = jobs_pb2.ClientCrash() + client_crash.ParseFromString(crash) + result.client_crash_info.CopyFrom(client_crash) + + if processing_worker: + result.processing_on = processing_worker + + if processing_start_time is not None: + result.processing_since = int( + rdfvalue.RDFDatetime.FromDatetime(processing_start_time) + ) + if processing_end_time is not None: + result.processing_deadline = int( + rdfvalue.RDFDatetime.FromDatetime(processing_end_time) + ) + + results.append(result) return results @@ -278,6 +1019,37 @@ def ReadHuntFlowErrors( """Returns errors for flows of the given hunt.""" results = {} + query = """ + SELECT f.ClientId, + f.UpdateTime, + f.Flow.error_message, + f.Flow.backtrace, + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + AND f.State = 'ERROR' + ORDER BY f.UpdateTime ASC LIMIT {count} + OFFSET {offset} + """ + + params = { + "hunt_id": hunt_id, + "offset": offset, + "count": count, + } + + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntFlowErrors"): + (client_id, time, message, backtrace) = row + + info = abstract_db.FlowErrorInfo( + message=message, + time=rdfvalue.RDFDatetime.FromDate(time), + ) + if backtrace: + info.backtrace = backtrace + + results[client_id] = info + return results @db_utils.CallLogged @@ -291,7 +1063,27 @@ def CountHuntFlows( # pytype: disable=annotation-type-mismatch ) -> int: """Counts hunt flows matching given conditions.""" - return 0 + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT COUNT(*) + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + """ + + if filter_condition != abstract_db.HuntFlowsCondition.UNSET: + states = _HUNT_FLOW_CONDITION_TO_FLOW_STATE_MAPPING[filter_condition] # pytype: disable=unsupported-operands + query += f""" AND({ + " OR ".join(f"f.State = {{state_{i}}}" for i in range(len(states))) + })""" + for i, state in enumerate(states): + params[f"state_{i}"] = int(state) + + (count,) = self.db.ParamQuerySingle(query, params, txn_tag="CountHuntFlows") + return count @db_utils.CallLogged @db_utils.CallAccounted @@ -301,7 +1093,30 @@ def ReadHuntFlowsStatesAndTimestamps( ) -> Sequence[abstract_db.FlowStateAndTimestamps]: """Reads hunt flows states and timestamps.""" + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT f.State, f.CreationTime, f.UpdateTime + FROM Flows AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + """ + results = [] + for row in self.db.ParamQuery( + query, params, txn_tag="ReadHuntFlowsStatesAndTimestamps" + ): + int_state, creation_time, update_time = row + + results.append( + abstract_db.FlowStateAndTimestamps( + flow_state=int_state, + create_time=rdfvalue.RDFDatetime.FromDatetime(creation_time), + last_update_time=rdfvalue.RDFDatetime.FromDatetime(update_time), + ) + ) return results @@ -313,7 +1128,90 @@ def ReadHuntsCounters( ) -> Mapping[str, abstract_db.HuntCounters]: """Reads hunt counters for several of hunt ids.""" - return None + params = { + "hunt_ids": hunt_ids, + } + + states_query = """ + SELECT f.ParentHuntID, f.State, COUNT(*) + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntID IN UNNEST({hunt_ids}) + AND f.FlowId IN UNNEST({hunt_ids}) + AND f.FlowId = f.ParentHuntID + GROUP BY f.ParentHuntID, f.State + """ + counts_by_state_per_hunt = dict.fromkeys(hunt_ids, {}) + for hunt_id, state, count in self.db.ParamQuery( + states_query, params, txn_tag="ReadHuntCounters_1" + ): + counts_by_state_per_hunt[hunt_id][state] = count + + hunt_counters = dict.fromkeys( + hunt_ids, + abstract_db.HuntCounters( + num_clients=0, + num_successful_clients=0, + num_failed_clients=0, + num_clients_with_results=0, + num_crashed_clients=0, + num_running_clients=0, + num_results=0, + total_cpu_seconds=0, + total_network_bytes_sent=0, + ), + ) + + resources_results_query = """ + SELECT + f.ParentHuntID, + SUM(f.UserCpuTimeUsed + f.SystemCpuTimeUsed), + SUM(f.NetworkBytesSent), + SUM(f.ReplyCount), + COUNT(IF(f.ReplyCount > 0, f.ClientId, NULL)) + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntID IN UNNEST({hunt_ids}) + AND f.FlowId IN UNNEST({hunt_ids}) + AND f.ParentHuntID = f.FlowID + GROUP BY f.ParentHuntID + """ + for ( + hunt_id, + total_cpu_seconds, + total_network_bytes_sent, + num_results, + num_clients_with_results, + ) in self.db.ParamQuery( + resources_results_query, params, txn_tag="ReadHuntCounters_2" + ): + counts_by_state = counts_by_state_per_hunt[hunt_id] + num_successful_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.FINISHED, 0 + ) + num_failed_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.ERROR, 0 + ) + num_crashed_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.CRASHED, 0 + ) + num_running_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.RUNNING, 0 + ) + num_clients = sum(counts_by_state.values()) + + hunt_counters[hunt_id] = abstract_db.HuntCounters( + num_clients=num_clients, + num_successful_clients=num_successful_clients, + num_failed_clients=num_failed_clients, + num_clients_with_results=num_clients_with_results, + num_crashed_clients=num_crashed_clients, + num_running_clients=num_running_clients, + # Spanner's SUM on no elements returns NULL - accounting for + # this here. + num_results=num_results or 0, + total_cpu_seconds=total_cpu_seconds or 0, + total_network_bytes_sent=total_network_bytes_sent or 0, + ) + return hunt_counters @db_utils.CallLogged @db_utils.CallAccounted @@ -323,4 +1221,133 @@ def ReadHuntClientResourcesStats( ) -> jobs_pb2.ClientResourcesStats: """Read hunt client resources stats.""" - return None + params = { + "hunt_id": hunt_id, + } + + # For some reason Spanner SQL doesn't have STDDEV_POP aggregate + # function. Thus, we have to effectively reimplement it ourselves + # using other aggregate functions. + # For the reference, see: + # http://google3/ops/security/grr/core/grr_response_core/lib/rdfvalues/stats.py;l=121;rcl=379683316 + query = """ + SELECT + COUNT(*), + SUM(f.UserCpuTimeUsed), + SQRT(SUM(POW(f.UserCpuTimeUsed, 2)) / COUNT(*) - POW(AVG(f.UserCpuTimeUsed), 2)), + SUM(f.SystemCpuTimeUsed), + SQRT(SUM(POW(f.SystemCpuTimeUsed, 2)) / COUNT(*) - POW(AVG(f.SystemCpuTimeUsed), 2)), + SUM(f.NetworkBytesSent), + SQRT(SUM(POW(f.NetworkBytesSent, 2)) / COUNT(*) - POW(AVG(f.NetworkBytesSent), 2)), + """ + + query += ", ".join([ + _BinsToQuery(models_hunts.CPU_STATS_BINS, "f.UserCpuTimeUsed"), + _BinsToQuery(models_hunts.CPU_STATS_BINS, "f.SystemCpuTimeUsed"), + _BinsToQuery(models_hunts.NETWORK_STATS_BINS, "f.NetworkBytesSent"), + ]) + + query += """ + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntID = {hunt_id} AND f.FlowId = {hunt_id} + """ + + response = self.db.ParamQuerySingle( + query, params, txn_tag="ReadHuntClientResourcesStats_1" + ) + + ( + count, + user_sum, + user_stddev, + system_sum, + system_stddev, + network_sum, + network_stddev, + ) = response[:7] + + stats = jobs_pb2.ClientResourcesStats( + user_cpu_stats=jobs_pb2.RunningStats( + num=count, + sum=user_sum, + stddev=user_stddev, + ), + system_cpu_stats=jobs_pb2.RunningStats( + num=count, + sum=system_sum, + stddev=system_stddev, + ), + network_bytes_sent_stats=jobs_pb2.RunningStats( + num=count, + sum=network_sum, + stddev=network_stddev, + ), + ) + + offset = 7 + user_cpu_stats_histogram = jobs_pb2.StatsHistogram() + for b_num, b_max_value in zip( + response[offset:], models_hunts.CPU_STATS_BINS + ): + user_cpu_stats_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.user_cpu_stats.histogram.CopyFrom(user_cpu_stats_histogram) + + offset += len(models_hunts.CPU_STATS_BINS) + system_cpu_stats_histogram = jobs_pb2.StatsHistogram() + for b_num, b_max_value in zip( + response[offset:], models_hunts.CPU_STATS_BINS + ): + system_cpu_stats_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.system_cpu_stats.histogram.CopyFrom(system_cpu_stats_histogram) + + offset += len(models_hunts.CPU_STATS_BINS) + network_bytes_histogram = jobs_pb2.StatsHistogram() + for b_num, b_max_value in zip( + response[offset:], models_hunts.NETWORK_STATS_BINS + ): + network_bytes_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.network_bytes_sent_stats.histogram.CopyFrom(network_bytes_histogram) + + clients_query = """ + SELECT + f.ClientID, + f.FlowID, + f.UserCPUTimeUsed, + f.SystemCPUTimeUsed, + f.NetworkBytesSent + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE + f.ParentHuntID = {hunt_id} AND + f.FlowId = {hunt_id} AND + (f.UserCpuTimeUsed > 0 OR + f.SystemCpuTimeUsed > 0 OR + f.NetworkBytesSent > 0) + ORDER BY (f.UserCPUTimeUsed + f.SystemCPUTimeUsed) DESC + LIMIT 10 + """ + + responses = self.db.ParamQuery( + clients_query, params, txn_tag="ReadHuntClientResourcesStats_2" + ) + for cid, fid, ucpu, scpu, nbs in responses: + client_id = cid + flow_id = fid + stats.worst_performers.append( + jobs_pb2.ClientResources( + client_id=str(rdf_client.ClientURN.FromHumanReadable(client_id)), + session_id=str(rdfvalue.RDFURN(client_id).Add(flow_id)), + cpu_usage=jobs_pb2.CpuSeconds( + user_cpu_time=ucpu, + system_cpu_time=scpu, + ), + network_bytes_sent=nbs, + ) + ) + + return stats diff --git a/grr/server/grr_response_server/databases/spanner_hunts_test.py b/grr/server/grr_response_server/databases/spanner_hunts_test.py index 28f412d5c..607721ae3 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts_test.py +++ b/grr/server/grr_response_server/databases/spanner_hunts_test.py @@ -2,7 +2,7 @@ from grr_response_server.databases import db_hunts_test from grr_response_server.databases import db_test_utils -from grr_response_server.databases.local.spanner import spanner_test_lib +from grr_response_server.databases import spanner_test_lib def setUpModule() -> None: diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index eca20a458..ad1f0166f 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -12,7 +12,6 @@ from grr_response_server.databases import spanner_utils from grr_response_server.models import paths as models_paths - class PathsMixin: """A Spanner database mixin with implementation of path methods.""" @@ -28,7 +27,88 @@ def WritePathInfos( path_infos: Iterable[objects_pb2.PathInfo], ) -> None: """Writes a collection of path records for a client.""" + int_client_id = client_id + + # Special case for empty list of paths because Spanner does not like empty + # mutations. We still have to validate the client id. + if not path_infos: + try: + self.db.Read(table="Clients", key=[int_client_id], cols=()) + except spanner_errors.RowNotFoundError as error: + raise abstract_db.UnknownClientError(client_id) from error + + return + + def Mutation(mut: spanner_utils.Mutation) -> None: + ancestors = set() + + for path_info in path_infos: + int_path_type = int(path_info.path_type) + path = EncodePathComponents(path_info.components) + + row = { + "ClientId": int_client_id, + "Type": int_path_type, + "Path": path, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "IsDir": path_info.directory, + "Depth": len(path_info.components), + } + + if path_info.HasField("stat_entry"): + row["LastFileStatTime"] = spanner_lib.COMMIT_TIMESTAMP + + file_stat_row = { + "ClientId": int_client_id, + "Type": int_path_type, + "Path": path, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "Stat": path_info.stat_entry, + } + else: + file_stat_row = None + + if path_info.HasField("hash_entry"): + row["LastFileHashTime"] = spanner_lib.CommitTimestamp() + + file_hash_row = { + "ClientId": int_client_id, + "Type": int_path_type, + "Path": path, + "CreationTime": spanner_lib.CommitTimestamp(), + "Hash": path_info.hash_entry, + } + else: + file_hash_row = None + + mut.InsertOrUpdate(table="Paths", row=row) + if file_stat_row is not None: + mut.Insert(table="PathFileStats", row=file_stat_row) + if file_hash_row is not None: + mut.Insert(table="PathFileHashes", row=file_hash_row) + + for path_info_ancestor in models_paths.GetAncestorPathInfos(path_info): + components = tuple(path_info_ancestor.components) + ancestors.add((path_info.path_type, components)) + for path_type, components in ancestors: + row = { + "ClientId": int_client_id, + "Type": int(path_type), + "Path": EncodePathComponents(components), + "CreationTime": spanner_lib.CommitTimestamp(), + "IsDir": True, + "Depth": len(components), + } + mut.InsertOrUpdate(table="Paths", row=row) + + try: + self.db.Mutate(Mutation, txn_tag="WritePathInfos") + except NotFound as error: + if "Parent row is missing: Clients" in str(error): + raise abstract_db.UnknownClientError(client_id) from error + else: + raise @db_utils.CallLogged @db_utils.CallAccounted @@ -39,8 +119,67 @@ def ReadPathInfos( components_list: Collection[Sequence[str]], ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: """Retrieves path info records for given paths.""" + # Early return to avoid issues with unnesting empty path list in the query. + if not components_list: + return {} + + results = {tuple(components): None for components in components_list} + + query = """ + SELECT p.Path, p.CreationTime, p.IsDir, + ps.CreationTime, ps.Stat, + ph.CreationTime, ph.Hash, + FROM Paths AS p + LEFT JOIN PathFileStats AS ps + ON p.ClientId = ps.ClientId + AND p.Type = ps.Type + AND p.Path = ps.Path + AND p.LastFileStatTime = ps.CreationTime + LEFT JOIN PathFileHashes AS ph + ON p.ClientId = ph.ClientId + AND p.Type = ph.Type + AND p.Path = ph.Path + AND p.LastFileHashTime = ph.CreationTime + WHERE p.ClientId = {client_id} + AND p.Type = {type} + AND p.Path IN UNNEST({paths}) + """ + params = { + "client_id": client_id, + "type": int(path_type), + "paths": list(map(EncodePathComponents, components_list)), + } + + for row in self.db.ParamQuery(query, params, txn_tag="ReadPathInfos"): + path, creation_time, is_dir, *row = row + stat_creation_time, stat_bytes, *row = row + hash_creation_time, hash_bytes, *row = row + () = row + + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=DecodePathComponents(path), + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + directory=is_dir, + ) + + if stat_bytes is not None: + path_info.stat_entry.ParseFromString(stat_bytes) + path_info.last_stat_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + stat_creation_time + ).AsMicrosecondsSinceEpoch() + + if hash_bytes is not None: + path_info.hash_entry.ParseFromString(hash_bytes) + path_info.last_hash_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + hash_creation_time + ).AsMicrosecondsSinceEpoch() - return {} + results[tuple(path_info.components)] = path_info + + return results @db_utils.CallLogged @db_utils.CallAccounted @@ -52,8 +191,79 @@ def ReadPathInfo( timestamp: Optional[rdfvalue.RDFDatetime] = None, ) -> objects_pb2.PathInfo: """Retrieves a path info record for a given path.""" + query = """ + SELECT p.CreationTime, + p.IsDir, + -- File stat information. + p.LastFileStatTime, + (SELECT s.Stat + FROM PathFileStats AS s + WHERE s.ClientId = {client_id} + AND s.Type = {type} + AND s.Path = {path} + AND s.CreationTime <= {timestamp} + ORDER BY s.CreationTime DESC + LIMIT 1), + -- File hash information. + p.LastFileHashTime, + (SELECT h.Hash + FROM PathFileHashes AS h + WHERE h.ClientId = {client_id} + AND h.Type = {type} + AND h.Path = {path} + AND h.CreationTime <= {timestamp} + ORDER BY h.CreationTime DESC + LIMIT 1) + FROM Paths AS p + WHERE p.ClientId = {client_id} + AND p.Type = {type} + AND p.Path = {path} + """ + params = { + "client_id": client_id, + "type": int(path_type), + "path": EncodePathComponents(components), + } + + if timestamp is not None: + params["timestamp"] = timestamp.AsDatetime() + else: + params["timestamp"] = rdfvalue.RDFDatetime.Now().AsDatetime() + + try: + row = self.db.ParamQuerySingle(query, params, txn_tag="ReadPathInfo") + except iterator.NoYieldsError: + raise abstract_db.UnknownPathError(client_id, path_type, components) # pylint: disable=raise-missing-from + + creation_time, is_dir, *row = row + last_file_stat_time, stat_bytes, *row = row + last_file_hash_time, hash_bytes, *row = row + () = row + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=components, + directory=is_dir, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + if last_file_stat_time is not None: + result.last_stat_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + last_file_stat_time + ).AsMicrosecondsSinceEpoch() + if last_file_hash_time is not None: + result.last_hash_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + last_file_hash_time + ).AsMicrosecondsSinceEpoch() + + if stat_bytes is not None: + result.stat_entry.ParseFromString(stat_bytes) + if hash_bytes is not None: + result.hash_entry.ParseFromString(hash_bytes) - return None + return result @db_utils.CallLogged @db_utils.CallAccounted @@ -66,8 +276,144 @@ def ListDescendantPathInfos( max_depth: Optional[int] = None, ) -> Sequence[objects_pb2.PathInfo]: """Lists path info records that correspond to descendants of given path.""" + results = [] + + # The query should include not only descendants of the path but the listed + # path path itself as well. We need to do that to ensure the path actually + # exists and raise if it does not (or if it is not a directory). + query = """ + SELECT p.Path, + p.CreationTime, + p.IsDir, + -- File stat information. + p.LastFileStatTime, + (SELECT s.Stat + FROM PathFileStats AS s + WHERE s.ClientId = p.ClientId + AND s.Type = p.Type + AND s.path = p.Path + AND s.CreationTime <= {timestamp} + ORDER BY s.CreationTime DESC + LIMIT 1), + -- File hash information. + p.LastFileHashTime, + (SELECT h.Hash + FROM PathFileHashes AS h + WHERE h.ClientId = p.ClientId + AND h.Type = p.Type + AND h.Path = p.Path + AND h.CreationTime <= {timestamp} + ORDER BY h.CreationTime DESC + LIMIT 1) + FROM Paths AS p + WHERE p.ClientId = {client_id} + AND p.Type = {type} + """ + params = { + "client_id": client_id, + "type": int(path_type), + } + + # We add a constraint on path only if path components are non-empty. Empty + # path components indicate root path and it means that everything should be + # listed anyway. Treating the root path the same as other paths leads to + # issues with trailing slashes. + if components: + query += """ + AND (p.Path = {path} OR STARTS_WITH(p.Path, CONCAT({path}, b'/'))) + """ + params["path"] = EncodePathComponents(components) + + if timestamp is not None: + params["timestamp"] = timestamp.AsDatetime() + else: + params["timestamp"] = rdfvalue.RDFDatetime.Now().AsDatetime() + + if max_depth is not None: + query += " AND p.Depth <= {depth}" + params["depth"] = len(components) + max_depth + + for row in self.db.ParamQuery( + query, params, txn_tag="ListDescendantPathInfos" + ): + path, creation_time, is_dir, *row = row + last_file_stat_time, stat_bytes, *row = row + last_file_hash_time, hash_bytes, *row = row + () = row + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=DecodePathComponents(path), + directory=is_dir, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + if last_file_stat_time is not None: + result.last_stat_entry_timestamp = ( + rdfvalue.RDFDatetime.FromDatetime(last_file_stat_time) + ).AsMicrosecondsSinceEpoch() + if last_file_hash_time is not None: + result.last_hash_entry_timestamp = ( + rdfvalue.RDFDatetime.FromDatetime(last_file_hash_time) + ).AsMicrosecondsSinceEpoch() - return [] + if stat_bytes is not None: + result.stat_entry.ParseFromString(stat_bytes) + if hash_bytes is not None: + result.hash_entry.ParseFromString(hash_bytes) + + results.append(result) + + results.sort(key=lambda result: tuple(result.components)) + + # Special case: we are being asked to list everything under the root path + # (represented by an empty list of components) but we do not have any path + # information available. We assume that the root path always exists (even if + # we did not collect any data yet) so we have to return something but checks + # that follow would cause us to raise instead. + if not components and not results: + return [] + + # The first element of the results should be the requested path itself. If + # it is not, it means that the requested path does not exist and we should + # raise. We also need to verify that the path is a directory. + if not results or tuple(results[0].components) != tuple(components): + raise abstract_db.UnknownPathError(client_id, path_type, components) + if not results[0].directory: + raise abstract_db.NotDirectoryPathError(client_id, path_type, components) + + # Once we verified that the path exists and is a directory, we should not + # include it in the results (since the method is ought to return only real + # descendants). + del results[0] + + # If timestamp is not specified we return collected paths as they are. But + # if the timestamp is given we are only interested in paths that are expli- + # cit (see below for the definition of "explicitness"). + if timestamp is None: + return results + + # A path is considered to be explicit if it has an associated stat or hash + # information or has an ancestor that is explicit. Thus, we traverse results + # in reverse order so that ancestors are checked for explicitness first. + explicit_results = list() + explicit_ancestors = set() + + for result in reversed(results): + components = tuple(result.components) + if ( + result.HasField("stat_entry") + or result.HasField("hash_entry") + or components in explicit_ancestors + ): + explicit_ancestors.add(components[:-1]) + explicit_results.append(result) + + # Since we have been traversing results in the reverse order, explicit re- + # sults are also reversed. Thus we have to reverse them back. + return list(reversed(explicit_results)) @db_utils.CallLogged @db_utils.CallAccounted @@ -79,12 +425,29 @@ def ReadPathInfosHistories( cutoff: Optional[rdfvalue.RDFDatetime] = None, ) -> dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: """Reads a collection of hash and stat entries for given paths.""" - + # Early return in case of empty components list to avoid awkward issues with + # unnesting an empty array. + if not components_list: + return {} results = {tuple(components): [] for components in components_list} + stat_query = """ + SELECT s.Path, s.CreationTime, s.Stat + FROM PathFileStats AS s + WHERE s.ClientId = {client_id} + AND s.Type = {type} + AND s.Path IN UNNEST({paths}) + """ + hash_query = """ + SELECT h.Path, h.CreationTime, h.Hash + FROM PathFileHashes AS h + WHERE h.ClientId = {client_id} + AND h.Type = {type} + AND h.Path IN UNNEST({paths}) + """ params = { - "client_id": spanner_clients.IntClientID(client_id), + "client_id": client_id, "type": int(path_type), "paths": list(map(EncodePathComponents, components_list)), } @@ -102,6 +465,40 @@ def ReadPathInfosHistories( FROM s FULL JOIN h ON s.Path = h.Path """ + for row in self.db.ParamQuery( + query, params, txn_tag="ReadPathInfosHistories" + ): + stat_path, stat_creation_time, stat_bytes, *row = row + hash_path, hash_creation_time, hash_bytes, *row = row + () = row + + # At least one of the two paths is going to be not null. In case both are + # not null, they are guaranteed to be the same value because of the way + # full join works. + components = DecodePathComponents(stat_path or hash_path) + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=components, + ) + + # Either stat or hash or both have to be available, so at least one of the + # branches below is going to trigger and thus set the timestamp. Not that + # if both are available, they are guaranteed to have the same timestamp so + # overriding the value does no harm. + if stat_bytes is not None: + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + stat_creation_time + ).AsMicrosecondsSinceEpoch() + result.stat_entry.ParseFromString(stat_bytes) + if hash_bytes is not None: + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + hash_creation_time + ).AsMicrosecondsSinceEpoch() + result.hash_entry.ParseFromString(hash_bytes) + + results[components].append(result) + return results @db_utils.CallLogged @@ -110,9 +507,139 @@ def ReadLatestPathInfosWithHashBlobReferences( self, client_paths: Collection[abstract_db.ClientPath], max_timestamp: Optional[rdfvalue.RDFDatetime] = None, - ) -> Dict[abstract_db.ClientPath, Optional[objects_pb2.PathInfo]]: + ) -> dict[abstract_db.ClientPath, Optional[objects_pb2.PathInfo]]: """Returns path info with corresponding hash blob references.""" # Early return in case of empty client paths to avoid issues with syntax er- # rors due to empty clause list. + if not client_paths: + return {} + + params = {} + + key_clauses = [] + for idx, client_path in enumerate(client_paths): + client_id = client_path.client_id + + key_clauses.append(f"""( + h.ClientId = {{client_id_{idx}}} + AND h.Type = {{type_{idx}}} + AND h.Path = {{path_{idx}}} + )""") + params[f"client_id_{idx}"] = client_id + params[f"type_{idx}"] = int(client_path.path_type) + params[f"path_{idx}"] = EncodePathComponents(client_path.components) + + if max_timestamp is not None: + params["cutoff"] = max_timestamp.AsDatetime() + cutoff_clause = " h.CreationTime <= {cutoff}" + else: + cutoff_clause = " TRUE" + + query = f""" + WITH l AS (SELECT h.ClientId, h.Type, h.Path, + MAX(h.CreationTime) AS LastCreationTime + FROM PathFileHashes AS h + INNER JOIN HashBlobReferences AS b + ON h.Hash.sha256 = b.HashId + WHERE ({" OR ".join(key_clauses)}) + AND {cutoff_clause} + GROUP BY h.ClientId, h.Type, h.Path) + SELECT l.ClientId, l.Type, l.Path, l.LastCreationTime, + s.Stat, h.Hash + FROM l + LEFT JOIN @{{{{JOIN_METHOD=APPLY_JOIN}}}} PathFileStats AS s + ON s.ClientId = l.ClientId + AND s.Type = l.Type + AND s.Path = l.Path + AND s.CreationTime = l.LastCreationTime + LEFT JOIN @{{{{JOIN_METHOD=APPLY_JOIN}}}} PathFileHashes AS h + ON h.ClientId = l.ClientId + AND h.Type = l.Type + AND h.Path = l.Path + AND h.CreationTime = l.LastCreationTime + """ + + results = {client_path: None for client_path in client_paths} + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadLatestPathInfosWithHashBlobReferences" + ): + int_client_id, int_type, path, creation_time, *row = row + stat_bytes, hash_bytes = row + + components = DecodePathComponents(path) + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int_type), + components=components, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + if stat_bytes is not None: + result.stat_entry.ParseFromString(stat_bytes) + + # Hash is guaranteed to be non-null (because of the query construction). + result.hash_entry.ParseFromString(hash_bytes) + + client_path = abstract_db.ClientPath( + client_id=db_utils.IntToClientID(int_client_id), + path_type=int_type, + components=components, + ) + + results[client_path] = result + + return results + + +def EncodePathComponents(components: Sequence[str]) -> bytes: + """Converts path components into canonical database representation of a path. + + Individual components are required to be non-empty and not contain any slash + characters ('/'). + + Args: + components: A sequence of path components. + + Returns: + A canonical database representation of the given path. + """ + for component in components: + if not component: + raise ValueError("Empty path component") + if _PATH_SEP in component: + raise ValueError(f"Path component with a '{_PATH_SEP}' character") + + return f"{_PATH_SEP}{_PATH_SEP.join(components)}".encode("utf-8") + + +def DecodePathComponents(path: bytes) -> tuple[str, ...]: + """Converts a path in canonical database representation into path components. + + Args: + path: A path in its canonical database representation. + + Returns: + A sequence of path components. + """ + # TODO(hanuszczak): Add support for WTF-8 decoding of paths. + path = path.decode("utf-8") + + if not path.startswith(_PATH_SEP): + raise ValueError(f"Non-absolute path {path!r}") + + if path == _PATH_SEP: + # A special case for root path, since `str.split` for it gives use two empty + # strings. + return () + else: + return tuple(path.split(_PATH_SEP)[1:]) + - return {} +# We use a forward slash as separator as this is the separator accepted by all +# supported platforms (including Windows) and is disallowed to appear in path +# components. Another viable choice would be a null byte character but that is +# very inconvenient to look at when browsing the database. +_PATH_SEP = "/" diff --git a/grr/server/grr_response_server/databases/spanner_paths_test.py b/grr/server/grr_response_server/databases/spanner_paths_test.py index 0b108c698..e662defed 100644 --- a/grr/server/grr_response_server/databases/spanner_paths_test.py +++ b/grr/server/grr_response_server/databases/spanner_paths_test.py @@ -3,8 +3,8 @@ from absl.testing import absltest from grr_response_server.databases import db_paths_test -from grr_response_server.databases.local.spanner import paths as spanner_paths -from grr_response_server.databases.local.spanner import spanner_test_lib +from grr_response_server.databases import spanner_paths +from grr_response_server.databases import spanner_test_lib def setUpModule() -> None: diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 5d881dac9..d4d6b0a08 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -109,17 +109,13 @@ def __init__(self, pyspanner: spanner_lib.database, project_id: str, def Now(self) -> rdfvalue.RDFDatetime: """Retrieves current time as reported by the database.""" - try: - with self._pyspanner.snapshot() as snapshot: - query = "SELECT CURRENT_TIMESTAMP()" - results = snapshot.execute_sql(query) - # Get the first (and only) row - # and the first (and only) column from that row. - timestamp = next(results)[0] - return rdfvalue.RDFDatetime.FromDatetime(timestamp) - except Exception as e: - print(f"Error executing query: {e}") - return None + with self._pyspanner.snapshot() as snapshot: + timestamp = None + query = "SELECT CURRENT_TIMESTAMP() AS now" + results = snapshot.execute_sql(query) + for row in results: + timestamp = row[0] + return rdfvalue.RDFDatetime.FromDatetime(timestamp) def MinTimestamp(self) -> rdfvalue.RDFDatetime: """Returns minimal timestamp allowed by the DB.""" @@ -187,7 +183,6 @@ def Transact( self, func: Callable[["Transaction"], _T], txn_tag: Optional[str] = None, - log_commit_stats: Optional[bool] = False ) -> List[Any]: """Execute the given callback function in a Spanner transaction. From 884ada8d9ed1ebda1d92edc947d5737eafd45d46 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Fri, 13 Jun 2025 14:47:47 +0000 Subject: [PATCH 023/168] Adds Paths table --- .../databases/spanner_flows.py | 17 +-- .../databases/spanner_paths.py | 112 ++++++++---------- 2 files changed, 59 insertions(+), 70 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index f2cc50329..ecc332b6e 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -895,10 +895,11 @@ def Txn(txn) -> None: for row in txn.read(table="Flows", keyset=spanner_lib.KeySet(keys=keys), columns=columns): client_id = row[0] flow_id = row[1] + next_request_to_process = int(row[2]) candidate_requests = needs_processing.get((client_id, flow_id), []) for r in candidate_requests: - if int(row[2]) == r.request_id or r.start_time: + if next_request_to_process == r.request_id or r.start_time: req = flows_pb2.FlowProcessingRequest( client_id=client_id, flow_id=flow_id ) @@ -1061,7 +1062,7 @@ def _BuildExpectedUpdates( for r_key, num_responses_expected in updates.items(): rows.append([r_key.client_id, r_key.flow_id, - r_key.request_id, + str(r_key.request_id), num_responses_expected, ]) txn.update(table="FlowRequests", columns=columns, values=rows) @@ -1112,7 +1113,7 @@ def Txn(txn) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: needs_expected_update = {} for r in responses: - req_key = _RequestKey(r.client_id, r.flow_id, r.request_id) + req_key = _RequestKey(r.client_id, r.flow_id, int(r.request_id)) # If the response is not a FlowStatus, we have nothing to do: it will be # simply written to the DB. If it's a FlowStatus, we have to update @@ -1145,9 +1146,9 @@ def Txn(txn) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: responses_to_write = {} for r in responses: - req_key = _RequestKey(r.client_id, r.flow_id, r.request_id) + req_key = _RequestKey(r.client_id, r.flow_id, int(r.request_id)) full_key = _ResponseKey( - r.client_id, r.flow_id, r.request_id, r.response_id + r.client_id, r.flow_id, int(r.request_id), int(r.response_id) ) if req_key not in currently_available_requests: @@ -1268,7 +1269,7 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( unique_flow_keys = set() for req_key in set(requests) | set(callback_states): - req_keys.append([req_key.client_id, req_key.flow_id, req_key.request_id]) + req_keys.append([req_key.client_id, req_key.flow_id, str(req_key.request_id)]) unique_flow_keys.add((req_key.client_id, req_key.flow_id)) for client_id, flow_id in unique_flow_keys: @@ -1801,8 +1802,8 @@ def Mutation(mut) -> None: for response_id, log in logs.items(): rows.append([client_id, flow_id, - request_id, - response_id, + str(request_id), + str(response_id), log.level, log.timestamp.ToDatetime(), log.message, diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index ad1f0166f..70a83c027 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -1,8 +1,11 @@ #!/usr/bin/env python """A module with path methods of the Spanner database implementation.""" - +import base64 from typing import Collection, Dict, Iterable, Optional, Sequence +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import iterator from grr_response_proto import objects_pb2 @@ -17,8 +20,6 @@ class PathsMixin: db: spanner_utils.Database - # TODO(b/196379916): Implement path methods. - @db_utils.CallLogged @db_utils.CallAccounted def WritePathInfos( @@ -27,85 +28,73 @@ def WritePathInfos( path_infos: Iterable[objects_pb2.PathInfo], ) -> None: """Writes a collection of path records for a client.""" - int_client_id = client_id - # Special case for empty list of paths because Spanner does not like empty # mutations. We still have to validate the client id. if not path_infos: try: - self.db.Read(table="Clients", key=[int_client_id], cols=()) - except spanner_errors.RowNotFoundError as error: + self.db.Read(table="Clients", key=[client_id], cols=(["ClientId"])) + except NotFound as error: raise abstract_db.UnknownClientError(client_id) from error - return - def Mutation(mut: spanner_utils.Mutation) -> None: + def Mutation(mut) -> None: ancestors = set() - + file_hash_columns = ["ClientId", "Type", "Path", "CreationTime", "FileHash"] + file_stat_columns = ["ClientId", "Type", "Path", "CreationTime", "Stat"] for path_info in path_infos: + path_columns = ["ClientId", "Type", "Path", "CreationTime", "IsDir", "Depth"] int_path_type = int(path_info.path_type) path = EncodePathComponents(path_info.components) - row = { - "ClientId": int_client_id, - "Type": int_path_type, - "Path": path, - "CreationTime": spanner_lib.COMMIT_TIMESTAMP, - "IsDir": path_info.directory, - "Depth": len(path_info.components), - } + path_row = [client_id, int_path_type, path, + spanner_lib.COMMIT_TIMESTAMP, + path_info.directory, + len(path_info.components) + ] if path_info.HasField("stat_entry"): - row["LastFileStatTime"] = spanner_lib.COMMIT_TIMESTAMP - - file_stat_row = { - "ClientId": int_client_id, - "Type": int_path_type, - "Path": path, - "CreationTime": spanner_lib.COMMIT_TIMESTAMP, - "Stat": path_info.stat_entry, - } + path_columns.append("LastFileStatTime") + path_row.append(spanner_lib.COMMIT_TIMESTAMP) + + file_stat_row = [client_id, int_path_type, path, + spanner_lib.COMMIT_TIMESTAMP, + path_info.stat_entry + ] else: file_stat_row = None if path_info.HasField("hash_entry"): - row["LastFileHashTime"] = spanner_lib.CommitTimestamp() - - file_hash_row = { - "ClientId": int_client_id, - "Type": int_path_type, - "Path": path, - "CreationTime": spanner_lib.CommitTimestamp(), - "Hash": path_info.hash_entry, - } + path_columns.append("LastFileHashTime") + path_row.append(spanner_lib.COMMIT_TIMESTAMP) + + file_hash_row = [client_id, int_path_type, path, + spanner_lib.COMMIT_TIMESTAMP, + path_info.hash_entry, + ] else: file_hash_row = None - mut.InsertOrUpdate(table="Paths", row=row) + mut.insert_or_update(table="Paths", columns=path_columns, values=[path_row]) if file_stat_row is not None: - mut.Insert(table="PathFileStats", row=file_stat_row) + mut.insert(table="PathFileStats", columns=file_stat_columns, values=[file_stat_row]) if file_hash_row is not None: - mut.Insert(table="PathFileHashes", row=file_hash_row) + mut.insert(table="PathFileHashes", columns=file_hash_columns, values=[file_hash_row]) for path_info_ancestor in models_paths.GetAncestorPathInfos(path_info): components = tuple(path_info_ancestor.components) ancestors.add((path_info.path_type, components)) + path_columns = ["ClientId", "Type", "Path", "CreationTime", "IsDir", "Depth"] for path_type, components in ancestors: - row = { - "ClientId": int_client_id, - "Type": int(path_type), - "Path": EncodePathComponents(components), - "CreationTime": spanner_lib.CommitTimestamp(), - "IsDir": True, - "Depth": len(components), - } - mut.InsertOrUpdate(table="Paths", row=row) + path_row = [client_id, int(path_type), EncodePathComponents(components), + spanner_lib.COMMIT_TIMESTAMP, True, len(components), + ] + mut.insert_or_update(table="Paths", columns=path_columns, values=[path_row]) try: self.db.Mutate(Mutation, txn_tag="WritePathInfos") except NotFound as error: - if "Parent row is missing: Clients" in str(error): + if "Parent row for row [" in str(error): raise abstract_db.UnknownClientError(client_id) from error else: raise @@ -128,7 +117,7 @@ def ReadPathInfos( query = """ SELECT p.Path, p.CreationTime, p.IsDir, ps.CreationTime, ps.Stat, - ph.CreationTime, ph.Hash, + ph.CreationTime, ph.FileHash, FROM Paths AS p LEFT JOIN PathFileStats AS ps ON p.ClientId = ps.ClientId @@ -206,7 +195,7 @@ def ReadPathInfo( LIMIT 1), -- File hash information. p.LastFileHashTime, - (SELECT h.Hash + (SELECT h.FileHash FROM PathFileHashes AS h WHERE h.ClientId = {client_id} AND h.Type = {type} @@ -232,7 +221,7 @@ def ReadPathInfo( try: row = self.db.ParamQuerySingle(query, params, txn_tag="ReadPathInfo") - except iterator.NoYieldsError: + except NotFound: raise abstract_db.UnknownPathError(client_id, path_type, components) # pylint: disable=raise-missing-from creation_time, is_dir, *row = row @@ -297,7 +286,7 @@ def ListDescendantPathInfos( LIMIT 1), -- File hash information. p.LastFileHashTime, - (SELECT h.Hash + (SELECT h.FileHash FROM PathFileHashes AS h WHERE h.ClientId = p.ClientId AND h.Type = p.Type @@ -440,7 +429,7 @@ def ReadPathInfosHistories( AND s.Path IN UNNEST({paths}) """ hash_query = """ - SELECT h.Path, h.CreationTime, h.Hash + SELECT h.Path, h.CreationTime, h.FileHash FROM PathFileHashes AS h WHERE h.ClientId = {client_id} AND h.Type = {type} @@ -461,7 +450,7 @@ def ReadPathInfosHistories( WITH s AS ({stat_query}), h AS ({hash_query}) SELECT s.Path, s.CreationTime, s.Stat, - h.Path, h.CreationTime, h.Hash + h.Path, h.CreationTime, h.FileHash FROM s FULL JOIN h ON s.Path = h.Path """ @@ -540,12 +529,12 @@ def ReadLatestPathInfosWithHashBlobReferences( MAX(h.CreationTime) AS LastCreationTime FROM PathFileHashes AS h INNER JOIN HashBlobReferences AS b - ON h.Hash.sha256 = b.HashId + ON h.FileHash.sha256 = b.HashId WHERE ({" OR ".join(key_clauses)}) AND {cutoff_clause} GROUP BY h.ClientId, h.Type, h.Path) SELECT l.ClientId, l.Type, l.Path, l.LastCreationTime, - s.Stat, h.Hash + s.Stat, h.FileHash FROM l LEFT JOIN @{{{{JOIN_METHOD=APPLY_JOIN}}}} PathFileStats AS s ON s.ClientId = l.ClientId @@ -564,7 +553,7 @@ def ReadLatestPathInfosWithHashBlobReferences( for row in self.db.ParamQuery( query, params, txn_tag="ReadLatestPathInfosWithHashBlobReferences" ): - int_client_id, int_type, path, creation_time, *row = row + client_id, int_type, path, creation_time, *row = row stat_bytes, hash_bytes = row components = DecodePathComponents(path) @@ -584,7 +573,7 @@ def ReadLatestPathInfosWithHashBlobReferences( result.hash_entry.ParseFromString(hash_bytes) client_path = abstract_db.ClientPath( - client_id=db_utils.IntToClientID(int_client_id), + client_id=client_id, path_type=int_type, components=components, ) @@ -612,7 +601,7 @@ def EncodePathComponents(components: Sequence[str]) -> bytes: if _PATH_SEP in component: raise ValueError(f"Path component with a '{_PATH_SEP}' character") - return f"{_PATH_SEP}{_PATH_SEP.join(components)}".encode("utf-8") + return base64.b64encode(f"{_PATH_SEP}{_PATH_SEP.join(components)}".encode("utf-8")) def DecodePathComponents(path: bytes) -> tuple[str, ...]: @@ -624,8 +613,7 @@ def DecodePathComponents(path: bytes) -> tuple[str, ...]: Returns: A sequence of path components. """ - # TODO(hanuszczak): Add support for WTF-8 decoding of paths. - path = path.decode("utf-8") + path = base64.b64decode(path).decode("utf-8") if not path.startswith(_PATH_SEP): raise ValueError(f"Non-absolute path {path!r}") From 5e3b14b0462a0e967c689699a3a9b2f101532ef6 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Fri, 13 Jun 2025 15:56:57 +0000 Subject: [PATCH 024/168] EoD 20250613 --- grr/server/grr_response_server/databases/spanner_flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index ecc332b6e..e9b4826ed 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -1305,7 +1305,7 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( columns=req_cols): client_id: str = row[0] flow_id: str = row[1] - request_id: int = row[2] + request_id: int = int(row[2]) np: bool = row[3] start_time: Optional[rdfvalue.RDFDatetime] = None if row[4] is not None: From f647151d1c09dada2c5772bff1cb17cdfe89db72 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sat, 14 Jun 2025 18:22:27 +0000 Subject: [PATCH 025/168] EoD 20250614 --- .../grr_response_server/databases/spanner.sdl | 8 +- .../databases/spanner_flows.py | 10 +-- .../databases/spanner_hunts.py | 82 ++++++++++++------- .../databases/spanner_users.py | 14 ---- .../databases/spanner_utils.py | 15 ++-- 5 files changed, 71 insertions(+), 58 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index e5cbd04dd..7692c4043 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -275,7 +275,7 @@ CREATE TABLE ApprovalGrants( CONSTRAINT fk_approval_grant_grantor_username FOREIGN KEY (Grantor) - REFERENCES Users(Username), + REFERENCES Users(Username) ON DELETE CASCADE, ) PRIMARY KEY (Requestor, ApprovalId, Grantor, GrantId), INTERLEAVE IN PARENT ApprovalRequests ON DELETE CASCADE; @@ -439,7 +439,7 @@ CREATE TABLE FlowRRGLogs( CREATE TABLE FlowOutputPluginLogEntries( ClientId STRING(18) NOT NULL, FlowId STRING(16) NOT NULL, - OutputPluginId STRING(16) NOT NULL, + OutputPluginId INT64 NOT NULL, HuntId STRING(16), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, @@ -462,7 +462,7 @@ CREATE TABLE ScheduledFlows( CONSTRAINT fk_creator_users_username FOREIGN KEY (Creator) - REFERENCES Users(Username), + REFERENCES Users(Username) ON DELETE CASCADE, ) PRIMARY KEY (ClientId, Creator, ScheduledFlowId), INTERLEAVE IN PARENT Clients ON DELETE CASCADE; @@ -580,7 +580,7 @@ CREATE INDEX HuntsByCreator ON Hunts(Creator); CREATE TABLE HuntOutputPlugins( HuntId STRING(16) NOT NULL, - OutputPluginId STRING(16) NOT NULL, + OutputPluginId INT64 NOT NULL, Name STRING(256) NOT NULL, Args `google.protobuf.Any`, State `google.protobuf.Any` NOT NULL, diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index e9b4826ed..caa372ba2 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -1286,7 +1286,7 @@ def _ReadFlowRequestsNotYetMarkedForProcessing( columns=flow_cols): client_id: int = row[0] flow_id: int = row[1] - next_request_id: int = row[2] + next_request_id: int = int(row[2]) next_request_to_process_by_flow[(client_id, flow_id)] = ( next_request_id ) @@ -1533,12 +1533,12 @@ def DeleteFlowRequests( def Mutation(mut) -> None: for request in requests: - key = [ + keyset = spanner_lib.KeySet([[ request.client_id, request.flow_id, - str(request.request_id), - ] - mut.Delete(table="FlowRequests", key=key) + str(request.request_id) + ]]) + mut.delete(table="FlowRequests", keyset=keyset) try: self.db.Mutate(Mutation, txn_tag="DeleteFlowRequests") diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index 3de610334..d8cc8486b 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -1,11 +1,13 @@ #!/usr/bin/env python """A module with hunt methods of the Spanner database implementation.""" +import base64 from typing import AbstractSet, Callable, Collection, Iterable, List, Mapping, Optional, Sequence, Set from google.api_core.exceptions import AlreadyExists, NotFound from google.cloud import spanner as spanner_lib +from google.cloud.spanner_v1 import param_types from google.protobuf import any_pb2 from grr_response_core.lib import rdfvalue @@ -141,31 +143,40 @@ def _UpdateHuntObject( params = { "hunt_id": hunt_id, } + params_types = { + "hunt_id": param_types.STRING + } assignments = ["h.LastUpdateTime = PENDING_COMMIT_TIMESTAMP()"] if duration is not None: assignments.append("h.DurationMicros = @duration_micros") params["duration_micros"] = int(duration.microseconds) + params_types["duration_micros"] = param_types.INT64 if client_rate is not None: assignments.append("h.ClientRate = @client_rate") params["client_rate"] = float(client_rate) + params_types["client_rate"] = param_types.FLOAT64 if client_limit is not None: assignments.append("h.ClientLimit = @client_limit") params["client_limit"] = int(client_limit) + params_types["client_limit"] = param_types.INT64 if hunt_state is not None: assignments.append("h.State = @hunt_state") params["hunt_state"] = int(hunt_state) + params_types["hunt_state"] = param_types.INT64 if hunt_state_reason is not None: assignments.append("h.StateReason = @hunt_state_reason") params["hunt_state_reason"] = int(hunt_state_reason) + params_types["hunt_state_reason"] = param_types.INT64 if hunt_state_comment is not None: assignments.append("h.StateComment = @hunt_state_comment") params["hunt_state_comment"] = hunt_state_comment + params_types["hunt_state_comment"] = param_types.STRING if start_time is not None: assignments.append( @@ -173,14 +184,16 @@ def _UpdateHuntObject( ) assignments.append("h.LastStartTime = @start_time") params["start_time"] = start_time.AsDatetime() + params_types["start_time"] = param_types.TIMESTAMP if num_clients_at_start_time is not None: assignments.append( "h.ClientCountAtStartTime = @client_count_at_start_time" ) - params["client_count_at_start_time"] = spanner_lib.UInt64( + params["client_count_at_start_time"] = int( num_clients_at_start_time ) + params_types["client_count_at_start_time"] = param_types.INT64 query = f""" UPDATE Hunts AS h @@ -188,7 +201,7 @@ def _UpdateHuntObject( WHERE h.HuntId = @hunt_id """ - txn.execute_sql(sql=query, params=params) + txn.execute_update(query, params=params, param_types=params_types) @db_utils.CallLogged @db_utils.CallAccounted @@ -212,7 +225,7 @@ def Txn(txn) -> None: # Make sure the hunt is there. try: keyset = spanner_lib.KeySet(keys=[[hunt_id]]) - txn.read(table="Hunts", keyset=keyset, columns=["HuntId",]) + txn.read(table="Hunts", keyset=keyset, columns=["HuntId",]).one() except NotFound as e: raise abstract_db.UnknownHuntError(hunt_id) from e @@ -236,9 +249,8 @@ def Txn(txn) -> None: @db_utils.CallAccounted def DeleteHuntObject(self, hunt_id: str) -> None: """Deletes a hunt object with a given id.""" - keyset = spanner_lib.KeySet(keys=[[hunt_id]]) self.db.Delete( - table="Hunts", keyset=keyset, txn_tag="DeleteHuntObject" + table="Hunts", key=[hunt_id], txn_tag="DeleteHuntObject" ) @db_utils.CallLogged @@ -309,7 +321,9 @@ def ReadHuntObjects( with_description_match: Optional[str] = None, created_by: Optional[Set[str]] = None, not_created_by: Optional[Set[str]] = None, - #with_states: Optional[Collection[hunts_pb2.Hunt.HuntState]] = None, # TODO + with_states: Optional[ + Iterable[hunts_pb2.Hunt.HuntState.ValueType] + ] = None ) -> list[hunts_pb2.Hunt]: """Reads hunt objects from the database.""" @@ -318,6 +332,7 @@ def ReadHuntObjects( "limit": count, "offset": offset, } + param_type = {} if with_creator is not None: conditions.append("h.Creator = {creator}") @@ -326,10 +341,12 @@ def ReadHuntObjects( if created_by is not None: conditions.append("h.Creator IN UNNEST({created_by})") params["created_by"] = list(created_by) + param_type["created_by"] = param_types.Array(param_types.STRING) if not_created_by is not None: conditions.append("h.Creator NOT IN UNNEST({not_created_by})") params["not_created_by"] = list(not_created_by) + param_type["not_created_by"] = param_types.Array(param_types.STRING) if created_after is not None: conditions.append("h.CreationTime > {creation_time}") @@ -377,7 +394,7 @@ def ReadHuntObjects( """ result = [] - for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntObjects"): + for row in self.db.ParamQuery(query, params, param_type=param_type, txn_tag="ReadHuntObjects"): hunt_obj = hunts_pb2.Hunt() hunt_obj.ParseFromString(row[14]) @@ -433,6 +450,7 @@ def ListHuntObjects( "limit": count, "offset": offset, } + param_type = {} if with_creator is not None: conditions.append("h.Creator = {creator}") @@ -441,10 +459,12 @@ def ListHuntObjects( if created_by is not None: conditions.append("h.Creator IN UNNEST({created_by})") params["created_by"] = list(created_by) + param_type["created_by"] = param_types.Array(param_types.STRING) if not_created_by is not None: conditions.append("h.Creator NOT IN UNNEST({not_created_by})") params["not_created_by"] = list(not_created_by) + param_type["not_created_by"] = param_types.Array(param_types.STRING) if created_after is not None: conditions.append("h.CreationTime > {creation_time}") @@ -489,7 +509,8 @@ def ListHuntObjects( """ result = [] - for row in self.db.ParamQuery(query, params, txn_tag="ListHuntObjects"): + for row in self.db.ParamQuery(query, params, + param_type=param_type, txn_tag="ListHuntObjects"): hunt_mdata = hunts_pb2.HuntMetadata() hunt_mdata.hunt_id = row[0] @@ -800,7 +821,7 @@ def CountHuntOutputPluginLogEntries( """ params = { "hunt_id": hunt_id, - "output_plugin_id": output_plugin_id, + "output_plugin_id": int(output_plugin_id), } if with_type is not None: @@ -858,17 +879,20 @@ def Mutation(mut) -> None: for index, state in enumerate(states): state_any = any_pb2.Any() state_any.Pack(state.plugin_state) - row = { - "HuntId": hunt_id, - "OutputPluginId": index, - "Name": state.plugin_descriptor.plugin_name, - "State": state_any.SerializeToString(), - } + columns = ["HuntId", "OutputPluginId", + "Name", + "State" + ] + row = [hunt_id, index, + state.plugin_descriptor.plugin_name, + base64.b64encode(state_any.SerializeToString()), + ] if state.plugin_descriptor.HasField("args"): - row["Args"] = state.plugin_descriptor.args.SerializeToString() + columns.append("Args") + row.append(base64.b64encode(state.plugin_descriptor.args.SerializeToString())) - mut.InsertOrUpdate(table="HuntOutputPlugins", row=row) + mut.insert_or_update(table="HuntOutputPlugins", columns=columns, values=[row]) try: self.db.Mutate(Mutation, txn_tag="WriteHuntOutputPluginsStates") @@ -888,22 +912,20 @@ def UpdateHuntOutputPluginState( def Txn(txn) -> None: row = txn.read( table="HuntOutputPlugins", - keyset=spanner_lib.KeySet(keys=[hunt_id, state_index]), + keyset=spanner_lib.KeySet(keys=[[hunt_id, state_index]]), columns=["Name", "Args", "State"], - ) + ).one() state = _HuntOutputPluginStateFromRow( - row["Name"], row["Args"], row["State"] + row[0], row[1], row[2] ) modified_plugin_state = update_fn(state.plugin_state) modified_plugin_state_any = any_pb2.Any() modified_plugin_state_any.Pack(modified_plugin_state) - row_update = { - "HuntId": int_hunt_id, - "OutputPluginId": state_index, - "State": modified_plugin_state_any.SerializeToString(), - } - txn.Update("HuntOutputPlugins", row_update) + columns = ["HuntId", "OutputPluginId", "State"] + row = [hunt_id,state_index, + base64.b64encode(modified_plugin_state_any.SerializeToString())] + txn.update("HuntOutputPlugins", columns=columns, values=[row]) try: self.db.Transact(Txn, txn_tag="UpdateHuntOutputPluginState") @@ -1131,6 +1153,9 @@ def ReadHuntsCounters( params = { "hunt_ids": hunt_ids, } + param_type = { + "hunt_ids": param_types.Array(param_types.STRING) + } states_query = """ SELECT f.ParentHuntID, f.State, COUNT(*) @@ -1142,7 +1167,7 @@ def ReadHuntsCounters( """ counts_by_state_per_hunt = dict.fromkeys(hunt_ids, {}) for hunt_id, state, count in self.db.ParamQuery( - states_query, params, txn_tag="ReadHuntCounters_1" + states_query, params, param_type=param_type, txn_tag="ReadHuntCounters_1" ): counts_by_state_per_hunt[hunt_id][state] = count @@ -1181,7 +1206,8 @@ def ReadHuntsCounters( num_results, num_clients_with_results, ) in self.db.ParamQuery( - resources_results_query, params, txn_tag="ReadHuntCounters_2" + resources_results_query, params, + param_type=param_type, txn_tag="ReadHuntCounters_2" ): counts_by_state = counts_by_state_per_hunt[hunt_id] num_successful_clients = counts_by_state.get( diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index a786ae038..3d47788a6 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -71,20 +71,6 @@ def Transaction(txn) -> None: except NotFound: raise abstract_db.UnknownGRRUserError(username) - query = f""" - DELETE - FROM ApprovalGrants@{{FORCE_INDEX=ApprovalGrantsByGrantor}} AS g - WHERE g.Grantor = '{username}' - """ - txn.execute_sql(query) - - query = f""" - DELETE - FROM ScheduledFlows@{{FORCE_INDEX=ScheduledFlowsByCreator}} AS f - WHERE f.Creator = '{username}' - """ - txn.execute_sql(query) - username_range = spanner_lib.KeyRange(start_closed=[username], end_closed=[username]) txn.delete(table="ApprovalRequests", keyset=spanner_lib.KeySet(ranges=[username_range])) txn.delete(table="Users", keyset=keyset) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index d4d6b0a08..f1f5f7ea7 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -240,7 +240,8 @@ def QuerySingle(self, query: str, txn_tag: Optional[str] = None) -> Row: return self.Query(query, txn_tag=txn_tag).one() def ParamQuery( - self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None + self, query: str, params: Mapping[str, Any], + param_type: Optional[dict] = {}, txn_tag: Optional[str] = None ) -> Cursor: """Queries PySpanner database using the given query string with params. @@ -270,13 +271,13 @@ def ParamQuery( names, values = collection.Unzip(params.items()) query = self._parametrize(query, names) - param_type = {} for key, value in params.items(): - try: - param_type[key] = self._get_param_type(value) - except TypeError as e: - print(f"Warning for key '{key}': {e}. Setting type to None.") - param_type[key] = None # Or re-raise, or handle differently + if key not in param_type: + try: + param_type[key] = self._get_param_type(value) + except TypeError as e: + print(f"Warning for key '{key}': {e}. Setting type to None.") + param_type[key] = None # Or re-raise, or handle differently print("query: {}".format(query)) print("params: {}".format(params)) From e1b33aeb726c3e643538587a92e62be8ec8d5d5a Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sun, 15 Jun 2025 11:07:08 +0000 Subject: [PATCH 026/168] Adds Hunts table --- .../grr_response_server/databases/spanner.sdl | 2 +- .../databases/spanner_hunts.py | 31 ++++++++++++++----- .../databases/spanner_utils.py | 5 +-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 7692c4043..06f2f3cb9 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -439,7 +439,7 @@ CREATE TABLE FlowRRGLogs( CREATE TABLE FlowOutputPluginLogEntries( ClientId STRING(18) NOT NULL, FlowId STRING(16) NOT NULL, - OutputPluginId INT64 NOT NULL, + OutputPluginId STRING(16) NOT NULL, HuntId STRING(16), CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index d8cc8486b..fb93c9493 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -39,7 +39,7 @@ def _HuntOutputPluginStateFromRow( plugin_state_any = any_pb2.Any() plugin_state_any.ParseFromString(plugin_state) - # TODO(b/120464908): currently AttributedDict is used to store output + # currently AttributedDict is used to store output # plugins' states. This is suboptimal and unsafe and should be refactored # towards using per-plugin protos. attributed_dict = jobs_pb2.AttributedDict() @@ -564,6 +564,7 @@ def ReadHuntResults( "offset": offset, "count": count, } + param_type = {} query = """ SELECT t.Payload, t.CreationTime, t.Tag, t.ClientId, t.FlowId @@ -575,10 +576,12 @@ def ReadHuntResults( if with_tag is not None: query += " AND t.Tag = {tag} " params["tag"] = with_tag + param_type["tag"] = param_types.STRING if with_type is not None: query += " AND t.RdfType = {type}" params["type"] = with_type + param_type["type"] = param_types.STRING if with_substring is not None: query += ( @@ -586,10 +589,12 @@ def ReadHuntResults( "{substring}) != 0" ) params["substring"] = with_substring + param_type["substring"] = param_types.STRING if with_timestamp is not None: query += " AND t.CreationTime = {creation_time}" params["creation_time"] = with_timestamp.AsDatetime() + param_type["creation_time"] = param_types.TIMESTAMP query += """ ORDER BY t.CreationTime ASC LIMIT {count} OFFSET {offset} @@ -602,7 +607,8 @@ def ReadHuntResults( tag, client_id, flow_id, - ) in self.db.ParamQuery(query, params, txn_tag="ReadHuntResults"): + ) in self.db.ParamQuery(query, params, param_type=param_type, + txn_tag="ReadHuntResults"): result = flows_pb2.FlowResult() result.hunt_id = hunt_id result.client_id = client_id @@ -629,6 +635,7 @@ def CountHuntResults( params = { "hunt_id": hunt_id, } + param_type = {} query = """ SELECT COUNT(*) @@ -639,13 +646,17 @@ def CountHuntResults( if with_tag is not None: query += " AND t.Tag = {tag} " params["tag"] = with_tag + param_type["tag"] = param_types.STRING if with_type is not None: query += " AND t.RdfType = {type}" params["type"] = with_type + param_type["type"] = param_types.STRING (count,) = self.db.ParamQuerySingle( - query, params, txn_tag="CountHuntResults" + query, params, + param_type=param_type, + txn_tag="CountHuntResults" ) return count @@ -762,19 +773,22 @@ def ReadHuntOutputPluginLogEntries( SELECT l.ClientId, l.FlowId, l.CreationTime, - l.Type, l.Message - FROM FlowOutputPluginLogEntries AS l - WHERE l.HuntId = {hunt_id} + l.Type, + l.Message + FROM FlowOutputPluginLogEntries AS l + WHERE l.HuntId = {hunt_id} AND l.OutputPluginId = {output_plugin_id} """ params = { "hunt_id": hunt_id, "output_plugin_id": output_plugin_id, } + param_type = {} if with_type is not None: query += " AND l.Type = {type}" params["type"] = int(with_type) + param_type["type"] = param_types.INT64 query += """ LIMIT {count} @@ -785,7 +799,8 @@ def ReadHuntOutputPluginLogEntries( results = [] for row in self.db.ParamQuery( - query, params, txn_tag="ReadHuntOutputPluginLogEntries" + query, params, param_type=param_type, + txn_tag="ReadHuntOutputPluginLogEntries" ): client_id, flow_id, creation_time, int_type, message = row @@ -821,7 +836,7 @@ def CountHuntOutputPluginLogEntries( """ params = { "hunt_id": hunt_id, - "output_plugin_id": int(output_plugin_id), + "output_plugin_id": output_plugin_id, } if with_type is not None: diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index f1f5f7ea7..048a3259b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -293,7 +293,8 @@ def ParamQuery( return results def ParamQuerySingle( - self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None + self, query: str, params: Mapping[str, Any], + param_type: Optional[dict] = {}, txn_tag: Optional[str] = None ) -> Row: """Queries the database for a single row using with a query with params. @@ -314,7 +315,7 @@ def ParamQuerySingle( ValueError: If the query contains disallowed sequences. KeyError: If some parameter is not specified. """ - return self.ParamQuery(query, params, txn_tag=txn_tag).one() + return self.ParamQuery(query, params, param_type=param_type, txn_tag=txn_tag).one() def ParamExecute( self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None From 103dc3584816f8699fca17c19e8d6c740fdb7b89 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sun, 15 Jun 2025 17:25:00 +0000 Subject: [PATCH 027/168] EoD 20250615 --- .../grr_response_server/databases/spanner.sdl | 25 ------------------- .../databases/spanner_flows.py | 18 ++++++++++--- .../databases/spanner_utils.py | 7 ++---- 3 files changed, 17 insertions(+), 33 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 06f2f3cb9..7b2c6a6da 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -16,7 +16,6 @@ CREATE PROTO BUNDLE ( `grr.UserNotification`, `grr.UserNotification.State`, `grr.UserNotification.Type`, - -- RDF magic types. `grr.AttributedDict`, `grr.BlobArray`, `grr.Dict`, @@ -24,7 +23,6 @@ CREATE PROTO BUNDLE ( `grr.DataBlob.CompressionType`, `grr.EmbeddedRDFValue`, `grr.KeyValue`, - -- Client snapshot types. `grr.AmazonCloudInstance`, `grr.ClientInformation`, `grr.ClientCrash`, @@ -48,7 +46,6 @@ CREATE PROTO BUNDLE ( `grr.WindowsVolume`, `grr.WindowsVolume.WindowsDriveTypeEnum`, `grr.WindowsVolume.WindowsVolumeAttributeEnum`, - -- Notification reference types. `grr.ApprovalRequestReference`, `grr.ClientReference`, `grr.CronJobReference`, @@ -59,7 +56,6 @@ CREATE PROTO BUNDLE ( `grr.ObjectReference`, `grr.ObjectReference.Type`, `grr.VfsFileReference`, - -- Flow types. `grr.CpuSeconds`, `grr.Flow`, `grr.Flow.FlowState`, @@ -82,10 +78,8 @@ CREATE PROTO BUNDLE ( `grr.OutputPluginDescriptor`, `grr.OutputPluginState`, `grr.RequestState`, - -- Audit events types. `grr.APIAuditEntry`, `grr.APIAuditEntry.Code`, - -- File-related types. `grr.AuthenticodeSignedData`, `grr.BlobImageChunkDescriptor`, `grr.BlobImageDescriptor`, @@ -100,7 +94,6 @@ CREATE PROTO BUNDLE ( `grr.StatEntry`, `grr.StatEntry.ExtAttr`, `grr.StatEntry.RegistryType`, - -- Foreman rules types. `grr.ForemanClientRule`, `grr.ForemanClientRule.Type`, `grr.ForemanClientRuleSet`, @@ -116,13 +109,11 @@ CREATE PROTO BUNDLE ( `grr.ForemanRegexClientRule.ForemanStringField`, `grr.ForemanRule`, `grr.ForemanRuleAction`, - -- Artifact types. `grr.Artifact`, `grr.ArtifactSource`, `grr.ArtifactSource.SourceType`, `grr.ArtifactDescriptor`, `grr.ClientActionResult`, - -- Hunt types. `grr.Hunt`, `grr.Hunt.HuntState`, `grr.Hunt.HuntStateReason`, @@ -131,12 +122,10 @@ CREATE PROTO BUNDLE ( `grr.HuntArgumentsStandard`, `grr.HuntArgumentsVariable`, `grr.VariableHuntFlowGroup`, - -- SignedBinary types. `grr.BlobReference`, `grr.BlobReferences`, `grr.SignedBinaryID`, `grr.SignedBinaryID.BinaryType`, - -- CronJobs types. `grr.CronJob`, `grr.CronJobAction`, `grr.CronJobAction.ActionType`, @@ -145,14 +134,11 @@ CREATE PROTO BUNDLE ( `grr.SystemCronAction`, `grr.HuntCronAction`, `grr.HuntRunnerArgs`, - -- Message handlers. `grr.MessageHandlerRequest`, - -- Signed Command types. `grr.Command`, `grr.Command.EnvVar`, `grr.SignedCommand`, `grr.SignedCommand.OS`, - -- RRG types. `rrg.Log`, `rrg.Log.Level`, `rrg.fs.Path`, @@ -243,11 +229,6 @@ CREATE TABLE ApprovalRequests( FOREIGN KEY (Requestor) REFERENCES Users(Username), - -- TODO: Foreign keys on ARRAY columns are not supported, can we put an alternative in place? - -- CONSTRAINT fk_approval_request_notified_users_usernames - -- FOREIGN KEY UNNEST(NotifiedUsers) (NotifiedUsers) - -- REFERENCES Users(Username), - CONSTRAINT ck_subject_id_valid CHECK ((IF(SubjectClientId IS NOT NULL, 1, 0) + IF(SubjectHuntId IS NOT NULL, 1, 0) + @@ -503,10 +484,8 @@ CREATE TABLE PathFileHashes( INTERLEAVE IN PARENT Paths ON DELETE CASCADE; CREATE TABLE HashBlobReferences( - -- 32 bytes is enough for SHA256 used for hash ids. HashId BYTES(32) NOT NULL, Offset INT64 NOT NULL, - -- 32 bytes is enough for SHA256 used for blob ids. BlobId BYTES(32) NOT NULL, Size INT64 NOT NULL, @@ -538,7 +517,6 @@ CREATE TABLE YaraSignatureReferences( CONSTRAINT fk_yara_signature_reference_creator_username FOREIGN KEY (Creator) REFERENCES Users(Username) - ENFORCED ) PRIMARY KEY (BlobId); CREATE TABLE Hunts( @@ -655,11 +633,8 @@ ALTER TABLE ApprovalRequests ON DELETE CASCADE; CREATE TABLE BlobEncryptionKeys( - -- A unique identifier of the blob. BlobId BYTES(32) NOT NULL, - -- A timestamp at which the association was created. CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), - -- A name of the key (to retrieve the key from Keystore). KeyName STRING(256) NOT NULL, ) PRIMARY KEY (BlobId, CreationTime); diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index caa372ba2..2a2f2dd43 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2297,10 +2297,22 @@ def Txn(txn) -> flows_pb2.Flow: return flow + def Txn2(txn) -> flows_pb2.Flow: + try: + row = txn.read( + table="Flows", + keyset=spanner_lib.KeySet(keys=[[client_id, flow_id]]), + columns=_READ_FLOW_OBJECT_COLS, + ).one() + flow = _ParseReadFlowObjectRow(client_id, flow_id, row) + print(flow) + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + return flow + leased_flow = self.db.Transact(Txn) - #leased_flow.processing_since = int( - # rdfvalue.RDFDatetime.FromDatetime(commit_stats.commit_timestamp) - #) + flow = self.db.Transact(Txn2) + leased_flow.processing_since = flow.processing_since return leased_flow @db_utils.CallLogged diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 048a3259b..c5c0b28fd 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -452,15 +452,12 @@ def InsertOrUpdate( columns = list(columns) values = list(values) - with self._pyspanner.mutation_groups() as groups: - groups.group().insert_or_update( + with self._pyspanner.batch() as batch: + batch.insert_or_update( table=table, columns=columns, values=[values] ) - for response in groups.batch_write(): - if response.status.code != OK: - raise Exception(response.status.message) def Delete( self, table: str, key: Sequence[Any], txn_tag: Optional[str] = None From e4164229f9c0427d90bd86a075a320daab509806 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sun, 15 Jun 2025 19:25:47 +0000 Subject: [PATCH 028/168] EoD 20250615 --- .../databases/spanner_flows.py | 36 ++++++++++++------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 2a2f2dd43..83874efb5 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -8,6 +8,7 @@ from google.api_core.exceptions import AlreadyExists, NotFound from google.cloud import spanner as spanner_lib +from google.cloud.spanner_v1 import param_types from grr_response_core.lib import rdfvalue from grr_response_core.lib import utils @@ -74,9 +75,10 @@ def _BuildReadFlowResultsErrorsConditions( with_tag: Optional[str] = None, with_type: Optional[str] = None, with_substring: Optional[str] = None, -) -> tuple[str, Mapping[str, Any]]: +) -> tuple[str, Mapping[str, Any], Mapping[str, Any]]: """Builds query string and params for results/errors reading queries.""" params = {} + param_type = {} query = f""" SELECT t.Payload, t.RdfType, t.CreationTime, t.Tag, t.HuntId @@ -93,16 +95,19 @@ def _BuildReadFlowResultsErrorsConditions( if with_tag is not None: query += " AND t.Tag = {tag} " params["tag"] = with_tag + param_type["tag"] = param_types.STRING if with_type is not None: query += " AND t.RdfType = {type}" params["type"] = with_type + param_type["type"] = param_types.STRING if with_substring is not None: query += """ AND STRPOS(SAFE_CONVERT_BYTES_TO_STRING(t.Payload.value), {substring}) != 0 """ params["substring"] = with_substring + param_type["substring"] = param_types.STRING query += """ ORDER BY t.CreationTime ASC LIMIT {count} OFFSET {offset} @@ -110,7 +115,7 @@ def _BuildReadFlowResultsErrorsConditions( params["offset"] = offset params["count"] = count - return query, params + return query, params, param_type def _BuildCountFlowResultsErrorsConditions( @@ -119,9 +124,10 @@ def _BuildCountFlowResultsErrorsConditions( flow_id: str, with_tag: Optional[str] = None, with_type: Optional[str] = None, -) -> tuple[str, Mapping[str, Any]]: +) -> tuple[str, Mapping[str, Any], Mapping[str, Any]]: """Builds query string and params for count flow results/errors queries.""" params = {} + param_type = {} query = f""" SELECT COUNT(*) @@ -138,12 +144,14 @@ def _BuildCountFlowResultsErrorsConditions( if with_tag is not None: query += " AND t.Tag = {tag} " params["tag"] = with_tag + param_type["tag"] = param_types.STRING if with_type is not None: query += " AND t.RdfType = {type}" params["type"] = with_type + param_type["type"] = param_types.STRING - return query, params + return query, params, param_type _READ_FLOW_OBJECT_COLS = ( @@ -539,7 +547,7 @@ def ReadFlowResults( with_substring: Optional[str] = None, ) -> Sequence[flows_pb2.FlowResult]: """Reads flow results of a given flow using given query options.""" - query, params = _BuildReadFlowResultsErrorsConditions( + query, params, param_type = _BuildReadFlowResultsErrorsConditions( "FlowResults", client_id, flow_id, @@ -557,7 +565,8 @@ def ReadFlowResults( creation_time, tag, hunt_id, - ) in self.db.ParamQuery(query, params, txn_tag="ReadFlowResults"): + ) in self.db.ParamQuery(query, params, param_type=param_type, + txn_tag="ReadFlowResults"): result = flows_pb2.FlowResult( client_id=client_id, flow_id=flow_id, @@ -589,7 +598,7 @@ def ReadFlowErrors( with_type: Optional[str] = None, ) -> Sequence[flows_pb2.FlowError]: """Reads flow errors of a given flow using given query options.""" - query, params = _BuildReadFlowResultsErrorsConditions( + query, params, param_type = _BuildReadFlowResultsErrorsConditions( "FlowErrors", client_id, flow_id, @@ -607,7 +616,8 @@ def ReadFlowErrors( creation_time, tag, hunt_id, - ) in self.db.ParamQuery(query, params, txn_tag="ReadFlowErrors"): + ) in self.db.ParamQuery(query, params, param_type=param_type, + txn_tag="ReadFlowErrors"): error = flows_pb2.FlowError( client_id=client_id, flow_id=flow_id, @@ -616,7 +626,7 @@ def ReadFlowErrors( ).AsMicrosecondsSinceEpoch(), ) - # TODO(b/309429206): for separation of concerns reasons, + # for separation of concerns reasons, # ReadFlowResults/ReadFlowErrors shouldn't do the payload type validation, # they should be completely agnostic to what payloads get written/read # to/from the database. Keeping this logic here temporarily @@ -650,11 +660,11 @@ def CountFlowResults( ) -> int: """Counts flow results of a given flow using given query options.""" - query, params = _BuildCountFlowResultsErrorsConditions( + query, params, param_type = _BuildCountFlowResultsErrorsConditions( "FlowResults", client_id, flow_id, with_tag, with_type ) (count,) = self.db.ParamQuerySingle( - query, params, txn_tag="CountFlowResults" + query, params, param_type=param_type, txn_tag="CountFlowResults" ) return count @@ -669,11 +679,11 @@ def CountFlowErrors( ) -> int: """Counts flow errors of a given flow using given query options.""" - query, params = _BuildCountFlowResultsErrorsConditions( + query, params, param_type = _BuildCountFlowResultsErrorsConditions( "FlowErrors", client_id, flow_id, with_tag, with_type ) (count,) = self.db.ParamQuerySingle( - query, params, txn_tag="CountFlowErrors" + query, params, param_type=param_type, txn_tag="CountFlowErrors" ) return count From 4a8ef9840ee9fd78da98b581e55ace0e8ad501a5 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sun, 15 Jun 2025 20:01:09 +0000 Subject: [PATCH 029/168] EoD 20250615 --- grr/server/grr_response_server/databases/spanner_flows.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 83874efb5..a750f1e3d 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -1986,12 +1986,14 @@ def CountFlowOutputPluginLogEntries( "flow_id": flow_id, "output_plugin_id": output_plugin_id, } + param_type = {} if with_type is not None: query += " AND l.Type = {type}" params["type"] = int(with_type) + param_type["type"] = param_types.INT64 - (count,) = self.db.ParamQuerySingle(query, params) + (count,) = self.db.ParamQuerySingle(query, params, param_type=param_type) return count @db_utils.CallLogged From 304064fbe92199e0e4bff2c3958420d8369bf004 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 16 Jun 2025 13:38:01 +0000 Subject: [PATCH 030/168] Updates CronJobs table tests --- .../databases/db_cronjob_test.py | 12 ++++++------ .../databases/spanner_cron_jobs.py | 18 +++++++++--------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/grr/server/grr_response_server/databases/db_cronjob_test.py b/grr/server/grr_response_server/databases/db_cronjob_test.py index 7203a67c6..2ea502736 100644 --- a/grr/server/grr_response_server/databases/db_cronjob_test.py +++ b/grr/server/grr_response_server/databases/db_cronjob_test.py @@ -191,7 +191,7 @@ def testCronJobLeasing(self): job = self._CreateCronJob() self.db.WriteCronJob(job) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs(lease_time=lease_time) @@ -232,7 +232,7 @@ def testCronJobLeasingByID(self): for j in jobs: self.db.WriteCronJob(j) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( @@ -254,7 +254,7 @@ def testCronJobReturning(self): with self.assertRaises(ValueError): self.db.ReturnLeasedCronJobs([job]) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( cronjob_ids=[leased_job.cron_job_id], @@ -276,14 +276,14 @@ def testCronJobReturningMultiple(self): for j in jobs: self.db.WriteCronJob(j) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) ) self.assertLen(leased, 3) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10001) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): unleased_jobs = self.db.LeaseCronJobs( lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) @@ -292,7 +292,7 @@ def testCronJobReturningMultiple(self): self.db.ReturnLeasedCronJobs(leased) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10002) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index c3f833777..c3e92ca4d 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -260,7 +260,7 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: "ids_to_update": ids_to_update, } - txn.execute_sql(sql=update_query, params=update_params) + txn.execute_update(update_query, update_params) # --------------------------------------------------------------------- # Query (and return) jobs that were updated @@ -304,9 +304,9 @@ def ReturnLeasedCronJobs(self, jobs: Sequence[flows_pb2.CronJob]) -> None: continue conditions.append( - "(cj.JobId={job_%d} AND " - "cj.LeaseEndTime={ld_%d} AND " - "cj.LeaseOwner={lo_%d})" % (i, i, i) + "(cj.JobId=@job_%d AND " + "cj.LeaseEndTime=@ld_%d AND " + "cj.LeaseOwner=@lo_%d)" % (i, i, i) ) dt_leased_until = ( rdfvalue.RDFDatetime() @@ -343,7 +343,7 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: params_job_ids_to_return["ld_%d" % i] = ld params_job_ids_to_return["lo_%d" % i] = lo - response = txn.ParamQuery( + response = txn.execute_sql( query_job_ids_to_return, params_job_ids_to_return ) @@ -360,19 +360,19 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: update_query = """ UPDATE CronJobs as cj SET cj.LeaseEndTime = NULL, cj.LeaseOwner = NULL - WHERE cj.JobId IN UNNEST({ids_to_return}) + WHERE cj.JobId IN UNNEST(@ids_to_return) """ update_params = { "ids_to_return": ids_to_return, } - txn.ParamExecute(update_query, update_params) + txn.execute_update(update_query, update_params) # --------------------------------------------------------------------- # Query (and return) jobs that were updated # --------------------------------------------------------------------- where_returned = """ - WHERE cj.JobId IN UNNEST({updated_ids}) + WHERE cj.JobId IN UNNEST(@updated_ids) """ returned_params = { "updated_ids": ids_to_return, @@ -386,7 +386,7 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: returned_jobs = self.db.Transact( Transaction, txn_tag="ReturnLeasedCronJobs" - ).value + ) if unleased_jobs: raise ValueError("CronJobs to return are not leased: %s" % unleased_jobs) if len(returned_jobs) != len(jobs): From e438006225af9ccad139911d26791ac4aa99c95a Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 16 Jun 2025 13:42:06 +0000 Subject: [PATCH 031/168] Adds time tests --- .../databases/spanner_time_test.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 grr/server/grr_response_server/databases/spanner_time_test.py diff --git a/grr/server/grr_response_server/databases/spanner_time_test.py b/grr/server/grr_response_server/databases/spanner_time_test.py new file mode 100644 index 000000000..989058cc4 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_time_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_time_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseTimeTest( + db_time_test.DatabaseTimeTestMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file From 1f81835f2a7fe98f38d495aa872f6fb3e6f5b433 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 16 Jun 2025 17:31:17 +0000 Subject: [PATCH 032/168] Updates Flows table tests --- .../databases/db_flows_test.py | 17 +++++++++-------- .../databases/spanner_utils.py | 15 +++++++++------ 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/grr/server/grr_response_server/databases/db_flows_test.py b/grr/server/grr_response_server/databases/db_flows_test.py index 8993b5e87..f723b3858 100644 --- a/grr/server/grr_response_server/databases/db_flows_test.py +++ b/grr/server/grr_response_server/databases/db_flows_test.py @@ -2270,7 +2270,8 @@ def testWritesAndReadsSingleFlowResultOfSingleType(self): sample_result = flows_pb2.FlowResult(client_id=client_id, flow_id=flow_id) sample_result.payload.Pack(jobs_pb2.ClientSummary(client_id=client_id)) - with test_lib.FakeTime(42): + current_time = rdfvalue.RDFDatetime.Now() + with test_lib.FakeTime(current_time): self.db.WriteFlowResults([sample_result]) results = self.db.ReadFlowResults(client_id, flow_id, 0, 100) @@ -2278,7 +2279,7 @@ def testWritesAndReadsSingleFlowResultOfSingleType(self): self.assertEqual(results[0].payload, sample_result.payload) self.assertEqual( rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(results[0].timestamp), - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + current_time, ) def testWritesAndReadsRDFStringFlowResult(self): @@ -2335,7 +2336,8 @@ def testReadResultsRestoresAllFlowResultsFields(self): ) sample_result.payload.Pack(jobs_pb2.ClientSummary(client_id=client_id)) - with test_lib.FakeTime(42): + current_time = rdfvalue.RDFDatetime.Now() + with test_lib.FakeTime(current_time): self.db.WriteFlowResults([sample_result]) results = self.db.ReadFlowResults(client_id, flow_id, 0, 100) @@ -2346,7 +2348,7 @@ def testReadResultsRestoresAllFlowResultsFields(self): self.assertEqual(results[0].payload, sample_result.payload) self.assertEqual( rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(results[0].timestamp), - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + current_time, ) def testWritesAndReadsMultipleFlowResultsOfSingleType(self): @@ -2743,7 +2745,8 @@ def testWritesAndReadsSingleFlowErrorOfSingleType(self): sample_error = flows_pb2.FlowError(client_id=client_id, flow_id=flow_id) sample_error.payload.Pack(jobs_pb2.ClientSummary(client_id=client_id)) - with test_lib.FakeTime(42): + current_time = rdfvalue.RDFDatetime.Now() + with test_lib.FakeTime(current_time): self.db.WriteFlowErrors([sample_error]) errors = self.db.ReadFlowErrors(client_id, flow_id, 0, 100) @@ -2751,9 +2754,7 @@ def testWritesAndReadsSingleFlowErrorOfSingleType(self): self.assertEqual(errors[0].payload, sample_error.payload) self.assertEqual( errors[0].timestamp, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch( - 42 - ).AsMicrosecondsSinceEpoch(), + current_time.AsMicrosecondsSinceEpoch(), ) def testWritesAndReadsMultipleFlowErrorsOfSingleType(self): diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index c5c0b28fd..e7590e96b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -550,11 +550,11 @@ def PublishMessageHandlerRequests(self, requests: [str]) -> None: def PublishFlowProcessingRequests(self, requests: [str]) -> None: self.PublishRequests(requests, self.flow_processing_top_path) - def ReadMessageHandlerRequests(self): - return self.ReadRequests(self.message_handler_sub_path) + def ReadMessageHandlerRequests(self, min_req: Optional[int] = None): + return self.ReadRequests(self.message_handler_sub_path, min_req) - def ReadFlowProcessingRequests(self): - return self.ReadRequests(self.flow_processing_sub_path) + def ReadFlowProcessingRequests(self, min_req: Optional[int] = None): + return self.ReadRequests(self.flow_processing_sub_path, min_req) def AckMessageHandlerRequests(self, ack_ids: [str]) -> None: self.AckRequests(ack_ids, self.message_handler_sub_path) @@ -605,12 +605,13 @@ def DeleteAllRequests(self, sub_path: str) -> None: # Make the request response = client.seek(request=request) - def ReadRequests(self, sub_path: str): + def ReadRequests(self, sub_path: str, min_req: Optional[int] = None): # Make the request start_time = time.time() results = {} - while time.time() - start_time < 2: + want_more = True + while want_more or (time.time() - start_time < 2): time.sleep(0.1) response = self.subscriber.pull( @@ -626,6 +627,8 @@ def ReadRequests(self, sub_path: str): "ack_id": resp.ack_id, "publish_time": resp.message.publish_time} }) + if min_req and len(results) >= min_req: + want_more = False return results.values() From 7112de3c854fedefeba5da6d7c4016fa14b652b9 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 18 Jun 2025 20:58:18 +0000 Subject: [PATCH 033/168] EoD 20250618 --- .../databases/db_message_handler_test.py | 1 + .../grr_response_server/databases/spanner.py | 13 + .../grr_response_server/databases/spanner.sdl | 27 + .../databases/spanner_flows.py | 490 +++++++++++------- .../databases/spanner_message_handler_test.py | 114 +--- .../databases/spanner_utils_test.py | 48 -- 6 files changed, 359 insertions(+), 334 deletions(-) diff --git a/grr/server/grr_response_server/databases/db_message_handler_test.py b/grr/server/grr_response_server/databases/db_message_handler_test.py index 68d321051..944a3d350 100644 --- a/grr/server/grr_response_server/databases/db_message_handler_test.py +++ b/grr/server/grr_response_server/databases/db_message_handler_test.py @@ -93,6 +93,7 @@ def testMessageHandlerRequestLeasing(self): m.ClearField("timestamp") got += l self.db.DeleteMessageHandlerRequests(got) + self.db.UnregisterMessageHandler() got.sort(key=lambda req: req.request_id) self.assertEqual(requests, got) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index b25610dd8..1e2c21dfa 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -1,7 +1,9 @@ from google.cloud.spanner import Client +from grr_response_core import config from grr_response_core.lib import rdfvalue +from grr_response_server import threadpool from grr_response_server.databases import db as db_module from grr_response_server.databases import spanner_artifacts from grr_response_server.databases import spanner_blob_keys @@ -45,6 +47,17 @@ def __init__(self, db: spanner_utils.Database) -> None: self.db = db self._write_rows_batch_size = 10000 + self.handler_thread = None + self.handler_stop = True + + self.flow_processing_request_handler_thread = None + self.flow_processing_request_handler_stop = None + self.flow_processing_request_handler_pool = threadpool.ThreadPool.Factory( + "spanner_flow_processing_pool", + min_threads=config.CONFIG["Mysql.flow_processing_threads_min"], + max_threads=config.CONFIG["Mysql.flow_processing_threads_max"], + ) + @classmethod def FromConfig(cls) -> "Database": """Creates a GRR database instance for Spanner path specified in the config. diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index 7b2c6a6da..a9eaaabd8 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -644,3 +644,30 @@ CREATE TABLE SignedCommands( Ed25519Signature BYTES(64) NOT NULL, Command BYTES(MAX) NOT NULL, ) PRIMARY KEY (Id, OperatingSystem); + +CREATE TABLE MessageHandlerRequests( + RequestId STRING(16) NOT NULL, + HandlerName STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + LeasedUntil TIMESTAMP, + LeasedBy STRING(128), + Payload `grr.MessageHandlerRequest` NOT NULL, +) PRIMARY KEY (RequestId); + +CREATE INDEX MessageHandlerRequestsByLease + ON MessageHandlerRequests(LeasedUntil, LeasedBy); + +CREATE TABLE FlowProcessingRequests( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + RequestId STRING(36) NOT NULL, + DeliveryTime TIMESTAMP, + LeasedUntil TIMESTAMP, + LeasedBy STRING(128), + Payload `grr.FlowProcessingRequest` NOT NULL, +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowProcessingRequestsIds + ON FlowProcessingRequests(RequestId); diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index a750f1e3d..504e48513 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -4,6 +4,10 @@ import dataclasses import datetime import logging +import threading +import time +import uuid + from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union from google.api_core.exceptions import AlreadyExists, NotFound @@ -235,6 +239,8 @@ class FlowsMixin: db: spanner_utils.Database _write_rows_batch_size: int + handler_thread: threading.Thread + @property def _flow_processing_request_receiver( self, @@ -741,19 +747,36 @@ def CountFlowErrorsByType( return result - def _BuildFlowProcessingRequestWrites( + def _WriteFlowProcessingRequests( self, requests: Iterable[flows_pb2.FlowProcessingRequest], + txn ) -> None: - """Writes a list of FlowProcessingRequests to the queue.""" - flowProcessingRequests = [] + """Writes a list of FlowProcessingRequests.""" + + columns = [ + "RequestId", + "ClientId", + "FlowId", + "CreationTime", + "Payload", + "DeliveryTime" + ] + rows = [] for request in requests: - request.creation_time=self.db.Now().AsMicrosecondsSinceEpoch() - flowProcessingRequests.append(request.SerializeToString()) + row = [ + str(uuid.uuid4()), + request.client_id, + request.flow_id, + spanner_lib.COMMIT_TIMESTAMP, + request, + request.delivery_time + ] + rows.append(row) + txn.insert(table="FlowProcessingRequests", columns=columns, values=rows) self.db.PublishFlowProcessingRequests(flowProcessingRequests) - @db_utils.CallLogged @db_utils.CallAccounted def WriteFlowProcessingRequests( @@ -762,7 +785,10 @@ def WriteFlowProcessingRequests( ) -> None: """Writes a list of flow processing requests to the database.""" - self._BuildFlowProcessingRequestWrites(requests) + def Txn(txn) -> None: + self._WriteFlowProcessingRequests(requests, txn) + + self.db.Transact(Txn, txn_tag="WriteFlowProcessingRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -770,14 +796,16 @@ def ReadFlowProcessingRequests( self, ) -> Sequence[flows_pb2.FlowProcessingRequest]: """Reads all flow processing requests from the database.""" + query = """ + SELECT fpr.Payload, fpr.CreationTime FROM FlowProcessingRequests AS fpr + """ results = [] - for result in self.db.ReadFlowProcessingRequests(): + for payload, creation_time in self.db.ParamQuery(query, {}): req = flows_pb2.FlowProcessingRequest() - req.ParseFromString(result["payload"]) - req.creation_time = int( - rdfvalue.RDFDatetime.FromDatetime(result["publish_time"]) - ) - req.ack_id = result["ack_id"] + req.ParseFromString(payload) + req.creation_time = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() results.append(req) return results @@ -787,18 +815,117 @@ def ReadFlowProcessingRequests( def AckFlowProcessingRequests( self, requests: Iterable[flows_pb2.FlowProcessingRequest] ) -> None: - """Acknowledges and deletes flow processing requests.""" - ack_ids = [] - for r in requests: - ack_ids.append(r.ack_id) - - self.db.AckFlowProcessingRequests(ack_ids) + """Deletes a list of flow processing requests from the database.""" + def Txn(txn) -> None: + keys = [] + for request in requests: + keys.append([request.client_id, request.flow_id, request.creation_time]) + keyset = spanner_lib.KeySet(keys=keys) + txn.delete(table="FlowProcessingRequests", keyset=keyset) + + self.db.Transact(Txn) @db_utils.CallLogged @db_utils.CallAccounted def DeleteAllFlowProcessingRequests(self) -> None: """Deletes all flow processing requests from the database.""" - self.db.DeleteAllFlowProcessingRequests() + + def Txn(txn) -> None: + keyset = spanner_lib.KeySet(all_=True) + txn.delete(table="FlowProcessingRequests", keyset=keyset) + + self.db.Transact(Txn) + + @db_utils.CallLogged + @db_utils.CallAccounted + def _LeaseFlowProcessingRequests( + self, limit: int + ) -> Sequence[flows_pb2.FlowProcessingRequest]: + """Leases a number of flow processing requests.""" + now = rdfvalue.RDFDatetime.Now() + expiry = now + rdfvalue.Duration.From(10, rdfvalue.MINUTES) + + def Txn(txn) -> None: + keyset = spanner_lib.KeySet(all_=True) + params = { + "limit": limit, + "now": now.AsDatetime() + } + param_type = { + "limit": param_types.INT64, + "now": param_types.TIMESTAMP + } + requests = txn.execute_sql( + "SELECT RequestId, CreationTime, Payload " + "FROM FlowProcessingRequests " + "WHERE " + " (DeliveryTime IS NULL OR DeliveryTime <= @now) AND " + " (LeasedUntil IS NULL OR LeasedUntil < @now) " + "LIMIT @limit", + params=params, + param_types=param_type) + + res = [] + request_ids = [] + for request_id, creation_time, request in cursor.fetchall(): + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(request) + req.creation_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + creation_time + ) + res.append(req) + request_ids.append(request_id) + + query = ( + "UPDATE FlowProcessingRequests " + "SET LeasedUntil=@leased_until, LeasedBy=@leased_by " + "WHERE RequestId IN UNNEST(@request_ids)" + ) + params = { + "request_ids": request_ids, + "leased_by": leased_by, + "leased_until": expiry.AsDatetime() + } + param_type = { + "request_ids": param_types.Array(param_types.STRING), + "leased_by": param_types.STRING, + "leased_until": param_types.TIMESTAMP + } + txn.execute_update(query, params, param_type) + + return res + + return self.db.Transact(Txn) + + _FLOW_REQUEST_POLL_TIME_SECS = 3 + + def _FlowProcessingRequestHandlerLoop( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: + """The main loop for the flow processing request queue.""" + self.flow_processing_request_handler_pool.Start() + + while not self.flow_processing_request_handler_stop: + thread_pool = self.flow_processing_request_handler_pool + free_threads = thread_pool.max_threads - thread_pool.busy_threads + if free_threads == 0: + time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) + continue + try: + msgs = self._LeaseFlowProcessingRequests(free_threads) + if msgs: + for m in msgs: + self.flow_processing_request_handler_pool.AddTask( + target=handler, args=(m,) + ) + else: + time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) + + except Exception as e: # pylint: disable=broad-except + logging.exception("_FlowProcessingRequestHandlerLoop raised %s.", e) + time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) + + self.flow_processing_request_handler_pool.Stop() def RegisterFlowProcessingHandler( self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] @@ -806,51 +933,26 @@ def RegisterFlowProcessingHandler( """Registers a handler to receive flow processing messages.""" self.UnregisterFlowProcessingHandler() - def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time): - try: - req = flows_pb2.FlowProcessingRequest() - req.ParseFromString(payload) - date_time_now = rdfvalue.RDFDatetime.Now() - epoch_now = date_time_now.AsMicrosecondsSinceEpoch() - epoch_in_ten = epoch_now + 10 * 1000000 - if req.delivery_time > epoch_now: - ack_ids = [] - ack_ids.append(ack_id) - # figure out when we reach the delivery time, and push it out (max 10 mins allowed by PubSub) - ack_deadline = req.delivery_time if req.delivery_time <= epoch_in_ten else epoch_in_ten - # PubSub wants the deadline in seconds from now - ack_deadline = int((ack_deadline - epoch_now)/1000000) - self.db.LeaseFlowProcessingRequests(ack_ids, ack_deadline) - else: - #req.creation_time = int( - # rdfvalue.RDFDatetime.FromDatetime(publish_time) - #) - req.ack_id = ack_id - handler(req) - except Exception as e: # pylint: disable=broad-except - logging.exception( - "Exception raised during FlowProcessingRequest processing: %s", e - ) - - receiver = self.db.NewRequestQueue( - "FlowProcessing", - Callback, - receiver_max_keepalive_seconds=3000, - receiver_max_active_callbacks=50, - receiver_max_messages_per_callback=1, - ) - self._flow_processing_request_receiver = receiver + if handler: + self.flow_processing_request_handler_stop = False + self.flow_processing_request_handler_thread = threading.Thread( + name="flow_processing_request_handler", + target=self._FlowProcessingRequestHandlerLoop, + args=(handler,), + ) + self.flow_processing_request_handler_thread.daemon = True + self.flow_processing_request_handler_thread.start() def UnregisterFlowProcessingHandler( self, timeout: Optional[rdfvalue.Duration] = None ) -> None: """Unregisters any registered flow processing handler.""" - del timeout # Unused. - if self._flow_processing_request_receiver is not None: - # Pytype doesn't understand that the if-check above ensures that - # _flow_processing_request_receiver is not None. - self._flow_processing_request_receiver.Stop() # pytype: disable=attribute-error - self._flow_processing_request_receiver = None + if self.flow_processing_request_handler_thread: + self.flow_processing_request_handler_stop = True + self.flow_processing_request_handler_thread.join(timeout) + if self.flow_processing_request_handler_thread.is_alive(): + raise RuntimeError("Flow processing handler did not join in time.") + self.flow_processing_request_handler_thread = None @db_utils.CallLogged @db_utils.CallAccounted @@ -918,7 +1020,7 @@ def Txn(txn) -> None: flow_processing_requests.append(req) if flow_processing_requests: - self._BuildFlowProcessingRequestWrites(flow_processing_requests) + self._WriteFlowProcessingRequests(flow_processing_requests, txn) try: self.db.Transact(Txn, txn_tag="WriteFlowRequests") @@ -1393,7 +1495,7 @@ def _UpdateNeedsProcessingAndWriteFlowProcessingRequests( fpr.delivery_time = int(start_time) flow_processing_requests.append(fpr) - self._BuildFlowProcessingRequestWrites(flow_processing_requests) + self._WriteFlowProcessingRequests(flow_processing_requests, txn) @db_utils.CallLogged @db_utils.CallAccounted @@ -2091,18 +2193,139 @@ def ListScheduledFlows( return results + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Leases a number of message handler requests up to the indicated limit.""" + self.UnregisterMessageHandler() + + if handler: + self.handler_stop = False + self.handler_thread = threading.Thread( + name="message_handler", + target=self._MessageHandlerLoop, + args=(handler, lease_time, limit), + ) + self.handler_thread.daemon = True + self.handler_thread.start() + + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: + """Unregisters any registered message handler.""" + if self.handler_thread: + self.handler_stop = True + self.handler_thread.join(timeout) + if self.handler_thread.is_alive(): + raise RuntimeError("Message handler thread did not join in time.") + self.handler_thread = None + + _MESSAGE_HANDLER_POLL_TIME_SECS = 5 + + def _MessageHandlerLoop( + self, + handler: Callable[[Iterable[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Loop to handle outstanding requests.""" + while not self.handler_stop: + try: + msgs = self._LeaseMessageHandlerRequests(lease_time, limit) + if msgs: + handler(msgs) + else: + time.sleep(self._MESSAGE_HANDLER_POLL_TIME_SECS) + except Exception as e: # pylint: disable=broad-except + logging.exception("_LeaseMessageHandlerRequests raised %s.", e) + + def _LeaseMessageHandlerRequests( + self, + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> Iterable[objects_pb2.MessageHandlerRequest]: + """Leases a number of message handler requests up to the indicated limit.""" + now = rdfvalue.RDFDatetime.Now() + delivery_time = now + lease_time + + leased_until = delivery_time.AsMicrosecondsSinceEpoch() + leased_by = utils.ProcessIdString() + + def Txn(txn) -> None: + # Read the message handler requests waiting for leases + keyset = spanner_lib.KeySet(all_=True) + params = { + "limit": limit, + "now": now.AsDatetime() + } + param_type = { + "limit": param_types.INT64, + "now": param_types.TIMESTAMP + } + requests = txn.execute_sql( + "SELECT RequestId, CreationTime, Payload " + "FROM MessageHandlerRequests " + "WHERE LeasedUntil IS NULL OR LeasedUntil < @now " + "LIMIT @limit", + params=params, + param_types=param_type) + res = [] + request_ids = [] + for request_id, creation_time, request in requests: + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(request) + req.timestamp = req.leased_until = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + req.leased_until = leased_until + req.leased_by = leased_by + res.append(req) + request_ids.append(request_id) + + query = ( + "UPDATE MessageHandlerRequests " + "SET LeasedUntil=@leased_until, LeasedBy=@leased_by " + "WHERE RequestId IN UNNEST(@request_ids)" + ) + params = { + "request_ids": request_ids, + "leased_by": leased_by, + "leased_until": delivery_time.AsDatetime() + } + param_type = { + "request_ids": param_types.Array(param_types.STRING), + "leased_by": param_types.STRING, + "leased_until": param_types.TIMESTAMP + } + txn.execute_update(query, params, param_type) + + return res + + return self.db.Transact(Txn) + @db_utils.CallLogged @db_utils.CallAccounted def WriteMessageHandlerRequests( self, requests: Iterable[objects_pb2.MessageHandlerRequest] ) -> None: """Writes a list of message handler requests to the queue.""" + def Mutation(mut) -> None: + creation_timestamp = spanner_lib.COMMIT_TIMESTAMP + rows = [] + columns = ["RequestId", "HandlerName", "CreationTime", "Payload"] + for request in requests: + rows.append([ + str(request.request_id), + request.handler_name, + creation_timestamp, + request, + ]) + mut.insert(table="MessageHandlerRequests", columns=columns, values=rows) - msgRequests = [] - for request in requests: - msgRequests.append(request.SerializeToString()) - - self.db.PublishMessageHandlerRequests(msgRequests) + self.db.Transact(Mutation) @db_utils.CallLogged @db_utils.CallAccounted @@ -2110,129 +2333,42 @@ def ReadMessageHandlerRequests( self, ) -> Sequence[objects_pb2.MessageHandlerRequest]: """Reads all message handler requests from the queue.""" - + query = """ + SELECT t.Payload, t.CreationTime, t.LeasedBy, t.LeasedUntil FROM MessageHandlerRequests AS t + """ results = [] - for result in self.db.ReadMessageHandlerRequests(): + for payload, creation_time, leased_by, leased_until in self.db.ParamQuery(query, {}): req = objects_pb2.MessageHandlerRequest() - req.ParseFromString(result["payload"]) - req.timestamp = int( - rdfvalue.RDFDatetime.FromDatetime(result["publish_time"]) - ) - req.ack_id = result["ack_id"] + req.ParseFromString(payload) + req.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + if leased_by is not None: + req.leased_by = leased_by + if leased_until is not None: + req.leased_until = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() results.append(req) return results - def _BuildDeleteMessageHandlerRequestWrites( - self, - txn, - requests: Iterable[objects_pb2.MessageHandlerRequest], - ) -> None: - """Deletes given requests within a given transaction.""" - ack_ids = [] - for r in requests: - ack_ids.append(r.ack_id) - - self.db.AckMessageHandlerRequests(ack_ids) - @db_utils.CallLogged @db_utils.CallAccounted def DeleteMessageHandlerRequests( - self, requests: Iterable[objects_pb2.MessageHandlerRequest] - ) -> None: - """Deletes a list of message handler requests from the database.""" - ack_ids = [] - for request in requests: - ack_ids.append(request.ack_id) - - self.db.AckMessageHandlerRequests(ack_ids) - - def _LeaseMessageHandlerRequest( - self, - req: objects_pb2.MessageHandlerRequest, - lease_time: rdfvalue.Duration, - ) -> bool: - """Leases the given message handler request. - - Leasing of the message amounts to the following: - 1. The message gets deleted from the queue. - 2. It gets rescheduled in the future (at now + lease_time) with - "leased_until" and "leased_by" attributes set. - - Args: - req: MessageHandlerRequest to lease. - lease_time: Lease duration. - - Returns: - Copy of the original request object with "leased_until" and "leased_by" - attributes set. - """ - date_time_now = rdfvalue.RDFDatetime.Now() - epoch_now = date_time_now.AsMicrosecondsSinceEpoch() - delivery_time = date_time_now + lease_time - - leased = False - if not req.leased_by or req.leased_until <= epoch_now: - # If the message has not been leased yet or the lease has expired - # then take and write back the clone back to the queue - # and delete the original message - clone = objects_pb2.MessageHandlerRequest() - clone.CopyFrom(req) - clone.leased_until = delivery_time.AsMicrosecondsSinceEpoch() - clone.leased_by = utils.ProcessIdString() - clone.ack_id = "" - self.WriteMessageHandlerRequests([clone]) - self.DeleteMessageHandlerRequests([req]) - elif req.leased_until > epoch_now: - # if we have leased the message (leased_until set and in the future) - # then we modify ack deadline to match the leased_until time - leased = True - ack_ids = [] - ack_ids.append(req.ack_id) - self.db.LeaseMessageHandlerRequests(ack_ids, int((req.leased_until - epoch_now)/1000000)) - - return leased - - def RegisterMessageHandler( self, - handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], - lease_time: rdfvalue.Duration, - limit: int = 1000, + requests: Iterable[objects_pb2.MessageHandlerRequest], ) -> None: - """Registers a message handler to receive batches of messages.""" - self.UnregisterMessageHandler() - - def Callback(payload: bytes, msg_id: str, ack_id: str, publish_time): - try: - req = objects_pb2.MessageHandlerRequest() - req.ParseFromString(payload) - req.ack_id = ack_id - leased = self._LeaseMessageHandlerRequest(req, lease_time) - if leased: - logging.info("Leased message handler request: %s", req.request_id) - handler([req]) - except Exception as e: # pylint: disable=broad-except - logging.exception( - "Exception raised during MessageHandlerRequest processing: %s", e - ) + """Deletes a list of message handler requests from the database.""" - receiver = self.db.NewRequestQueue( - "MessageHandler", - Callback, - receiver_max_keepalive_seconds=_MESSAGE_HANDLER_MAX_KEEPALIVE_SECONDS, - receiver_max_active_callbacks=_MESSAGE_HANDLER_MAX_ACTIVE_CALLBACKS, - receiver_max_messages_per_callback=limit, - ) - self._message_handler_receiver = receiver + query = "DELETE FROM MessageHandlerRequests WHERE RequestId IN UNNEST(@request_ids)" + request_ids = [] + for r in requests: + request_ids.append(str(r.request_id)) + params={"request_ids": request_ids} + param_type={"request_ids": param_types.Array(param_types.STRING)} - def UnregisterMessageHandler( - self, timeout: Optional[rdfvalue.Duration] = None - ) -> None: - """Unregisters any registered message handler.""" - del timeout # Unused. - if self._message_handler_receiver: - self._message_handler_receiver.Stop() # pytype: disable=attribute-error # always-use-return-annotations - self._message_handler_receiver = None + self.db.ParamExecute(query, params, param_type) def _ReadHuntState( self, txn, hunt_id: str diff --git a/grr/server/grr_response_server/databases/spanner_message_handler_test.py b/grr/server/grr_response_server/databases/spanner_message_handler_test.py index 1702ff178..4b930b57c 100644 --- a/grr/server/grr_response_server/databases/spanner_message_handler_test.py +++ b/grr/server/grr_response_server/databases/spanner_message_handler_test.py @@ -1,13 +1,5 @@ -import queue - from absl.testing import absltest -from grr_response_core.lib import rdfvalue -from grr_response_core.lib import utils -from grr_response_core.lib.rdfvalues import mig_protodict -from grr_response_core.lib.rdfvalues import protodict as rdf_protodict -from grr_response_proto import objects_pb2 - from grr_response_server.databases import db_message_handler_test from grr_response_server.databases import spanner_test_lib @@ -20,108 +12,12 @@ def tearDownModule() -> None: spanner_test_lib.TearDown() -class SpannerDatabaseHandlerTest(spanner_test_lib.TestCase): - def setUp(self): - super().setUp() - - def testMessageHandlerRequests(self): - - ######################## - # Read / Write tests - ######################## - requests = [] - for i in range(5): - emb = mig_protodict.ToProtoEmbeddedRDFValue( - rdf_protodict.EmbeddedRDFValue(rdfvalue.RDFInteger(i)) - ) - requests.append( - objects_pb2.MessageHandlerRequest( - client_id="C.1000000000000000", - handler_name="Testhandler 0", - request_id=i * 100, - request=emb, - ) - ) - - self.db.WriteMessageHandlerRequests(requests) - - read = self.db.ReadMessageHandlerRequests() - self.assertLen(read, 5) - - self.db.DeleteMessageHandlerRequests(read[:2]) - self.db.DeleteMessageHandlerRequests(read[4:5]) - - for r in read: - self.assertTrue(r.timestamp) - r.ClearField("timestamp") - self.assertTrue(r.ack_id) - r.ClearField("ack_id") - - self.assertCountEqual(read, requests) - - read = self.db.ReadMessageHandlerRequests() - self.assertLen(read, 2) - self.db.DeleteMessageHandlerRequests(read) - - for r in read: - self.assertTrue(r.timestamp) - r.ClearField("timestamp") - self.assertTrue(r.ack_id) - r.ClearField("ack_id") - - self.assertCountEqual(requests[2:4], read) - - def testMessageHandlerLeaseManagement(self): - ######################## - # Lease Management tests - ######################## - - requests = [] - for i in range(10): - emb = mig_protodict.ToProtoEmbeddedRDFValue( - rdf_protodict.EmbeddedRDFValue(rdfvalue.RDFInteger(i)) - ) - requests.append( - objects_pb2.MessageHandlerRequest( - client_id="C.1000000000000001", - handler_name="Testhandler 1", - request_id=i * 100, - request=emb, - ) - ) - - lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) - - leased = queue.Queue() - self.db.RegisterMessageHandler(leased.put, lease_time, limit=10) - - self.db.WriteMessageHandlerRequests(requests) - - got = [] - while len(got) < 10: - try: - l = leased.get(True, timeout=6) - except queue.Empty: - self.fail( - "Timed out waiting for messages, expected 10, got %d" % len(got) - ) - self.assertLessEqual(len(l), 10) - for m in l: - self.assertEqual(m.leased_by, utils.ProcessIdString()) - self.assertGreater(m.leased_until, rdfvalue.RDFDatetime.Now()) - self.assertLess(m.timestamp, rdfvalue.RDFDatetime.Now()) - got += l - self.db.DeleteMessageHandlerRequests(got) - - got.sort(key=lambda req: req.request_id) - for m in got: - m.ClearField("leased_by") - m.ClearField("leased_until") - m.ClearField("timestamp") - m.ClearField("ack_id") - self.assertEqual(requests, got) +class SpannerDatabaseHandlerTest( + db_message_handler_test.DatabaseTestHandlerMixin, + spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. - self.db.UnregisterMessageHandler() if __name__ == "__main__": absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index a1166d874..8975d502e 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -316,53 +316,5 @@ def Mutation(txn) -> None: with self.assertRaises(RuntimeError): self.raw_db.Mutate(Mutation) - ####################################### - # Queue Tests - ####################################### - def testNewRequestQueueCallbackGetsCalled(self): - callback_func = mock.Mock() - - requestQueue = self.raw_db.NewRequestQueue( - "MessageHandler", - callback_func, - receiver_max_keepalive_seconds=10, - receiver_max_active_callbacks=1, - receiver_max_messages_per_callback=1, - ) - - start_time = time.time() - requests = [] - requests.append("foo".encode("utf-8")) - self.raw_db.PublishMessageHandlerRequests(requests) - - while callback_func.call_count == 0: - time.sleep(0.1) - if time.time() - start_time > 10: - self.fail("Request was not processed in time.") - - callback_func.assert_called_once() - result = callback_func.call_args.kwargs - ack_ids = [] - ack_ids.append(result["ack_id"]) - self.raw_db.AckMessageHandlerRequests(ack_ids) - requestQueue.Stop() - - def testNewRequestCount(self): - - start_time = time.time() - requests = [] - requests.append("foo".encode("utf-8")) - requests.append("bar".encode("utf-8")) - self.raw_db.PublishMessageHandlerRequests(requests) - - results = self.raw_db.ReadMessageHandlerRequests() - ack_ids = [] - - for result in results: - ack_ids.append(result["ack_id"]) - - self.raw_db.AckMessageHandlerRequests(ack_ids) - self.assertLen(results, 2) - if __name__ == "__main__": absltest.main() From 25dd61838e23ca0e0698c183f6cdf4f541ed7739 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Thu, 19 Jun 2025 09:13:28 +0000 Subject: [PATCH 034/168] Updates FlowRequestProcessing and MessageHandler --- .../databases/spanner_flows.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 504e48513..e97a73334 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -770,12 +770,13 @@ def _WriteFlowProcessingRequests( request.flow_id, spanner_lib.COMMIT_TIMESTAMP, request, - request.delivery_time + rdfvalue.RDFDatetime( + request.delivery_time + ).AsDatetime() ] rows.append(row) txn.insert(table="FlowProcessingRequests", columns=columns, values=rows) - self.db.PublishFlowProcessingRequests(flowProcessingRequests) @db_utils.CallLogged @db_utils.CallAccounted @@ -819,7 +820,10 @@ def AckFlowProcessingRequests( def Txn(txn) -> None: keys = [] for request in requests: - keys.append([request.client_id, request.flow_id, request.creation_time]) + creation_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + request.creation_time + ).AsDatetime() + keys.append([request.client_id, request.flow_id, creation_time]) keyset = spanner_lib.KeySet(keys=keys) txn.delete(table="FlowProcessingRequests", keyset=keyset) @@ -867,12 +871,12 @@ def Txn(txn) -> None: res = [] request_ids = [] - for request_id, creation_time, request in cursor.fetchall(): + for request_id, creation_time, request in requests: req = flows_pb2.FlowProcessingRequest() req.ParseFromString(request) - req.creation_time = mysql_utils.TimestampToMicrosecondsSinceEpoch( + req.creation_time = rdfvalue.RDFDatetime.FromDatetime( creation_time - ) + ).AsMicrosecondsSinceEpoch() res.append(req) request_ids.append(request_id) @@ -883,7 +887,7 @@ def Txn(txn) -> None: ) params = { "request_ids": request_ids, - "leased_by": leased_by, + "leased_by": utils.ProcessIdString(), "leased_until": expiry.AsDatetime() } param_type = { From 25d8ee1846fd5db31fdeb0b3d88bbafedf1f19b3 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Thu, 19 Jun 2025 20:32:32 +0000 Subject: [PATCH 035/168] Clean up handlers --- .../databases/spanner_flows.py | 22 --- .../databases/spanner_test_lib.py | 56 +----- .../databases/spanner_utils.py | 184 +----------------- 3 files changed, 2 insertions(+), 260 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index e97a73334..a4c29e323 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -241,28 +241,6 @@ class FlowsMixin: handler_thread: threading.Thread - @property - def _flow_processing_request_receiver( - self, - ) -> Optional[spanner_utils.RequestQueue]: - return getattr(self, "__flow_processing_request_receiver", None) - - @_flow_processing_request_receiver.setter - def _flow_processing_request_receiver( - self, value: Optional[spanner_utils.RequestQueue] - ) -> None: - setattr(self, "__flow_processing_request_receiver", value) - - @property - def _message_handler_receiver(self) -> Optional[spanner_utils.RequestQueue]: - return getattr(self, "__message_handler_receiver", None) - - @_message_handler_receiver.setter - def _message_handler_receiver( - self, value: Optional[spanner_utils.RequestQueue] - ) -> None: - setattr(self, "__message_handler_receiver", value) - @db_utils.CallLogged @db_utils.CallAccounted def WriteFlowObject( diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index f1fa76c5f..2fba5f12d 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -1,13 +1,11 @@ """A library with utilities for testing the Spanner database implementation.""" import os import unittest -import uuid from typing import Optional from absl.testing import absltest -from google.cloud import pubsub_v1 from google.cloud import spanner_v1 as spanner_lib from google.cloud import spanner_admin_database_v1 from google.cloud.spanner import Client, KeySet @@ -110,10 +108,6 @@ class TestCase(absltest.TestCase): This class takes care of setting up a clean database for every test method. It is intended to be used with database test suite mixins. """ - msg_handler_top = None - msg_handler_sub = None - flow_processing_top = None - flow_processing_sub = None project_id = None @@ -122,61 +116,13 @@ def setUp(self): self.project_id = _GetEnvironOrSkip("PROJECT_ID") - msg_uuid=str(uuid.uuid4()) - flow_uuid=str(uuid.uuid4()) - self.msg_handler_top_id = "msg-top"+msg_uuid - self.msg_handler_sub_id = "msg-sub"+msg_uuid - - self.flow_processing_top_id = "flow-top"+flow_uuid - self.flow_processing_sub_id = "flow-sub"+flow_uuid - - publisher = pubsub_v1.PublisherClient() - msg_handler_top_path = publisher.topic_path(self.project_id, self.msg_handler_top_id) - flow_processing_top_path = publisher.topic_path(self.project_id, self.flow_processing_top_id) - message_handler_top = publisher.create_topic(request={"name": msg_handler_top_path}) - flow_processing_top = publisher.create_topic(request={"name": flow_processing_top_path}) - - subscriber = pubsub_v1.SubscriberClient() - msg_handler_sub_path = subscriber.subscription_path(self.project_id, self.msg_handler_sub_id) - flow_processing_sub_path = subscriber.subscription_path(self.project_id, self.flow_processing_sub_id) - message_handler_sub = subscriber.create_subscription(request={"name": msg_handler_sub_path, - "topic": msg_handler_top_path} - ) - flow_processing_sub = subscriber.create_subscription(request={"name": flow_processing_sub_path, - "topic": flow_processing_top_path} - ) - - _clean_database() - self.raw_db = spanner_utils.Database(_TEST_DB, self.project_id, - self.msg_handler_top_id, self.msg_handler_sub_id, - self.flow_processing_top_id, self.flow_processing_sub_id) + self.raw_db = spanner_utils.Database(_TEST_DB, self.project_id) spannerDB = spanner_db.SpannerDB(self.raw_db) self.db = abstract_db.DatabaseValidationWrapper(spannerDB) - def tearDown(self): - subscriber = pubsub_v1.SubscriberClient() - msg_handler_sub_path = subscriber.subscription_path(self.project_id, self.msg_handler_sub_id) - flow_processing_sub_path = subscriber.subscription_path(self.project_id, self.flow_processing_sub_id) - - # Wrap the subscriber in a 'with' block to automatically call close() to - # close the underlying gRPC channel when done. - with subscriber: - subscriber.delete_subscription(request={"subscription": msg_handler_sub_path}) - subscriber.delete_subscription(request={"subscription": flow_processing_sub_path}) - - publisher = pubsub_v1.PublisherClient() - msg_handler_top_path = publisher.topic_path(self.project_id, self.msg_handler_top_id) - flow_processing_top_path = publisher.topic_path(self.project_id, self.flow_processing_top_id) - - publisher.delete_topic(request={"topic": msg_handler_top_path}) - publisher.delete_topic(request={"topic": flow_processing_top_path}) - - super().tearDown() - - def _get_table_names(db): with db.snapshot() as snapshot: query_result = snapshot.execute_sql( diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index e7590e96b..7fb4728f0 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -22,7 +22,6 @@ from concurrent import futures -from google.cloud import pubsub_v1 from google.cloud import spanner_v1 as spanner_lib from google.cloud.spanner import KeyRange, KeySet @@ -43,46 +42,6 @@ _T = TypeVar("_T") -class RequestQueue: - """ - This stores the callback internally, and will continue to deliver messages to - the callback as long as it is referenced in python code and Stop is not - called. - """ - - def __init__( - self, - subscriber, - subscription_path: str, - callback, # : Callable - receiver_max_keepalive_seconds: int, - receiver_max_active_callbacks: int, - receiver_max_messages_per_callback: int, - ): - - # An optional executor to use. If not specified, a default one with maximum 10 - # threads will be created. - executor = futures.ThreadPoolExecutor(max_workers=receiver_max_active_callbacks) - # A thread pool-based scheduler. It must not be shared across SubscriberClients. - scheduler = pubsub_v1.subscriber.scheduler.ThreadScheduler(executor) - - #flow_control = pubsub_v1.types.FlowControl(max_messages=receiver_max_messages_per_callback) - - self.streaming_pull_future = subscriber.subscribe( - subscription_path, callback=callback, scheduler=scheduler - #subscription_path, callback=callback, scheduler=scheduler, flow_control=flow_control - ) - - def Stop(self): - if self.streaming_pull_future: - try: - self.streaming_pull_future.cancel() - except asyncio.CancelledError: - pass # Expected when cancelling - except Exception as e: - print(f"Warning: Exception while cancelling future: {e}") - - time.sleep(0.1) # Give a short buffer for threads to clean up class Database: """A wrapper around the PySpanner class. @@ -94,18 +53,10 @@ class Database: _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") - def __init__(self, pyspanner: spanner_lib.database, project_id: str, - msg_handler_top_id: str, msg_handler_sub_id: str, - flow_processing_top_id: str, flow_processing_sub_id: str) -> None: + def __init__(self, pyspanner: spanner_lib.database, project_id: str) -> None: super().__init__() self._pyspanner = pyspanner self.project_id = project_id - self.publisher = pubsub_v1.PublisherClient() - self.subscriber = pubsub_v1.SubscriberClient() - self.flow_processing_sub_path = self.subscriber.subscription_path(project_id, flow_processing_sub_id) - self.flow_processing_top_path = self.publisher.topic_path(project_id, flow_processing_top_id) - self.message_handler_sub_path = self.subscriber.subscription_path(project_id, msg_handler_sub_id) - self.message_handler_top_path = self.publisher.topic_path(project_id, msg_handler_top_id) def Now(self) -> rdfvalue.RDFDatetime: """Retrieves current time as reported by the database.""" @@ -543,136 +494,3 @@ def ReadSet( ) return results - - def PublishMessageHandlerRequests(self, requests: [str]) -> None: - self.PublishRequests(requests, self.message_handler_top_path) - - def PublishFlowProcessingRequests(self, requests: [str]) -> None: - self.PublishRequests(requests, self.flow_processing_top_path) - - def ReadMessageHandlerRequests(self, min_req: Optional[int] = None): - return self.ReadRequests(self.message_handler_sub_path, min_req) - - def ReadFlowProcessingRequests(self, min_req: Optional[int] = None): - return self.ReadRequests(self.flow_processing_sub_path, min_req) - - def AckMessageHandlerRequests(self, ack_ids: [str]) -> None: - self.AckRequests(ack_ids, self.message_handler_sub_path) - - def AckFlowProcessingRequests(self, ack_ids: [str]) -> None: - self.AckRequests(ack_ids, self.flow_processing_sub_path) - - def DeleteAllMessageHandlerRequests(self) -> None: - self.DeleteAllRequests(self.message_handler_sub_path) - - def DeleteAllFlowProcessingRequests(self) -> None: - self.DeleteAllRequests(self.flow_processing_sub_path) - - def LeaseMessageHandlerRequests(self, ack_ids: [str], ack_deadline: int) -> None: - self.subscriber.modify_ack_deadline( - request={ - "subscription": self.message_handler_sub_path, - "ack_ids": ack_ids, - "ack_deadline_seconds": ack_deadline, - } - ) - - def LeaseFlowProcessingRequests(self, ack_ids: [str], ack_deadline: int) -> None: - self.subscriber.modify_ack_deadline( - request={ - "subscription": self.flow_processing_sub_path, - "ack_ids": ack_ids, - "ack_deadline_seconds": ack_deadline, - } - ) - - def PublishRequests(self, requests: [str], top_path: str) -> None: - for req in requests: - self.publisher.publish(top_path, req) - - def AckRequests(self, ack_ids: [str], sub_path: str) -> None: - self.subscriber.acknowledge( - request={"subscription": sub_path, "ack_ids": ack_ids} - ) - - def DeleteAllRequests(self, sub_path: str) -> None: - client = pubsub_v1.SubscriberClient() - # Initialize request argument(s) - request = { - "subscription": sub_path, - "time": datetime.datetime.now(pytz.utc) + datetime.timedelta(days=30) - } - # Make the request - response = client.seek(request=request) - - def ReadRequests(self, sub_path: str, min_req: Optional[int] = None): - # Make the request - - start_time = time.time() - results = {} - want_more = True - while want_more or (time.time() - start_time < 2): - time.sleep(0.1) - - response = self.subscriber.pull( - request = { - "subscription": sub_path, - "max_messages": 10000, - }, - ) - for resp in response.received_messages: - results.update({resp.message.message_id: { - "payload": resp.message.data, - "msg_id": resp.message.message_id, - "ack_id": resp.ack_id, - "publish_time": resp.message.publish_time} - }) - if min_req and len(results) >= min_req: - want_more = False - - return results.values() - - def NewRequestQueue( - self, - queue: str, - callback: Callable[[Sequence[Any], bytes], None], - receiver_max_keepalive_seconds: Optional[int] = None, - receiver_max_active_callbacks: Optional[int] = None, - receiver_max_messages_per_callback: Optional[int] = None, - ) -> RequestQueue: - """Registers a queue callback in a given queue. - - Args: - queue: Name of the queue. - callback: Callback with 2 args (expanded_key, payload). expanded_key is a - sequence where each item corresponds to an item of the message's key. - Payload is the message itself, serialized as bytes. - receiver_max_keepalive_seconds: Num seconds before the lease on the - message expires (if the message is not acked before the lease expires, - it will be delivered again). - receiver_max_active_callbacks: Max number of callback to be called in - parallel. - receiver_max_messages_per_callback: Max messages to receive per callback. - - Returns: - New queue receiver objects. - """ - - def _Callback(message: pubsub_v1.subscriber.message.Message): - payload = message.data - callback(payload=payload, msg_id=message.message_id, ack_id=message.ack_id, - publish_time=message.publish_time) - - if queue == "MessageHandler" or queue == "": - subscription_path = self.message_handler_sub_path - elif queue == "FlowProcessing": - subscription_path = self.flow_processing_sub_path - - return RequestQueue( - self.subscriber, - subscription_path, - _Callback, - receiver_max_keepalive_seconds=receiver_max_keepalive_seconds, - receiver_max_active_callbacks=receiver_max_active_callbacks, - receiver_max_messages_per_callback=receiver_max_messages_per_callback, - ) From 8971141c14f4b4404b9b4c363f91d272a8730dba Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Fri, 20 Jun 2025 07:50:50 +0000 Subject: [PATCH 036/168] Adds Spanner config variables --- .../grr_response_core/config/data_store.py | 30 +++++++++++++++++++ .../grr_response_server/databases/spanner.py | 14 +++------ .../databases/spanner_test_lib.py | 2 +- .../databases/spanner_utils.py | 26 ++++++++-------- 4 files changed, 48 insertions(+), 24 deletions(-) diff --git a/grr/core/grr_response_core/config/data_store.py b/grr/core/grr_response_core/config/data_store.py index 7c96e8d10..9fe1619b5 100644 --- a/grr/core/grr_response_core/config/data_store.py +++ b/grr/core/grr_response_core/config/data_store.py @@ -121,3 +121,33 @@ "Only used when Blobstore.implementation is GCSBlobStore." ), ) + +# Spanner configuration. +config_lib.DEFINE_string( + "Spanner.project", + default=None, + help=( + "GCP Project where the Spanner Instance is located." + ), +) +config_lib.DEFINE_string( + "Spanner.instance", + default="grr-instance", + help="The Spanner Instance for GRR.") + +config_lib.DEFINE_string( + "Spanner.database", + default="grr-database", + help="The Spanner Database for GRR.") + +config_lib.DEFINE_integer( + "Spanner.flow_processing_threads_min", + default=1, + help="The minimum number of flow-processing worker threads.", +) + +config_lib.DEFINE_integer( + "Spanner.flow_processing_threads_max", + default=20, + help="The maximum number of flow-processing worker threads.", +) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index 1e2c21dfa..d9c42f321 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -54,8 +54,8 @@ def __init__(self, db: spanner_utils.Database) -> None: self.flow_processing_request_handler_stop = None self.flow_processing_request_handler_pool = threadpool.ThreadPool.Factory( "spanner_flow_processing_pool", - min_threads=config.CONFIG["Mysql.flow_processing_threads_min"], - max_threads=config.CONFIG["Mysql.flow_processing_threads_max"], + min_threads=config.CONFIG["Spanner.flow_processing_threads_min"], + max_threads=config.CONFIG["Spanner.flow_processing_threads_max"], ) @classmethod @@ -65,18 +65,12 @@ def FromConfig(cls) -> "Database": Returns: A GRR database instance. """ - project_id = config.CONFIG["ProjectID"] + project_id = config.CONFIG["Spanner.project_id"] spanner_client = Client(project_id) spanner_instance = spanner_client.instance(config.CONFIG["Spanner.instance"]) spanner_database = spanner_instance.database(config.CONFIG["Spanner.database"]) - msg_handler_top_id = config.CONFIG["MessageHandler.topic_id"] - msg_handler_sub_id = config.CONFIG["MessageHandler.subscription_id"] - flow_processing_top_id = config.CONFIG["FlowProcessing.topic_id"] - flow_processing_sub_id = config.CONFIG["FlowProcessing.subscription_id"] - return cls(spanner_utils.Database(spanner_database, project_id, - msg_handler_top_id, msg_handler_sub_id, - flow_processing_top_id, flow_processing_sub_id)) + return cls(spanner_utils.Database(spanner_database, project_id)) def Now(self) -> rdfvalue.RDFDatetime: """Retrieves current time as reported by the database.""" diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 2fba5f12d..299e91671 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -54,7 +54,7 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: if _TEST_DB is not None: raise AssertionError("Spanner test library already initialized") - project_id = _GetEnvironOrSkip("PROJECT_ID") + project_id = _GetEnvironOrSkip("SPANNER_PROJECT_ID") instance_id = _GetEnvironOrSkip("SPANNER_GRR_INSTANCE") database_id = _GetEnvironOrSkip("SPANNER_GRR_DATABASE") diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 7fb4728f0..078ab8695 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -58,19 +58,19 @@ def __init__(self, pyspanner: spanner_lib.database, project_id: str) -> None: self._pyspanner = pyspanner self.project_id = project_id - def Now(self) -> rdfvalue.RDFDatetime: - """Retrieves current time as reported by the database.""" - with self._pyspanner.snapshot() as snapshot: - timestamp = None - query = "SELECT CURRENT_TIMESTAMP() AS now" - results = snapshot.execute_sql(query) - for row in results: - timestamp = row[0] - return rdfvalue.RDFDatetime.FromDatetime(timestamp) - - def MinTimestamp(self) -> rdfvalue.RDFDatetime: - """Returns minimal timestamp allowed by the DB.""" - return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) +# def Now(self) -> rdfvalue.RDFDatetime: +# """Retrieves current time as reported by the database.""" +# with self._pyspanner.snapshot() as snapshot: +# timestamp = None +# query = "SELECT CURRENT_TIMESTAMP() AS now" +# results = snapshot.execute_sql(query) +# for row in results: +# timestamp = row[0] +# return rdfvalue.RDFDatetime.FromDatetime(timestamp) + +# def MinTimestamp(self) -> rdfvalue.RDFDatetime: +# """Returns minimal timestamp allowed by the DB.""" +# return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) def _parametrize(self, query: str, names: Iterable[str]) -> str: match = self._PYSPANNER_PARAM_REGEX.search(query) From 06632a310e7d34529f8727e6e5e331c73fc6eb9a Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Fri, 20 Jun 2025 10:56:18 +0000 Subject: [PATCH 037/168] Clean up Spanner config --- .../databases/registry_init.py | 2 ++ .../grr_response_server/databases/spanner.py | 2 +- .../databases/spanner_utils.py | 23 ------------------- 3 files changed, 3 insertions(+), 24 deletions(-) diff --git a/grr/server/grr_response_server/databases/registry_init.py b/grr/server/grr_response_server/databases/registry_init.py index 7d576bd48..c2affb156 100644 --- a/grr/server/grr_response_server/databases/registry_init.py +++ b/grr/server/grr_response_server/databases/registry_init.py @@ -3,9 +3,11 @@ from grr_response_server.databases import mem from grr_response_server.databases import mysql +from grr_response_server.databases import spanner # All available databases go into this registry. REGISTRY = { "InMemoryDB": mem.InMemoryDB, "MysqlDB": mysql.MysqlDB, + "SpannerDB": spanner.SpannerDB.FromConfig, } diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index d9c42f321..9a67e957d 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -65,7 +65,7 @@ def FromConfig(cls) -> "Database": Returns: A GRR database instance. """ - project_id = config.CONFIG["Spanner.project_id"] + project_id = config.CONFIG["Spanner.project"] spanner_client = Client(project_id) spanner_instance = spanner_client.instance(config.CONFIG["Spanner.instance"]) spanner_database = spanner_instance.database(config.CONFIG["Spanner.database"]) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 078ab8695..ca5269016 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -57,20 +57,6 @@ def __init__(self, pyspanner: spanner_lib.database, project_id: str) -> None: super().__init__() self._pyspanner = pyspanner self.project_id = project_id - -# def Now(self) -> rdfvalue.RDFDatetime: -# """Retrieves current time as reported by the database.""" -# with self._pyspanner.snapshot() as snapshot: -# timestamp = None -# query = "SELECT CURRENT_TIMESTAMP() AS now" -# results = snapshot.execute_sql(query) -# for row in results: -# timestamp = row[0] -# return rdfvalue.RDFDatetime.FromDatetime(timestamp) - -# def MinTimestamp(self) -> rdfvalue.RDFDatetime: -# """Returns minimal timestamp allowed by the DB.""" -# return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) def _parametrize(self, query: str, names: Iterable[str]) -> str: match = self._PYSPANNER_PARAM_REGEX.search(query) @@ -230,10 +216,6 @@ def ParamQuery( print(f"Warning for key '{key}': {e}. Setting type to None.") param_type[key] = None # Or re-raise, or handle differently - print("query: {}".format(query)) - print("params: {}".format(params)) - print("param_type: {}".format(param_type)) - with self._pyspanner.snapshot() as snapshot: results = snapshot.execute_sql( query, @@ -296,10 +278,6 @@ def ParamExecute( print(f"Warning for key '{key}': {e}. Setting type to None.") param_type[key] = None # Or re-raise, or handle differently - print("query: {}".format(query)) - print("params: {}".format(params)) - print("param_type: {}".format(param_type)) - def param_execute(transaction): row_ct = transaction.execute_update( query, @@ -307,7 +285,6 @@ def param_execute(transaction): param_types=param_type, ) - print("{} record(s) updated.".format(row_ct)) self._pyspanner.run_in_transaction(param_execute) def ExecutePartitioned( From 6ac79c7cb9a2dd118c5473a9dc6500dca6818ea0 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 23 Jun 2025 06:45:28 +0000 Subject: [PATCH 038/168] Fix user notification encoding --- grr/server/grr_response_server/databases/spanner_users.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 3d47788a6..50cf0e850 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -1,6 +1,6 @@ #!/usr/bin/env python """A library with user methods of Spanner database implementation.""" - +import base64 import datetime import logging import sys @@ -415,7 +415,7 @@ def WriteUserNotification( "Message": notification.message, } if notification.reference: - row["Reference"] = notification.reference.SerializeToString() + row["Reference"] = base64.b64encode(notification.reference.SerializeToString()) try: self.db.Insert( From dd5e077899bfb5af56041caef72d325fa23ca5f0 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 25 Jun 2025 09:20:37 +0000 Subject: [PATCH 039/168] Make repo owner configurable for container image upload --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 18b46c74a..7b84357cf 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,7 +3,7 @@ on: [push, pull_request] env: GCS_BUCKET: autobuilds.grr-response.com GCS_BUCKET_OPENAPI: autobuilds-grr-openapi - DOCKER_REPOSITORY: ghcr.io/google/grr + DOCKER_REPOSITORY: ghcr.io/${{ github.repository_owner }}/grr jobs: test-devenv: runs-on: ubuntu-22.04 From f8f12bd3fb72ae917fabed9336d661c37573e448 Mon Sep 17 00:00:00 2001 From: Dan Date: Mon, 7 Jul 2025 10:21:35 +0200 Subject: [PATCH 040/168] Fix testing compose manifest images --- compose.testing.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compose.testing.yaml b/compose.testing.yaml index 818901d9b..4bdcca486 100644 --- a/compose.testing.yaml +++ b/compose.testing.yaml @@ -1,14 +1,14 @@ services: grr-admin-ui: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test grr-fleetspeak-frontend: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test grr-worker: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test grr-client: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test privileged: true From 9e608ae169a1ea423a162296182ba06adf4d1013 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 8 Jul 2025 10:52:18 +0000 Subject: [PATCH 041/168] Propagate txn_tag --- .../databases/spanner_artifacts.py | 5 ++- .../databases/spanner_blob_references.py | 5 ++- .../databases/spanner_clients.py | 5 ++- .../databases/spanner_cron_jobs.py | 10 ++++-- .../databases/spanner_flows.py | 27 ++++++++++---- .../databases/spanner_hunts.py | 10 ++++-- .../databases/spanner_paths.py | 5 ++- .../databases/spanner_signed_binaries.py | 1 + .../databases/spanner_users.py | 5 ++- .../databases/spanner_utils.py | 36 ++++++++++--------- .../databases/spanner_yara.py | 5 ++- 11 files changed, 81 insertions(+), 33 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py index fd3906319..8dda15bc4 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -55,7 +55,10 @@ def ReadArtifact(self, name: str) -> Optional[artifact_pb2.Artifact]: UnknownArtifactError: when the artifact does not exist. """ try: - row = self.db.Read("Artifacts", key=[name], cols=("Platforms", "Payload")) + row = self.db.Read("Artifacts", + key=[name], + cols=("Platforms", "Payload"), + txn_tag="ReadArtifacts") except NotFound as error: raise db.UnknownArtifactError(name) from error diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 6fb1aa422..c7d6ee127 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -64,7 +64,10 @@ def ReadHashBlobReferences( hashes_left = set(hashes) for row in self.db.ReadSet( - table="HashBlobReferences", rows=rows, cols=("HashId", "BlobId", "Offset", "Size") + table="HashBlobReferences", + rows=rows, + cols=("HashId", "BlobId", "Offset", "Size"), + txn_tag="ReadHashBlobReferences" ): hash_id = rdf_objects.SHA256HashID(base64.b64decode(row[0])) diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index ab742fa96..04de4e1e1 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -95,7 +95,10 @@ def MultiReadClientMetadata( "FleetspeakValidationInfo", ) - for row in self.db.ReadSet(table="Clients", rows=keyset, cols=cols): + for row in self.db.ReadSet(table="Clients", + rows=keyset, + cols=cols, + txn_tag="MultiReadClientMetadata"): client_id = row[0] metadata = objects_pb2.ClientMetadata() diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index c3e92ca4d..6b58f31bb 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -470,7 +470,10 @@ def ReadCronJobRuns(self, job_id: str) -> Sequence[flows_pb2.CronJobRun]: rows = spanner_lib.KeySet(ranges=[rowrange]) res = [] - for row in self.db.ReadSet(table="CronJobRuns", rows=rows, cols=cols): + for row in self.db.ReadSet(table="CronJobRuns", + rows=rows, + cols=cols, + txn_tag="ReadCronJobRuns"): res.append( _CronJobRunFromRow( job_run=row[0], @@ -509,7 +512,10 @@ def ReadCronJobRun(self, job_id: str, run_id: str) -> flows_pb2.CronJobRun: "Backtrace", ] try: - row = self.db.Read(table="CronJobRuns", key=(job_id, run_id), cols=cols) + row = self.db.Read(table="CronJobRuns", + key=(job_id, run_id), + cols=cols, + txn_tag="ReadCronJobRun") except NotFound as error: raise db.UnknownCronJobRunError( "Run with job id %s and run id %s not found." % (job_id, run_id) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index a4c29e323..db04fef63 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -325,6 +325,7 @@ def ReadFlowObject( table="Flows", key=[client_id, flow_id], cols=_READ_FLOW_OBJECT_COLS, + txn_tag="ReadFlowObject" ) except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) @@ -1538,6 +1539,7 @@ def DeleteAllFlowRequestsAndResponses( self.db.DeleteWithPrefix( "FlowRequests", (client_id, flow_id), + txn_tag="DeleteAllFlowRequestsAndResponses" ) @db_utils.CallLogged @@ -1571,7 +1573,10 @@ def ReadAllFlowRequestsAndResponses( "CreationTime", ] requests = [] - for row in self.db.ReadSet(table="FlowRequests", rows=rows, cols=req_cols): + for row in self.db.ReadSet(table="FlowRequests", + rows=rows, + cols=req_cols, + txn_tag="ReadAllFlowRequestsAndResponses:FlowRequests"): request = flows_pb2.FlowRequest() request.ParseFromString(row[0]) request.needs_processing = row[1] @@ -1592,7 +1597,10 @@ def ReadAllFlowRequestsAndResponses( ] responses = {} for row in self.db.ReadSet( - table="FlowResponses", rows=rows, cols=resp_cols + table="FlowResponses", + rows=rows, + cols=resp_cols, + txn_tag="ReadAllFlowRequestsAndResponses:FlowResponses" ): if row[1] is not None: response = flows_pb2.FlowStatus() @@ -1706,7 +1714,10 @@ def ReadFlowRequests( "Iterator", "CreationTime", ] - for row in self.db.ReadSet(table="FlowResponses", rows=rows, cols=resp_cols): + for row in self.db.ReadSet(table="FlowResponses", + rows=rows, + cols=resp_cols, + txn_tag="ReadFlowRequests:FlowResponses"): if row[1]: response = flows_pb2.FlowStatus() response.ParseFromString(row[1]) @@ -1742,7 +1753,10 @@ def ReadFlowRequests( "CallbackState", "CreationTime", ] - for row in self.db.ReadSet(table="FlowRequests", rows=rows, cols=req_cols): + for row in self.db.ReadSet(table="FlowRequests", + rows=rows, + cols=req_cols, + txn_tag="ReadFlowRequests:FlowRequests"): request = flows_pb2.FlowRequest() request.ParseFromString(row[0]) request.needs_processing = row[1] @@ -2101,7 +2115,7 @@ def WriteScheduledFlow( } try: - self.db.InsertOrUpdate(table="ScheduledFlows", row=row) + self.db.InsertOrUpdate(table="ScheduledFlows", row=row, txn_tag="WriteScheduledFlow") except Exception as error: if "Parent row for row [" in str(error): raise db.UnknownClientError(scheduled_flow.client_id) from error @@ -2158,7 +2172,8 @@ def ListScheduledFlows( ] results = [] - for row in self.db.ReadSet("ScheduledFlows", rows, cols): + for row in self.db.ReadSet("ScheduledFlows", rows, cols, + txn_tag="ListScheduledFlows"): sf = flows_pb2.ScheduledFlow() sf.client_id = row[0] sf.creator = row[1] diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index fb93c9493..494a4ab2f 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -276,7 +276,10 @@ def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: ] try: - row = self.db.Read(table="Hunts", key=[hunt_id], cols=cols) + row = self.db.Read(table="Hunts", + key=[hunt_id], + cols=cols, + txn_tag="ReadHuntObject") except NotFound as e: raise abstract_db.UnknownHuntError(hunt_id) from e @@ -856,7 +859,10 @@ def ReadHuntOutputPluginsStates( """Reads all hunt output plugins states of a given hunt.""" # Make sure the hunt is there. try: - self.db.Read(table="Hunts", key=[hunt_id,], cols=("HuntId",)) + self.db.Read(table="Hunts", + key=[hunt_id,], + cols=("HuntId",), + txn_tag="ReadHuntOutputPluginsStates") except NotFound as e: raise abstract_db.UnknownHuntError(hunt_id) from e diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index 70a83c027..22e78ac66 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -32,7 +32,10 @@ def WritePathInfos( # mutations. We still have to validate the client id. if not path_infos: try: - self.db.Read(table="Clients", key=[client_id], cols=(["ClientId"])) + self.db.Read(table="Clients", + key=[client_id], + cols=(["ClientId"]), + txn_tag="WritePathInfos") except NotFound as error: raise abstract_db.UnknownClientError(client_id) from error return diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries.py b/grr/server/grr_response_server/databases/spanner_signed_binaries.py index 1deb02075..0fffca22a 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_binaries.py +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries.py @@ -65,6 +65,7 @@ def ReadSignedBinaryReferences( table="SignedBinaries", key=(binary_type, binary_id.path), cols=("BlobReferences", "CreationTime"), + txn_tag="ReadSignedBinaryReferences" ) except NotFound as error: raise db.UnknownSignedBinaryError(binary_id) from error diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 50cf0e850..8ae7cd570 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -84,7 +84,10 @@ def ReadGRRUser(self, username: str) -> objects_pb2.GRRUser: """Reads a user object corresponding to a given name.""" cols = ("Email", "Password", "Type", "CanaryMode", "UiMode") try: - row = self.db.Read(table="Users", key=[username], cols=cols) + row = self.db.Read(table="Users", + key=[username], + cols=cols, + txn_tag="ReadGRRUser") except NotFound as error: raise abstract_db.UnknownGRRUserError(username) from error diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index ca5269016..21e50ef31 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -131,7 +131,7 @@ def Transact( Returns: The result of the transaction function executed. """ - return self._pyspanner.run_in_transaction(func) + return self._pyspanner.run_in_transaction(func, transaction_tag=txn_tag) def Mutate( self, func: Callable[["Mutation"], None], txn_tag: Optional[str] = None @@ -156,7 +156,7 @@ def Query(self, query: str, txn_tag: Optional[str] = None) -> Cursor: A cursor over the query results. """ with self._pyspanner.snapshot() as snapshot: - results = snapshot.execute_sql(query) + results = snapshot.execute_sql(query, request_options={"request_tag": txn_tag}) return results @@ -221,6 +221,7 @@ def ParamQuery( query, params=params, param_types=param_type, + request_options={"request_tag": txn_tag} ) return results @@ -285,7 +286,7 @@ def param_execute(transaction): param_types=param_type, ) - self._pyspanner.run_in_transaction(param_execute) + self._pyspanner.run_in_transaction(param_execute, transaction_tag=txn_tag) def ExecutePartitioned( self, query: str, txn_tag: Optional[str] = None @@ -305,12 +306,8 @@ def ExecutePartitioned( Returns: Nothing. """ - query_options = None - if txn_tag is not None: - query_options = spanner_lib.QueryOptions() - query_options.SetTag(txn_tag) - - return self._pyspanner.execute_partitioned_dml(query) + return self._pyspanner.execute_partitioned_dml(query, + request_options={"request_tag": txn_tag}) def Insert( self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None @@ -330,7 +327,7 @@ def Insert( columns = list(columns) values = list(values) - with self._pyspanner.batch() as batch: + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: batch.insert( table=table, columns=columns, @@ -355,7 +352,7 @@ def Update( columns = list(columns) values = list(values) - with self._pyspanner.batch() as batch: + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: batch.update( table=table, columns=columns, @@ -380,7 +377,7 @@ def InsertOrUpdate( columns = list(columns) values = list(values) - with self._pyspanner.batch() as batch: + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: batch.insert_or_update( table=table, columns=columns, @@ -403,10 +400,11 @@ def Delete( keyset = KeySet(all_=True) if key: keyset = KeySet(keys=[key]) - with self._pyspanner.batch() as batch: + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: batch.delete(table, keyset) - def DeleteWithPrefix(self, table: str, key_prefix: Sequence[Any]) -> None: + def DeleteWithPrefix(self, table: str, key_prefix: Sequence[Any], + txn_tag: Optional[str] = None) -> None: """Deletes a range of rows with common key prefix from the given table. Args: @@ -419,7 +417,7 @@ def DeleteWithPrefix(self, table: str, key_prefix: Sequence[Any]) -> None: range = KeyRange(start_closed=key_prefix, end_closed=key_prefix) keyset = KeySet(ranges=[range]) - with self._pyspanner.batch() as batch: + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: batch.delete(table, keyset) def Read( @@ -427,6 +425,7 @@ def Read( table: str, key: Sequence[Any], cols: Sequence[str], + txn_tag: Optional[str] = None ) -> Mapping[str, Any]: """Read a single row with the given key from the specified table. @@ -443,7 +442,8 @@ def Read( results = snapshot.read( table=table, columns=cols, - keyset=keyset + keyset=keyset, + request_options={"request_tag": txn_tag} ) return results.one() @@ -452,6 +452,7 @@ def ReadSet( table: str, rows: KeySet, cols: Sequence[str], + txn_tag: Optional[str] = None ) -> Iterator[Mapping[str, Any]]: """Read a set of rows from the specified table. @@ -467,7 +468,8 @@ def ReadSet( results = snapshot.read( table=table, columns=cols, - keyset=rows + keyset=rows, + request_options={"request_tag": txn_tag} ) return results diff --git a/grr/server/grr_response_server/databases/spanner_yara.py b/grr/server/grr_response_server/databases/spanner_yara.py index 167667595..882dc2954 100644 --- a/grr/server/grr_response_server/databases/spanner_yara.py +++ b/grr/server/grr_response_server/databases/spanner_yara.py @@ -51,7 +51,10 @@ def VerifyYaraSignatureReference( key = (base64.b64encode(bytes(blob_id)),) try: - self.db.Read(table="YaraSignatureReferences", key=key, cols=("BlobId",)) + self.db.Read(table="YaraSignatureReferences", + key=key, + cols=("BlobId",), + txn_tag="VerifyYaraSignatureReference") except NotFound: return False From 53a36c00338072a2626ff4410ac5a191a35e4b2a Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Tue, 8 Jul 2025 15:02:58 +0000 Subject: [PATCH 042/168] Propagate txn_tag --- .../databases/spanner_blob_references.py | 2 +- .../databases/spanner_clients.py | 4 +-- .../databases/spanner_cron_jobs.py | 12 ++++--- .../databases/spanner_flows.py | 36 +++++++++---------- .../databases/spanner_hunts.py | 4 ++- .../databases/spanner_utils.py | 4 +-- 6 files changed, 34 insertions(+), 28 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index c7d6ee127..769bbb1b9 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -38,7 +38,7 @@ def Mutation(mut) -> None: mut.insert( table="HashBlobReferences", columns=("HashId", "BlobId", "Offset", "Size",), - values=[(hash_id_b64, base64.b64encode(bytes(ref.blob_id)), ref.offset, ref.size,)], + values=[(hash_id_b64, base64.b64encode(bytes(ref.blob_id)), ref.offset, ref.size,)] ) self.db.Mutate(Mutation, txn_tag="WriteHashBlobReferences") diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index 04de4e1e1..e09fbfbc9 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -384,13 +384,13 @@ def Mutation(mut: spanner_utils.Mutation) -> None: mut.update( table="Clients", columns=("ClientId", "LastRRGStartuptime"), - values=[(client_id, spanner_lib.COMMIT_TIMESTAMP)], + values=[(client_id, spanner_lib.COMMIT_TIMESTAMP)] ) mut.insert( table="ClientRRGStartups", columns=("ClientId", "CreationTime", "Startup"), - values=[(client_id, spanner_lib.COMMIT_TIMESTAMP, startup)], + values=[(client_id, spanner_lib.COMMIT_TIMESTAMP, startup)] ) try: diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index 6b58f31bb..5a6f61697 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -237,7 +237,8 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: query_ids_to_update += "AND cj.JobId IN UNNEST(@cronjob_ids)" params_ids_to_update["cronjob_ids"] = cronjob_ids - response = txn.execute_sql(sql=query_ids_to_update, params=params_ids_to_update) + response = txn.execute_sql(sql=query_ids_to_update, params=params_ids_to_update, + request_options={"request_tag": "LeaseCronJobs:CronJobs:execute_sql"}) ids_to_update = [] for (job_id,) in response: @@ -260,7 +261,8 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: "ids_to_update": ids_to_update, } - txn.execute_update(update_query, update_params) + txn.execute_update(update_query, update_params, + request_options={"request_tag": "LeaseCronJobs:CronJobs:execute_update"}) # --------------------------------------------------------------------- # Query (and return) jobs that were updated @@ -366,7 +368,8 @@ def Transaction(txn) -> Sequence[flows_pb2.CronJob]: "ids_to_return": ids_to_return, } - txn.execute_update(update_query, update_params) + txn.execute_update(update_query, update_params, + request_options={"request_tag": "ReturnLeasedCronJobs:CronJobs:execute_update"}) # --------------------------------------------------------------------- # Query (and return) jobs that were updated @@ -557,7 +560,8 @@ def Transaction(txn) -> int: for job_id, run_id in rows: keyset = spanner_lib.KeySet(keys=[[job_id, run_id]]) - txn.delete(table="CronJobRuns", keyset=keyset) + txn.delete(table="CronJobRuns", keyset=keyset, + request_options={"request_tag": "DeleteOldCronJobRuns:CronJobRuns:delete"}) return len(rows) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index db04fef63..41ecf43f5 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -806,7 +806,7 @@ def Txn(txn) -> None: keyset = spanner_lib.KeySet(keys=keys) txn.delete(table="FlowProcessingRequests", keyset=keyset) - self.db.Transact(Txn) + self.db.Transact(Txn, txn_tag="AckFlowProcessingRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -817,7 +817,7 @@ def Txn(txn) -> None: keyset = spanner_lib.KeySet(all_=True) txn.delete(table="FlowProcessingRequests", keyset=keyset) - self.db.Transact(Txn) + self.db.Transact(Txn, txn_tag="DeleteAllFlowProcessingRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -878,7 +878,7 @@ def Txn(txn) -> None: return res - return self.db.Transact(Txn) + return self.db.Transact(Txn, txn_tag="_LeaseFlowProcessingRequests") _FLOW_REQUEST_POLL_TIME_SECS = 3 @@ -1071,7 +1071,7 @@ def _ReadRequestsInfo( "RequestID", "CallbackState", "ExpectedResponseCount", - ], + ] ): request_key = _RequestKey( @@ -1793,10 +1793,10 @@ def Txn(txn) -> None: txn.update( table="FlowRequests", columns=columns, - values=rows, + values=rows ) - self.db.Transact(Txn) + self.db.Transact(Txn, txn_tag="UpdateIncrementalFlowRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -1813,7 +1813,7 @@ def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: row["HuntId"] = entry.hunt_id try: - self.db.Insert(table="FlowLogEntries", row=row) + self.db.Insert(table="FlowLogEntries", row=row, txn_tag="WriteFlowLogEntry") except NotFound as error: raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error @@ -1993,7 +1993,7 @@ def WriteFlowOutputPluginLogEntry( row["HuntId"] = entry.hunt_id try: - self.db.Insert(table="FlowOutputPluginLogEntries", row=row) + self.db.Insert(table="FlowOutputPluginLogEntries", row=row, txn_tag="WriteFlowOutputPluginLogEntry") except NotFound as error: raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error @@ -2147,7 +2147,7 @@ def Transaction(txn) -> None: txn.delete(table="ScheduledFlows", keyset=keyset) - self.db.Transact(Transaction) + self.db.Transact(Transaction, txn_tag="DeleteScheduledFlow") @db_utils.CallLogged @db_utils.CallAccounted @@ -2301,7 +2301,7 @@ def Txn(txn) -> None: return res - return self.db.Transact(Txn) + return self.db.Transact(Txn, txn_tag="_LeaseMessageHandlerRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -2322,7 +2322,7 @@ def Mutation(mut) -> None: ]) mut.insert(table="MessageHandlerRequests", columns=columns, values=rows) - self.db.Transact(Mutation) + self.db.Transact(Mutation, txn_tag="WriteMessageHandlerRequests") @db_utils.CallLogged @db_utils.CallAccounted @@ -2391,7 +2391,7 @@ def Txn(txn) -> flows_pb2.Flow: row = txn.read( table="Flows", keyset=spanner_lib.KeySet(keys=[[client_id, flow_id]]), - columns=_READ_FLOW_OBJECT_COLS, + columns=_READ_FLOW_OBJECT_COLS ).one() except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) @@ -2447,7 +2447,7 @@ def Txn2(txn) -> flows_pb2.Flow: row = txn.read( table="Flows", keyset=spanner_lib.KeySet(keys=[[client_id, flow_id]]), - columns=_READ_FLOW_OBJECT_COLS, + columns=_READ_FLOW_OBJECT_COLS ).one() flow = _ParseReadFlowObjectRow(client_id, flow_id, row) print(flow) @@ -2455,8 +2455,8 @@ def Txn2(txn) -> flows_pb2.Flow: raise db.UnknownFlowError(client_id, flow_id, cause=error) return flow - leased_flow = self.db.Transact(Txn) - flow = self.db.Transact(Txn2) + leased_flow = self.db.Transact(Txn, txn_tag="LeaseFlowForProcessing") + flow = self.db.Transact(Txn2, txn_tag="LeaseFlowForProcessing2") leased_flow.processing_since = flow.processing_since return leased_flow @@ -2470,7 +2470,7 @@ def Txn(txn) -> bool: row = txn.read( table="FlowRequests", keyset=spanner_lib.KeySet(keys=[[flow_obj.client_id, flow_obj.flow_id, flow_obj.next_request_to_process]]), - columns=["NeedsProcessing", "StartTime"], + columns=["NeedsProcessing", "StartTime"] ).one() if row[0]: start_time = row[1] @@ -2496,8 +2496,8 @@ def Txn(txn) -> bool: flow_obj.next_request_to_process, spanner_lib.COMMIT_TIMESTAMP, flow_obj.num_replies_sent, - ]], + ]] ) return True - return self.db.Transact(Txn) + return self.db.Transact(Txn, txn_tag="ReleaseProcessedFlow") diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index 494a4ab2f..0dee90e85 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -201,7 +201,8 @@ def _UpdateHuntObject( WHERE h.HuntId = @hunt_id """ - txn.execute_update(query, params=params, param_types=params_types) + txn.execute_update(query, params=params, param_types=params_types, + request_options={"request_tag": "_UpdateHuntObject:Hunts:execute_update"}) @db_utils.CallLogged @db_utils.CallAccounted @@ -935,6 +936,7 @@ def Txn(txn) -> None: table="HuntOutputPlugins", keyset=spanner_lib.KeySet(keys=[[hunt_id, state_index]]), columns=["Name", "Args", "State"], + request_options={"request_tag": "UpdateHuntOutputPluginState:HuntOutputPlugins:read"} ).one() state = _HuntOutputPluginStateFromRow( row[0], row[1], row[2] diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 21e50ef31..2889d43bb 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -283,10 +283,10 @@ def param_execute(transaction): row_ct = transaction.execute_update( query, params=params, - param_types=param_type, + param_types=param_type ) - self._pyspanner.run_in_transaction(param_execute, transaction_tag=txn_tag) + self._pyspanner.run_in_transaction(param_execute) def ExecutePartitioned( self, query: str, txn_tag: Optional[str] = None From 934ba71a8241c85ae7d628b8d8ec1cfe7d229149 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Sat, 12 Jul 2025 13:45:28 +0000 Subject: [PATCH 043/168] Adds spanner emulator to test env --- .github/workflows/build.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7b84357cf..af8ba67d4 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,6 +4,10 @@ env: GCS_BUCKET: autobuilds.grr-response.com GCS_BUCKET_OPENAPI: autobuilds-grr-openapi DOCKER_REPOSITORY: ghcr.io/${{ github.repository_owner }}/grr + SPANNER_EMULATOR_HOST: localhost:9010 + SPANNER_DATABASE: grr-database + SPANNER_INSTANCE: grr-instance + SPANNER_PROJECT: spanner-emulator-project jobs: test-devenv: runs-on: ubuntu-22.04 @@ -44,6 +48,19 @@ jobs: fi python3 -m venv --system-site-packages "${HOME}/INSTALL" travis/install.sh + - name: 'Install Cloud SDK' + uses: google-github-actions/setup-gcloud@v2 + with: + install_components: 'beta,pubsub-emulator,cloud-spanner-emulator' + - name: 'Start Spanner emulator' + run: | + gcloud config configurations create emulator + gcloud config set auth/disable_credentials true + gcloud config set project ${{ env.SPANNER_PROJECT }} + gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ + gcloud emulators spanner start & + gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ + --config=emulator-config --description="Spanner Test Instance" --nodes=1 - name: Test run: | source "${HOME}/INSTALL/bin/activate" From af901a0ca5229b552474782723dd73c8e9991b69 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Sun, 13 Jul 2025 12:51:45 +0200 Subject: [PATCH 044/168] Update build.yml Changing env var SPANNER_PROJECT to SPANNER_PROJECT_ID --- .github/workflows/build.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index af8ba67d4..15d673200 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,7 +7,7 @@ env: SPANNER_EMULATOR_HOST: localhost:9010 SPANNER_DATABASE: grr-database SPANNER_INSTANCE: grr-instance - SPANNER_PROJECT: spanner-emulator-project + SPANNER_PROJECT_ID: spanner-emulator-project jobs: test-devenv: runs-on: ubuntu-22.04 @@ -56,7 +56,7 @@ jobs: run: | gcloud config configurations create emulator gcloud config set auth/disable_credentials true - gcloud config set project ${{ env.SPANNER_PROJECT }} + gcloud config set project ${{ env.SPANNER_PROJECT_ID }} gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ gcloud emulators spanner start & gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ From cc159badd5a9265c1cca198219c27492167aadb6 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 04:59:37 +0000 Subject: [PATCH 045/168] Update build steps with Spanner Instance creation --- .github/workflows/build.yml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 15d673200..39707059f 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -59,8 +59,13 @@ jobs: gcloud config set project ${{ env.SPANNER_PROJECT_ID }} gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ gcloud emulators spanner start & - gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ - --config=emulator-config --description="Spanner Test Instance" --nodes=1 + sudo apt-get update + sudo apt install -y protobuf-compiler libprotobuf-dev + # Verify that we can create the Spanner Instance and Database + cd grr/server/grr_response_server/databases + ./spanner_setup.sh + # Remove the Database for a clean test env + gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} - name: Test run: | source "${HOME}/INSTALL/bin/activate" From 09df46d2a382d511d5d283c6a3acd55fae8f16e3 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 05:52:01 +0000 Subject: [PATCH 046/168] Fix Spanner Env variable names --- .github/workflows/build.yml | 3 +++ grr/server/grr_response_server/databases/spanner_setup.sh | 4 ++-- grr/server/grr_response_server/databases/spanner_test_lib.py | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 39707059f..7942e5993 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -59,11 +59,14 @@ jobs: gcloud config set project ${{ env.SPANNER_PROJECT_ID }} gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ gcloud emulators spanner start & + gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ + --config=emulator-config --description="Spanner Test Instance" --nodes=1 sudo apt-get update sudo apt install -y protobuf-compiler libprotobuf-dev # Verify that we can create the Spanner Instance and Database cd grr/server/grr_response_server/databases ./spanner_setup.sh + cd - # Remove the Database for a clean test env gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} - name: Test diff --git a/grr/server/grr_response_server/databases/spanner_setup.sh b/grr/server/grr_response_server/databases/spanner_setup.sh index 22a3a6406..cbd27d9fd 100755 --- a/grr/server/grr_response_server/databases/spanner_setup.sh +++ b/grr/server/grr_response_server/databases/spanner_setup.sh @@ -22,7 +22,7 @@ if [ ! -f ./spanner_grr.pb ]; then fi echo "2/3 : Creating GRR database on Spanner..." -gcloud spanner databases create ${SPANNER_GRR_DATABASE} --instance ${SPANNER_GRR_INSTANCE} +gcloud spanner databases create ${SPANNER_DATABASE} --instance ${SPANNER_INSTANCE} echo "3/3 : Creating tables ..." -gcloud spanner databases ddl update ${SPANNER_GRR_DATABASE} --instance=${SPANNER_GRR_INSTANCE} --ddl-file=spanner.sdl --proto-descriptors-file=spanner_grr.pb \ No newline at end of file +gcloud spanner databases ddl update ${SPANNER_DATABASE} --instance=${SPANNER_INSTANCE} --ddl-file=spanner.sdl --proto-descriptors-file=spanner_grr.pb \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 299e91671..0a1463296 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -55,8 +55,8 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: raise AssertionError("Spanner test library already initialized") project_id = _GetEnvironOrSkip("SPANNER_PROJECT_ID") - instance_id = _GetEnvironOrSkip("SPANNER_GRR_INSTANCE") - database_id = _GetEnvironOrSkip("SPANNER_GRR_DATABASE") + instance_id = _GetEnvironOrSkip("SPANNER_INSTANCE") + database_id = _GetEnvironOrSkip("SPANNER_DATABASE") spanner_client = Client(project_id) database_admin_api = spanner_client.database_admin_api From bc72cbd536f2dc02e6899d5e34e4346fcc11f6d9 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 06:02:58 +0000 Subject: [PATCH 047/168] Run database removal with quiet flag --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7942e5993..baa8bebaa 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -68,7 +68,7 @@ jobs: ./spanner_setup.sh cd - # Remove the Database for a clean test env - gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} + gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} --quiet - name: Test run: | source "${HOME}/INSTALL/bin/activate" From 18506feb2b73934ec15b28afbb198849a1ff14b6 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 11:30:29 +0000 Subject: [PATCH 048/168] Adds PROJECT_ID env var --- .github/workflows/build.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index baa8bebaa..ff6a926ff 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -8,6 +8,7 @@ env: SPANNER_DATABASE: grr-database SPANNER_INSTANCE: grr-instance SPANNER_PROJECT_ID: spanner-emulator-project + PROJECT_ID: spanner-emulator-project jobs: test-devenv: runs-on: ubuntu-22.04 From 8233c2bea29946084857abc8862fc3debc1f7e96 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 12:59:37 +0000 Subject: [PATCH 049/168] Creates unique database name --- grr/server/grr_response_server/databases/spanner_test_lib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 0a1463296..61a514ad3 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -1,6 +1,7 @@ """A library with utilities for testing the Spanner database implementation.""" import os import unittest +import uuid from typing import Optional @@ -56,7 +57,7 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: project_id = _GetEnvironOrSkip("SPANNER_PROJECT_ID") instance_id = _GetEnvironOrSkip("SPANNER_INSTANCE") - database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + str(uuid.uuid4()) spanner_client = Client(project_id) database_admin_api = spanner_client.database_admin_api From 0bc717f4f03358ff433c06fc0015f8b6cb752c0e Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 13:18:23 +0000 Subject: [PATCH 050/168] Creates unique database name --- grr/server/grr_response_server/databases/spanner_test_lib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 61a514ad3..51be52c37 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -1,7 +1,7 @@ """A library with utilities for testing the Spanner database implementation.""" import os +import random import unittest -import uuid from typing import Optional @@ -57,7 +57,7 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: project_id = _GetEnvironOrSkip("SPANNER_PROJECT_ID") instance_id = _GetEnvironOrSkip("SPANNER_INSTANCE") - database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + str(uuid.uuid4()) + database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + "-" + random.randint(1, 100000) spanner_client = Client(project_id) database_admin_api = spanner_client.database_admin_api From ecce62c273899ccc714935eff51c59d55428127e Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 13:32:48 +0000 Subject: [PATCH 051/168] Creates unique database name --- grr/server/grr_response_server/databases/spanner_test_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 51be52c37..b011d3f1b 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -57,7 +57,7 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: project_id = _GetEnvironOrSkip("SPANNER_PROJECT_ID") instance_id = _GetEnvironOrSkip("SPANNER_INSTANCE") - database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + "-" + random.randint(1, 100000) + database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + "-" + str(random.randint(1, 100000)) spanner_client = Client(project_id) database_admin_api = spanner_client.database_admin_api From 3bd6006f07169d1dc72e474bb2c0b4b90b347619 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 14:06:00 +0000 Subject: [PATCH 052/168] Reset the test db --- grr/server/grr_response_server/databases/spanner_test_lib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index b011d3f1b..d0a09699e 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -101,6 +101,7 @@ def TearDown() -> None: if _TEST_DB is not None: # Create a client _TEST_DB.drop() + _TEST_DB = None class TestCase(absltest.TestCase): From b5e4a457079fe25f473fe3463cf2c31f40e2e7c6 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Mon, 14 Jul 2025 14:27:52 +0000 Subject: [PATCH 053/168] Reset the test db --- grr/server/grr_response_server/databases/spanner_test_lib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index d0a09699e..9bda12f15 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -98,6 +98,8 @@ def TearDown() -> None: This must be called once per process after all the tests. A `tearDownModule` is a perfect place for it. """ + global _TEST_DB + if _TEST_DB is not None: # Create a client _TEST_DB.drop() From 098cf516a7779097ae82985d3b217ca7ed29e182 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Fri, 18 Jul 2025 13:46:23 +0000 Subject: [PATCH 054/168] Several cosmetic improvements --- .../grr_response_server/databases/spanner.sdl | 4 +-- .../databases/spanner_clients.py | 2 +- .../databases/spanner_flows.py | 3 +- .../databases/spanner_flows_large_test.py | 36 +++---------------- .../databases/spanner_flows_test.py | 3 +- .../databases/spanner_signed_commands.py | 6 ---- 6 files changed, 9 insertions(+), 45 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl index a9eaaabd8..b23e021bb 100644 --- a/grr/server/grr_response_server/databases/spanner.sdl +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -496,8 +496,8 @@ CREATE TABLE HashBlobReferences( CREATE TABLE ApiAuditEntry ( Username STRING(256) NOT NULL, CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), - HttpRequestPath String(MAX) NOT NULL, - RouterMethodName String(256) NOT NULL, + HttpRequestPath STRING(MAX) NOT NULL, + RouterMethodName STRING(256) NOT NULL, ResponseCode `grr.APIAuditEntry.Code` NOT NULL, ) PRIMARY KEY (Username, CreationTime); diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index e09fbfbc9..61396a00e 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -383,7 +383,7 @@ def WriteClientRRGStartup( def Mutation(mut: spanner_utils.Mutation) -> None: mut.update( table="Clients", - columns=("ClientId", "LastRRGStartuptime"), + columns=("ClientId", "LastRRGStartupTime"), values=[(client_id, spanner_lib.COMMIT_TIMESTAMP)] ) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 41ecf43f5..44db616bd 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2450,7 +2450,6 @@ def Txn2(txn) -> flows_pb2.Flow: columns=_READ_FLOW_OBJECT_COLS ).one() flow = _ParseReadFlowObjectRow(client_id, flow_id, row) - print(flow) except NotFound as error: raise db.UnknownFlowError(client_id, flow_id, cause=error) return flow @@ -2487,7 +2486,7 @@ def Txn(txn) -> bool: table="Flows", columns=["ClientId", "FlowId", "Flow", "State", "UserCpuTimeUsed", "SystemCpuTimeUsed", "NetworkBytesSent", "ProcessingWorker", - "ProcessingStartTime", "ProcessingEndTime", "NextRequesttoProcess", + "ProcessingStartTime", "ProcessingEndTime", "NextRequestToProcess", "UpdateTime", "ReplyCount"], values=[[flow_obj.client_id, flow_obj.flow_id,flow_obj, int(flow_obj.flow_state), float(flow_obj.cpu_time_used.user_cpu_time), diff --git a/grr/server/grr_response_server/databases/spanner_flows_large_test.py b/grr/server/grr_response_server/databases/spanner_flows_large_test.py index 143ac3c96..60ce031be 100644 --- a/grr/server/grr_response_server/databases/spanner_flows_large_test.py +++ b/grr/server/grr_response_server/databases/spanner_flows_large_test.py @@ -14,40 +14,12 @@ def tearDownModule() -> None: class SpannerDatabaseFlowsTest( - db_flows_test.DatabaseLargeTestFlowMixin, spanner_test_lib.TestCase + db_flows_test.DatabaseLargeTestFlowMixin, + spanner_test_lib.TestCase ): # Test methods are defined in the base mixin class. - # To cleanup the database we use `DeleteWithPrefix` (to do multiple deletions - # within a single mutation) but this method is super slow for cleaning up huge - # amounts of data. Thus, for certain methods that populate the database with - # a lot of rows we manually clean up using the `DELETE` DML statement which is - # faster in such cases. - - def test40001RequestsCanBeWrittenAndRead(self): - super().test40001RequestsCanBeWrittenAndRead() - - db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error - db.ExecutePartitioned("DELETE FROM FlowRequests WHERE TRUE") - - def test40001ResponsesCanBeWrittenAndRead(self): - super().test40001ResponsesCanBeWrittenAndRead() - - db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error - db.ExecutePartitioned("DELETE FROM FlowResponses WHERE TRUE") - - def testWritesAndCounts40001FlowResults(self): - super().testWritesAndCounts40001FlowResults() - - db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error - db.ExecutePartitioned("DELETE FROM FlowResults WHERE TRUE") - - def testWritesAndCounts40001FlowErrors(self): - super().testWritesAndCounts40001FlowErrors() - - db: spanner_utils.Database = self.db.delegate.db # pytype: disable=attribute-error - db.ExecutePartitioned("DELETE FROM FlowErrors WHERE TRUE") - + pass # Test methods are defined in the base mixin class. if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_flows_test.py b/grr/server/grr_response_server/databases/spanner_flows_test.py index 1e981ef08..fe5462247 100644 --- a/grr/server/grr_response_server/databases/spanner_flows_test.py +++ b/grr/server/grr_response_server/databases/spanner_flows_test.py @@ -24,6 +24,5 @@ class SpannerDatabaseFlowsTest( """Spanner flow tests.""" pass # Test methods are defined in the base mixin class. - if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands.py b/grr/server/grr_response_server/databases/spanner_signed_commands.py index bd1c7a06a..5b05c9a66 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_commands.py +++ b/grr/server/grr_response_server/databases/spanner_signed_commands.py @@ -97,12 +97,6 @@ def ReadSignedCommands( FROM SignedCommands AS c """ - query = """ - SELECT - c.Id, c.OperatingSystem, c.Ed25519Signature, c.Command - FROM - SignedCommands AS c - """ signed_commands = [] for ( command_id, From 4d461018ee9705953b98de09b4e188d08f45b235 Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 27 Aug 2025 15:10:12 +0000 Subject: [PATCH 055/168] Adds fixes to pass Spanner tests and installation instructions --- .github/workflows/build.yml | 10 +-- .gitignore | 1 + grr/proto/grr_response_proto/flows.proto | 1 - grr/proto/grr_response_proto/objects.proto | 1 - .../grr_response_server/databases/spanner.md | 80 +++++++++++++++++++ .../grr_response_server/databases/spanner.py | 2 +- .../databases/spanner_flows.py | 10 +++ 7 files changed, 97 insertions(+), 8 deletions(-) create mode 100644 grr/server/grr_response_server/databases/spanner.md diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ff6a926ff..18b209b6e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -4,11 +4,11 @@ env: GCS_BUCKET: autobuilds.grr-response.com GCS_BUCKET_OPENAPI: autobuilds-grr-openapi DOCKER_REPOSITORY: ghcr.io/${{ github.repository_owner }}/grr - SPANNER_EMULATOR_HOST: localhost:9010 - SPANNER_DATABASE: grr-database - SPANNER_INSTANCE: grr-instance - SPANNER_PROJECT_ID: spanner-emulator-project - PROJECT_ID: spanner-emulator-project + #SPANNER_EMULATOR_HOST: localhost:9010 + #SPANNER_DATABASE: grr-database + #SPANNER_INSTANCE: grr-instance + #SPANNER_PROJECT_ID: spanner-emulator-project + #PROJECT_ID: spanner-emulator-project jobs: test-devenv: runs-on: ubuntu-22.04 diff --git a/.gitignore b/.gitignore index b9e317e2d..767dcd311 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,4 @@ compose.watch.yaml Dockerfile.client grr/server/grr_response_server/databases/spanner_grr.pb +spanner-setup.txt diff --git a/grr/proto/grr_response_proto/flows.proto b/grr/proto/grr_response_proto/flows.proto index 5586a0577..52b6cadff 100644 --- a/grr/proto/grr_response_proto/flows.proto +++ b/grr/proto/grr_response_proto/flows.proto @@ -2287,7 +2287,6 @@ message FlowProcessingRequest { optional uint64 creation_time = 4 [(sem_type) = { type: "RDFDatetime", }]; - optional string ack_id = 5; } message FlowRequest { diff --git a/grr/proto/grr_response_proto/objects.proto b/grr/proto/grr_response_proto/objects.proto index 318a486c7..58440aafd 100644 --- a/grr/proto/grr_response_proto/objects.proto +++ b/grr/proto/grr_response_proto/objects.proto @@ -331,7 +331,6 @@ message MessageHandlerRequest { optional uint64 leased_until = 5 [(sem_type) = { type: "RDFDatetime" }]; optional string leased_by = 6; optional EmbeddedRDFValue request = 7; - optional string ack_id = 8; } message SerializedValueOfUnrecognizedType { diff --git a/grr/server/grr_response_server/databases/spanner.md b/grr/server/grr_response_server/databases/spanner.md new file mode 100644 index 000000000..65c0465e4 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner.md @@ -0,0 +1,80 @@ +# GRR on Spanner + +When operating GRR you need to decide on a persistence store. + +This document describes how to use [Spanner](https://cloud.google.com/spanner) +as the GRR datastore. + +Spanner is a fully managed, mission-critical database service on +[Google Cloud](https://cloud.google.com) that brings together relational, graph, +key-value, and search. It offers transactional consistency at global scale, +automatic, synchronous replication for high availability. + +## 1. Google Cloud Resources + +Running GRR on Spanner requires that you create and configure a +[Spanner instance](https://cloud.google.com/spanner/docs/instances) before you +run GRR. + +Furthermore, you also need to create a +[Google Cloud Storage](https://cloud.google.com/storage) +[Bucket](https://cloud.google.com/storage/docs/buckets) that +will serve as the GRR blobstore. + +## 2. Google Cloud Spanner Instance + +You can follow the instructions in the Google Cloud online documentation to +[create a Spanner instance](https://cloud.google.com/spanner/docs/create-query-database-console#create-instance). + +> [!NOTE] You only need to create the +> [Spanner instance](https://cloud.google.com/spanner/docs/instances). The +> GRR [Spanner database](https://cloud.google.com/spanner/docs/databases) +> and its tables are created by running the provided [spanner_setup.sh](./spanner_setup.sh) +> script. The script assumes that you use `grr-instance` as the +> GRR instance name and `grr-database` as the GRR database name. In +> case you want to use different values then you need to update the +> [spanner_setup.sh](./spanner_setup.sh) script accordingly. +> The script assumes that you have the +> [gcloud](https://cloud.google.com/sdk/docs/install) and the +> [protoc](https://protobuf.dev/installation/) binaries installed on your machine. + +Run the following command to create the GRR database and its tables: + +```bash +export PROJECT_ID= +export SPANNER_INSTANCE=grr-instance +export SPANNER_DATABASE=grr-database +./spanner_setup.sh +``` + +## 3. GRR Configuration + +To run GRR on Spanner you need to configure the components settings with +the values of the Google Cloud Spanner and the GCS Bucket resources mentioned above. + +The snippet below illustrates a sample GRR `server.yaml` configuration. + +```bash + Database.implementation: SpannerDB + Spanner.project: + Spanner.instance: grr-instance + Spanner.database: grr-database + Blobstore.implementation: GCSBlobStore + Blobstore.gcs.project: + Blobstore.gcs.bucket: +``` + +> [!NOTE] Make sure you remove all the `Mysql` related configuration items. + +## 4. IAM Permissions + +This guide assumes that your GRR instance is running on the [Google Kubernetes Engine](https://cloud.google.com/kubernetes-engine) (GKE) +and you can leverage [Workload Identity Federation for GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity) (WIF). + +Using WIF you can assign the required IAM roles using the WIF principal +`principal://iam.googleapis.com/projects//locations/global/workloadIdentityPools//sa/` +where `K8S_NAMESPACE` is the value of the Kubernetes Namespace and `K8S_SERVICE_ACCOUNT` is the value Kubernetes Service Account that your GRR Pods are running under. + +The two IAM roles that are required are: +- `roles/spanner.databaseUser` on your Spanner Database and +- `roles/storage.objectUser` on our GCS Bucket (the GRR Blobstore mentioned above). \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index 9a67e957d..7f9732abd 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -45,7 +45,7 @@ class SpannerDB( def __init__(self, db: spanner_utils.Database) -> None: """Initializes the database.""" self.db = db - self._write_rows_batch_size = 10000 + self._write_rows_batch_size = 100 self.handler_thread = None self.handler_stop = True diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 44db616bd..75fc7217d 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -480,6 +480,11 @@ def UpdateFlow( @db_utils.CallAccounted def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: """Writes flow results for a given flow.""" + for batch in collection.Batch(results, self._write_rows_batch_size): + self._WriteFlowResults(batch) + + def _WriteFlowResults(self, results: Sequence[flows_pb2.FlowError]) -> None: + """Writes flow errors for a given flow.""" def Mutation(mut) -> None: rows = [] @@ -503,6 +508,11 @@ def Mutation(mut) -> None: @db_utils.CallAccounted def WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: """Writes flow errors for a given flow.""" + for batch in collection.Batch(errors, self._write_rows_batch_size): + self._WriteFlowErrors(batch) + + def _WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: + """Writes flow errors for a given flow.""" def Mutation(mut) -> None: rows = [] From 26480894ba9994333be85fa09d7289ae355f726f Mon Sep 17 00:00:00 2001 From: Dan Aschwanden Date: Wed, 27 Aug 2025 15:31:58 +0000 Subject: [PATCH 056/168] Removes emulator test instructions --- .github/workflows/build.yml | 42 ++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 18b209b6e..542d900fe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -49,27 +49,27 @@ jobs: fi python3 -m venv --system-site-packages "${HOME}/INSTALL" travis/install.sh - - name: 'Install Cloud SDK' - uses: google-github-actions/setup-gcloud@v2 - with: - install_components: 'beta,pubsub-emulator,cloud-spanner-emulator' - - name: 'Start Spanner emulator' - run: | - gcloud config configurations create emulator - gcloud config set auth/disable_credentials true - gcloud config set project ${{ env.SPANNER_PROJECT_ID }} - gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ - gcloud emulators spanner start & - gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ - --config=emulator-config --description="Spanner Test Instance" --nodes=1 - sudo apt-get update - sudo apt install -y protobuf-compiler libprotobuf-dev - # Verify that we can create the Spanner Instance and Database - cd grr/server/grr_response_server/databases - ./spanner_setup.sh - cd - - # Remove the Database for a clean test env - gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} --quiet + #- name: 'Install Cloud SDK' + # uses: google-github-actions/setup-gcloud@v2 + # with: + # install_components: 'beta,pubsub-emulator,cloud-spanner-emulator' + #- name: 'Start Spanner emulator' + # run: | + # gcloud config configurations create emulator + # gcloud config set auth/disable_credentials true + # gcloud config set project ${{ env.SPANNER_PROJECT_ID }} + # gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ + # gcloud emulators spanner start & + # gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ + # --config=emulator-config --description="Spanner Test Instance" --nodes=1 + # sudo apt-get update + # sudo apt install -y protobuf-compiler libprotobuf-dev + # # Verify that we can create the Spanner Instance and Database + # cd grr/server/grr_response_server/databases + # ./spanner_setup.sh + # cd - + # # Remove the Database for a clean test env + # gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} --quiet - name: Test run: | source "${HOME}/INSTALL/bin/activate" From 6734876628cce742279a9b584b9be1ed4a7de185 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 11 Nov 2025 14:23:53 +0100 Subject: [PATCH 057/168] Update grr/server/grr_response_server/databases/spanner.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index 7f9732abd..6530f2cf0 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -1,4 +1,8 @@ +#!/usr/bin/env python +"""Spanner implementation of the GRR relational database abstraction. +See grr/server/db.py for interface. +""" from google.cloud.spanner import Client from grr_response_core import config From bcbb208d3e33fa734f2bec2228cf6e9003280d3f Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 11 Nov 2025 14:24:24 +0100 Subject: [PATCH 058/168] Update grr/server/grr_response_server/databases/spanner.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index 6530f2cf0..bbfeb1db2 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -63,7 +63,7 @@ def __init__(self, db: spanner_utils.Database) -> None: ) @classmethod - def FromConfig(cls) -> "Database": + def FromConfig(cls) -> "SpannerDB": """Creates a GRR database instance for Spanner path specified in the config. Returns: From 6fcb253bdde6f15f6aa963376aab686d8e15f612 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 11 Nov 2025 14:24:36 +0100 Subject: [PATCH 059/168] Update grr/server/grr_response_server/databases/spanner_artifacts.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py index 8dda15bc4..12874301e 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -97,7 +97,7 @@ def DeleteArtifact(self, name: str) -> None: UnknownArtifactError when the artifact does not exist. """ def Transaction(txn) -> None: - # Spanner does not raise if we attept to delete a non-existing row so + # Spanner does not raise if we attempt to delete a non-existing row so # we check it exists ourselves. keyset = spanner_lib.KeySet(keys=[[name],]) From d4c6a592c3d97cf4fa0d273d9a1cebdb00071a81 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 11 Nov 2025 14:24:58 +0100 Subject: [PATCH 060/168] Update grr/server/grr_response_server/databases/spanner_artifacts.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_artifacts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py index 12874301e..b797242f3 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -108,4 +108,4 @@ def Transaction(txn) -> None: txn.delete("Artifacts", keyset) - self.db.Transact(Transaction, txn_tag="DeleteArtifact") \ No newline at end of file + self.db.Transact(Transaction, txn_tag="DeleteArtifact") From d04b82cd511a22440ca7b6d73edf0ff355cc99ae Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 11 Nov 2025 14:25:12 +0100 Subject: [PATCH 061/168] Update grr/server/grr_response_server/databases/spanner_artifacts.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_artifacts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py index b797242f3..0afcc496d 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -38,7 +38,6 @@ def WriteArtifact(self, artifact: artifact_pb2.Artifact) -> None: self.db.Insert(table="Artifacts", row=row, txn_tag="WriteArtifact") except AlreadyExists as error: raise db.DuplicatedArtifactError(name) from error - @db_utils.CallLogged @db_utils.CallAccounted From 5c967a49a27cfed4ee42ecd370b1654c71df5822 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 11 Nov 2025 14:25:26 +0100 Subject: [PATCH 062/168] Update grr/server/grr_response_server/databases/spanner_yara_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_yara_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_yara_test.py b/grr/server/grr_response_server/databases/spanner_yara_test.py index 0bf0ddb38..616191b1c 100644 --- a/grr/server/grr_response_server/databases/spanner_yara_test.py +++ b/grr/server/grr_response_server/databases/spanner_yara_test.py @@ -20,4 +20,4 @@ class SpannerDatabaseYaraTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 5c5b2dc61d8e7bcc861ba824f31fa7f0b3503abe Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:07:59 +0100 Subject: [PATCH 063/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 2889d43bb..9b611bf4d 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -178,7 +178,7 @@ def QuerySingle(self, query: str, txn_tag: Optional[str] = None) -> Row: def ParamQuery( self, query: str, params: Mapping[str, Any], - param_type: Optional[dict] = {}, txn_tag: Optional[str] = None + param_type: Optional[dict] = None, txn_tag: Optional[str] = None ) -> Cursor: """Queries PySpanner database using the given query string with params. From 2611ad8bacf99f9e7b86802c23aa3c3e96487210 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:08:08 +0100 Subject: [PATCH 064/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 9b611bf4d..049b829e7 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -205,6 +205,8 @@ def ParamQuery( ValueError: If the query contains disallowed sequences. KeyError: If some parameter is not specified. """ + if not param_type: + param_type = {} names, values = collection.Unzip(params.items()) query = self._parametrize(query, names) From aff64ce6f7a6538cfccb850fe373ce1f60446afb Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:08:20 +0100 Subject: [PATCH 065/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 049b829e7..faee9ff8b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -230,7 +230,7 @@ def ParamQuery( def ParamQuerySingle( self, query: str, params: Mapping[str, Any], - param_type: Optional[dict] = {}, txn_tag: Optional[str] = None + param_type: Optional[dict] = None, txn_tag: Optional[str] = None ) -> Row: """Queries the database for a single row using with a query with params. From ef4c12d85ddc3565b30d6cdffd042d4b9b45ae7e Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:08:43 +0100 Subject: [PATCH 066/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 75fc7217d..1db61ea1c 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -980,7 +980,7 @@ def Txn(txn) -> None: rows.append([r.client_id, r.flow_id, str(r.request_id), r.needs_processing, str(r.next_response_id), r.callback_state, r, spanner_lib.COMMIT_TIMESTAMP, start_time]) - txn.insert_or_update(table="FlowRequests", columns=columns, values=rows) + txn.insert_or_update(table="FlowRequests", columns=columns, values=rows) if needs_processing: flow_processing_requests = [] From 1e256bcc534864df3ecd824f253fc108934abea6 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:09:08 +0100 Subject: [PATCH 067/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 1db61ea1c..cfbd3b909 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -483,7 +483,7 @@ def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: for batch in collection.Batch(results, self._write_rows_batch_size): self._WriteFlowResults(batch) - def _WriteFlowResults(self, results: Sequence[flows_pb2.FlowError]) -> None: + def _WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: """Writes flow errors for a given flow.""" def Mutation(mut) -> None: From 540ef4d820ea423561dd9d27dd2e983ad8273849 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:09:29 +0100 Subject: [PATCH 068/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index cfbd3b909..1907f7d3e 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -1151,7 +1151,7 @@ def _BuildResponseWrites( rows.append([r.client_id, r.flow_id, str(r.request_id), str(r.response_id), response,status,iterator,spanner_lib.COMMIT_TIMESTAMP]) - txn.insert_or_update(table="FlowResponses", columns=columns, values=rows) + txn.insert_or_update(table="FlowResponses", columns=columns, values=rows) def _BuildExpectedUpdates( self, updates: dict[_RequestKey, int], txn From 1e45c725e799628d40771b1dbfa1680886c29561 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:09:41 +0100 Subject: [PATCH 069/168] Update grr/server/grr_response_server/databases/spanner_events.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_events.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_events.py b/grr/server/grr_response_server/databases/spanner_events.py index 6975ddb30..1a64d8b22 100644 --- a/grr/server/grr_response_server/databases/spanner_events.py +++ b/grr/server/grr_response_server/databases/spanner_events.py @@ -96,7 +96,7 @@ def CountAPIAuditEntriesByUserAndDay( a.Username, TIMESTAMP_TRUNC(a.CreationTime, DAY, "UTC") AS day, COUNT(*) - FROM APIAuditEntry AS a + FROM ApiAuditEntry AS a """ params = {} From 99e17ffffe04bd996b4e63db0c95ca7e35121ffd Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:10:00 +0100 Subject: [PATCH 070/168] Update grr/server/grr_response_server/databases/spanner_paths.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_paths.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index 22e78ac66..6205e5a4a 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -120,7 +120,7 @@ def ReadPathInfos( query = """ SELECT p.Path, p.CreationTime, p.IsDir, ps.CreationTime, ps.Stat, - ph.CreationTime, ph.FileHash, + ph.CreationTime, ph.FileHash FROM Paths AS p LEFT JOIN PathFileStats AS ps ON p.ClientId = ps.ClientId From 81ee1a7e18d321017089caa8a3e82689d794c57d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:10:25 +0100 Subject: [PATCH 071/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 1907f7d3e..737f3be3e 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2332,7 +2332,7 @@ def Mutation(mut) -> None: ]) mut.insert(table="MessageHandlerRequests", columns=columns, values=rows) - self.db.Transact(Mutation, txn_tag="WriteMessageHandlerRequests") + self.db.Mutate(Mutation, txn_tag="WriteMessageHandlerRequests") @db_utils.CallLogged @db_utils.CallAccounted From 840b849aed495b3ec5612cb23d56a222112de1dd Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:10:48 +0100 Subject: [PATCH 072/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 737f3be3e..420ff1def 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2263,7 +2263,6 @@ def _LeaseMessageHandlerRequests( def Txn(txn) -> None: # Read the message handler requests waiting for leases - keyset = spanner_lib.KeySet(all_=True) params = { "limit": limit, "now": now.AsDatetime() From 704a5d4a25b37a7a7f6be5038cc25884cd2a3b41 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:10:59 +0100 Subject: [PATCH 073/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_cron_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index 5a6f61697..e24260152 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -181,7 +181,7 @@ def DeleteCronJob(self, cronjob_id: str) -> None: UnknownCronJobError: A cron job with the given id does not exist. """ def Transaction(txn) -> None: - # Spanner does not raise if we attept to delete a non-existing row so + # Spanner does not raise if we attempt to delete a non-existing row so # we check it exists ourselves. keyset = spanner_lib.KeySet(keys=[[cronjob_id]]) try: From 60668c6dc3e815cc0838f890c42dad057f08fbf6 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 07:11:13 +0100 Subject: [PATCH 074/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 420ff1def..e9b4f0b3a 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -839,7 +839,6 @@ def _LeaseFlowProcessingRequests( expiry = now + rdfvalue.Duration.From(10, rdfvalue.MINUTES) def Txn(txn) -> None: - keyset = spanner_lib.KeySet(all_=True) params = { "limit": limit, "now": now.AsDatetime() From 9cbe6e1159d62098ebc49e070e8cbaf401b6b7e9 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:46:52 +0100 Subject: [PATCH 075/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index faee9ff8b..854bc01fe 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -205,7 +205,7 @@ def ParamQuery( ValueError: If the query contains disallowed sequences. KeyError: If some parameter is not specified. """ - if not param_type: + if not param_type: param_type = {} names, values = collection.Unzip(params.items()) query = self._parametrize(query, names) From b02c6e211f8e4379bf573b1dbfb8f6fc9432bcc7 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:47:04 +0100 Subject: [PATCH 076/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 854bc01fe..3085d5048 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -282,10 +282,11 @@ def ParamExecute( param_type[key] = None # Or re-raise, or handle differently def param_execute(transaction): - row_ct = transaction.execute_update( + transaction.execute_update( query, params=params, - param_types=param_type + param_types=param_type, + request_options={"request_tag": txn_tag}, ) self._pyspanner.run_in_transaction(param_execute) From d37bd6a1777429f18a16a876abff64addef247a3 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:47:14 +0100 Subject: [PATCH 077/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 3085d5048..600741819 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -207,7 +207,7 @@ def ParamQuery( """ if not param_type: param_type = {} - names, values = collection.Unzip(params.items()) + names, _ = collection.Unzip(params.items()) query = self._parametrize(query, names) for key, value in params.items(): From 743105ce5cf89c7b5bed7745e280912cc36532d9 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:47:28 +0100 Subject: [PATCH 078/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 600741819..cb5347729 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -41,7 +41,13 @@ Cursor = Iterator[Row] _T = TypeVar("_T") +class Mutation(_Mutation): + """A wrapper around the PySpanner Mutation class.""" + pass +class Transaction(_Transaction): + """A wrapper around the PySpanner Transaction class.""" + pass class Database: """A wrapper around the PySpanner class. From ebab5aadfdfb23549662e5632e5c338c9dfa8795 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:47:39 +0100 Subject: [PATCH 079/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index cb5347729..78627f60d 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -59,7 +59,7 @@ class Database: _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") - def __init__(self, pyspanner: spanner_lib.database, project_id: str) -> None: + def __init__(self, pyspanner: spanner_lib.Database, project_id: str) -> None: super().__init__() self._pyspanner = pyspanner self.project_id = project_id From c3c0faddbb9ba43f9736ebfc1193535396f3f18d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:47:56 +0100 Subject: [PATCH 080/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 78627f60d..ed11a17b7 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -22,7 +22,9 @@ from concurrent import futures -from google.cloud import spanner_v1 as spanner_lib +from google.cloud.spanner_v1 import database as spanner_lib +from google.cloud.spanner_v1.transaction import Transaction as _Transaction +from google.cloud.spanner_v1.batch import _BatchBase as _Mutation from google.cloud.spanner import KeyRange, KeySet from google.cloud.spanner_admin_database_v1.types import spanner_database_admin From 8a92ea767b9cb8588806f06080c1fc3565603989 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:48:10 +0100 Subject: [PATCH 081/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index ed11a17b7..01878820a 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -278,7 +278,7 @@ def ParamExecute( ValueError: If the query contains disallowed sequences. KeyError: If some parameter is not specified. """ - names, values = collection.Unzip(params.items()) + names, _ = collection.Unzip(params.items()) query = self._parametrize(query, names) param_type = {} From 6f3efafe7bb1284195f26c3257c26d95a9fc50b4 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:48:23 +0100 Subject: [PATCH 082/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 01878820a..862f02075 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -40,7 +40,7 @@ from grr_response_proto import objects_pb2 Row = Tuple[Any, ...] -Cursor = Iterator[Row] +Cursor = Iterable[Row] _T = TypeVar("_T") class Mutation(_Mutation): From 03db3b60af7cacf8b7aaf4027a05362bd9a8fcaf Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 18 Nov 2025 15:48:37 +0100 Subject: [PATCH 083/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 862f02075..3f6e7a7d2 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -28,7 +28,7 @@ from google.cloud.spanner import KeyRange, KeySet from google.cloud.spanner_admin_database_v1.types import spanner_database_admin -from google.cloud.spanner_v1 import Mutation, param_types +from google.cloud.spanner_v1 import param_types from google.rpc.code_pb2 import OK From 785168c3e5e9882c8ad083fe0262f322d39a4ed0 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:01:19 +0100 Subject: [PATCH 084/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 3f6e7a7d2..d792087e6 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -471,6 +471,7 @@ def ReadSet( table: A name of the table to read from. rows: A set of keys specifying which rows to read. cols: Columns of the row to read. + txn_tag: Spanner transaction tag. Returns: Mappings from columns to values of the rows read. From 30a66dd08839cafee75546ba2b9a8a2ac127d0ac Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:01:30 +0100 Subject: [PATCH 085/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index d792087e6..a6475100f 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -444,6 +444,7 @@ def Read( table: A name of the table to read from. key: A key of the row to read. cols: Columns of the row to read. + txn_tag: Spanner transaction tag. Returns: A mapping from columns to values of the read row. From 6ad847d5c876ae1baf4704a8c1da547649f0a54d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:01:44 +0100 Subject: [PATCH 086/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index a6475100f..8ab36c383 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -420,7 +420,8 @@ def DeleteWithPrefix(self, table: str, key_prefix: Sequence[Any], Args: table: A table from which rows are to be deleted. - key: A sequence of value denoting the prefix of the key of rows to delete. + key_prefix: A sequence of value denoting the prefix of the key of rows to delete. + txn_tag: Spanner transaction tag. Returns: Nothing. From 731ba1e85b168f5b2e5540270995e5bf0ebf68f9 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:02:19 +0100 Subject: [PATCH 087/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 8ab36c383..0b329c87e 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -279,7 +279,7 @@ def ParamExecute( KeyError: If some parameter is not specified. """ names, _ = collection.Unzip(params.items()) - query = self._parametrize(query, names) + query = self._parametrize(query, names) param_type = {} for key, value in params.items(): From 97b7b1a8aae16dfa5b62069c467aa859b6250faa Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:02:29 +0100 Subject: [PATCH 088/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 0b329c87e..b45917544 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -289,8 +289,8 @@ def ParamExecute( print(f"Warning for key '{key}': {e}. Setting type to None.") param_type[key] = None # Or re-raise, or handle differently - def param_execute(transaction): - transaction.execute_update( + def param_execute(txn: Transaction): + txn.execute_update( query, params=params, param_types=param_type, From fc66c22db3315cc40fdf862a5d6bff249391ce14 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:02:39 +0100 Subject: [PATCH 089/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index b45917544..b0d907fdc 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -216,7 +216,7 @@ def ParamQuery( if not param_type: param_type = {} names, _ = collection.Unzip(params.items()) - query = self._parametrize(query, names) + query = self._parametrize(query, names) for key, value in params.items(): if key not in param_type: From ce1324046d7c56a912897e36f64be0e239030dac Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:02:50 +0100 Subject: [PATCH 090/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index b0d907fdc..d53cedc63 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -148,7 +148,7 @@ def Mutate( Args: func: A mutation function to execute. - txn_tag: Optional[str] = None, + txn_tag: Spanner transaction tag. """ self.Transact(func, txn_tag=txn_tag) From 3dbec745618a6beb9586758989adce357eaaf82a Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:03:16 +0100 Subject: [PATCH 091/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index e9b4f0b3a..962164495 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2371,9 +2371,8 @@ def DeleteMessageHandlerRequests( for r in requests: request_ids.append(str(r.request_id)) params={"request_ids": request_ids} - param_type={"request_ids": param_types.Array(param_types.STRING)} - self.db.ParamExecute(query, params, param_type) + self.db.ParamExecute(query, params, txn_tag="DeleteMessageHandlerRequests") def _ReadHuntState( self, txn, hunt_id: str From 85a887887127d9179f9e35130f39b4c5607c321c Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:03:28 +0100 Subject: [PATCH 092/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index 8975d502e..c7e2bd36b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -305,8 +305,8 @@ def Mutation(txn) -> None: def testMutateException(self): - def Mutation(txn) -> None: - txn.insert( + def Mutation(mut) -> None: + mut.insert( table="Table", columns=("Key",), values=[("foo",)] From b1d7cbdf38e1843aafe45201ea09be6746dea3d5 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:03:43 +0100 Subject: [PATCH 093/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index d53cedc63..8696986e0 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -88,6 +88,9 @@ def _get_param_type(self, value): A google.cloud.spanner_v1.types.Type object, or None if the type cannot be reliably inferred (e.g., for a standalone None value or an empty list). + + Raises: + TypeError: Raised for any unsupported type or empty container value. """ if value is None: # Cannot determine a specific Spanner type from a None value alone. From 428cd1248f11c416eb99878a4b5b8d638ded5cea Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:03:57 +0100 Subject: [PATCH 094/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 8696986e0..c6460c81a 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -131,7 +131,7 @@ def Transact( self, func: Callable[["Transaction"], _T], txn_tag: Optional[str] = None, - ) -> List[Any]: + ) -> _T: """Execute the given callback function in a Spanner transaction. From 21cbfc729dbb953c9c55f35e11911af7be960b77 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:04:04 +0100 Subject: [PATCH 095/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index c6460c81a..badfccb38 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -441,7 +441,7 @@ def Read( key: Sequence[Any], cols: Sequence[str], txn_tag: Optional[str] = None - ) -> Mapping[str, Any]: + ) -> Row: """Read a single row with the given key from the specified table. Args: From e2f5cbbdde602a1bffc6be7089a2ab7e2eaa9c08 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:04:12 +0100 Subject: [PATCH 096/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index badfccb38..4a3c5566b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -469,7 +469,7 @@ def ReadSet( rows: KeySet, cols: Sequence[str], txn_tag: Optional[str] = None - ) -> Iterator[Mapping[str, Any]]: + ) -> Cursor: """Read a set of rows from the specified table. Args: From ada8f5f3a79d4a6f9a17ea692af73df1efaa4f12 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:04:26 +0100 Subject: [PATCH 097/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index c7e2bd36b..d8dd795ef 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -252,7 +252,7 @@ def testReadSimple(self): self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) result = self.raw_db.Read(table="Table", key=("foo",), cols=("Column",)) - self.assertEqual(result, ['foo@x.com']) + self.assertEqual(result, ["foo@x.com"]) def testReadNotExisting(self): with self.assertRaises(NotFound): From 71f87054eae725d1919abda528308dc73b009ce3 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 25 Nov 2025 10:04:36 +0100 Subject: [PATCH 098/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index d8dd795ef..17e0b6b20 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -286,13 +286,13 @@ def testReadSetSimple(self): ####################################### def testMutateSimple(self): - def Mutation(txn) -> None: - txn.insert( + def Mutation(mut) -> None: + mut.insert( table="Table", columns=("Key",), values=[("foo",)] ) - txn.insert( + mut.insert( table="Table", columns=("Key",), values=[("bar",)] From f35b1c86e51ec45c780fa51803d7801c5644e601 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Wed, 26 Nov 2025 05:51:21 +0100 Subject: [PATCH 099/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_cron_jobs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index e24260152..a59ac47af 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -560,8 +560,10 @@ def Transaction(txn) -> int: for job_id, run_id in rows: keyset = spanner_lib.KeySet(keys=[[job_id, run_id]]) - txn.delete(table="CronJobRuns", keyset=keyset, - request_options={"request_tag": "DeleteOldCronJobRuns:CronJobRuns:delete"}) + txn.delete( + table="CronJobRuns", + keyset=keyset, + ) return len(rows) From 51865aa8536a1cc534696ca9561a13c2021e724e Mon Sep 17 00:00:00 2001 From: daschwanden Date: Wed, 26 Nov 2025 05:51:33 +0100 Subject: [PATCH 100/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_cron_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index a59ac47af..3b7bf5ac9 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -567,7 +567,7 @@ def Transaction(txn) -> int: return len(rows) - return self.db.Transact(Transaction, txn_tag="DeleteOldCronJobRuns").value + return self.db.Transact(Transaction, txn_tag="DeleteOldCronJobRuns") def _SelectCronJobsWith( self, From b4efb659d13c8aeb4ed673da2ad6a4c3e7a48871 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Thu, 27 Nov 2025 07:15:17 +0100 Subject: [PATCH 101/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 962164495..a06d12841 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -951,6 +951,13 @@ def UnregisterFlowProcessingHandler( def WriteFlowRequests( self, requests: Collection[flows_pb2.FlowRequest], + ) -> None: + for batch in collection.Batch(requests, self._write_rows_batch_size): + self._WriteFlowRequests(batch) + + def _WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], ) -> None: """Writes a list of flow requests to the database.""" From 3d3f7dfa9e62ba444102da399a0c83516f4d6a1b Mon Sep 17 00:00:00 2001 From: daschwanden Date: Thu, 4 Dec 2025 07:23:37 +0100 Subject: [PATCH 102/168] Update grr/server/grr_response_server/databases/db_blob_references_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/db_blob_references_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/grr/server/grr_response_server/databases/db_blob_references_test.py b/grr/server/grr_response_server/databases/db_blob_references_test.py index 20f85fed6..babf33c0d 100644 --- a/grr/server/grr_response_server/databases/db_blob_references_test.py +++ b/grr/server/grr_response_server/databases/db_blob_references_test.py @@ -82,10 +82,7 @@ def testMultipleHashBlobReferencesCanBeWrittenAndReadBack(self): def testWriteHashBlobHandlesLargeAmountsOfData(self): hash_id_blob_refs = {} - # Limit to 16k records to stay within Spanner 80k mutation/commit limit - # https://cloud.google.com/spanner/quotas#limits-for - # 16k records * 5 columns = 80k mutations/commit - for _ in range(16000): + for _ in range(50000): hash_id = rdf_objects.SHA256HashID(os.urandom(32)) blob_ref = objects_pb2.BlobReference() From bbdff8b162b3080f30f91b29714eb4b5ed4a9b19 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Thu, 4 Dec 2025 07:25:23 +0100 Subject: [PATCH 103/168] Update grr/server/grr_response_server/databases/spanner_blob_references.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_blob_references.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 769bbb1b9..01db91248 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -16,6 +16,7 @@ class BlobReferencesMixin: """A Spanner database mixin with implementation of blob references methods.""" db: spanner_utils.Database + BATCH_SIZE = 16000 @db_utils.CallLogged @db_utils.CallAccounted @@ -24,6 +25,27 @@ def WriteHashBlobReferences( references_by_hash: Mapping[ rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] ], + ) -> None: + """Writes blob references for a given set of hashes.""" + batch = dict() + for k, v in references_by_hash.items(): + batch[k] = v + if len(batch) == self.BATCH_SIZE: + self._WriteHashBlobReferences(batch) + batch = dict() + if batch: + self._WriteHashBlobReferences(batch) + + def _WriteHashBlobReferences( + self, + references_by_hash: Mapping[ + rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] + ], + ) -> None: + self, + references_by_hash: Mapping[ + rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] + ], ) -> None: """Writes blob references for a given set of hashes.""" def Mutation(mut) -> None: From 8815af2b6641c59e9fee7feca8edef9ec35c449c Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:52:21 +0100 Subject: [PATCH 104/168] Update grr/server/grr_response_server/databases/spanner_blob_references.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_blob_references.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 01db91248..d9b6b63aa 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -41,11 +41,6 @@ def _WriteHashBlobReferences( references_by_hash: Mapping[ rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] ], - ) -> None: - self, - references_by_hash: Mapping[ - rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] - ], ) -> None: """Writes blob references for a given set of hashes.""" def Mutation(mut) -> None: From 8550c9da684db5846da9293ef6c7dfcb8da5746d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:52:34 +0100 Subject: [PATCH 105/168] Update grr/server/grr_response_server/databases/spanner_blob_references.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_blob_references.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index d9b6b63aa..83235a981 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -57,7 +57,6 @@ def Mutation(mut) -> None: columns=("HashId", "BlobId", "Offset", "Size",), values=[(hash_id_b64, base64.b64encode(bytes(ref.blob_id)), ref.offset, ref.size,)] ) - self.db.Mutate(Mutation, txn_tag="WriteHashBlobReferences") From 09499b5ea955fd11b27fc12cac6d6cdcb3c3769d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:52:45 +0100 Subject: [PATCH 106/168] Update grr/server/grr_response_server/databases/spanner_blob_references.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_blob_references.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 83235a981..780176262 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -70,7 +70,7 @@ def ReadHashBlobReferences( """Reads blob references of a given set of hashes.""" result = {} key_ranges = [] - + for h in hashes: hash_id_b64 = base64.b64encode(bytes(h.AsBytes())) key_ranges.append(spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,])) From e19ed85fe3fac82b1272b68d3804ae7067f18fdc Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:52:55 +0100 Subject: [PATCH 107/168] Update grr/server/grr_response_server/databases/spanner_blob_references.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_blob_references.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 780176262..22eedf612 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -75,7 +75,6 @@ def ReadHashBlobReferences( hash_id_b64 = base64.b64encode(bytes(h.AsBytes())) key_ranges.append(spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,])) result[h] = [] - rows = spanner_lib.KeySet(ranges=key_ranges) hashes_left = set(hashes) From b746fd36b504de50cd548d8171c4a8d3a26d4407 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:53:28 +0100 Subject: [PATCH 108/168] Update grr/server/grr_response_server/databases/spanner_artifacts_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_artifacts_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_artifacts_test.py b/grr/server/grr_response_server/databases/spanner_artifacts_test.py index 12a6a9e4e..4e1eeb9f1 100644 --- a/grr/server/grr_response_server/databases/spanner_artifacts_test.py +++ b/grr/server/grr_response_server/databases/spanner_artifacts_test.py @@ -19,4 +19,4 @@ class SpannerDatabaseArtifactsTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From a3cfda39ad1cf639a27c4d3e2128d40ca6fac74c Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:53:39 +0100 Subject: [PATCH 109/168] Update grr/server/grr_response_server/databases/spanner_test_lib.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_test_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 9bda12f15..963f2c8ec 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -150,4 +150,4 @@ def _clean_database() -> None: for table_name in table_names: batch.delete(table_name, keyset) -_TEST_DB: spanner_lib.database = None \ No newline at end of file +_TEST_DB: spanner_lib.database = None From 12021e397592c33369ddc35446eead374e520a5f Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:53:48 +0100 Subject: [PATCH 110/168] Update grr/server/grr_response_server/databases/spanner_test_lib.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_test_lib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py index 963f2c8ec..58d7b2ae9 100644 --- a/grr/server/grr_response_server/databases/spanner_test_lib.py +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -90,7 +90,6 @@ def Init(sdl_path: str, proto_bundle: bool) -> None: instance = spanner_client.instance(instance_id) _TEST_DB = instance.database(database_id) - def TearDown() -> None: """Tears down the Spanner testing environment. From 7a6e7bbccce68519b1fbd78c3027d43004d9c035 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:53:59 +0100 Subject: [PATCH 111/168] Update grr/server/grr_response_server/databases/spanner_signed_commands_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_signed_commands_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands_test.py b/grr/server/grr_response_server/databases/spanner_signed_commands_test.py index 15ff6a872..647d06130 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_commands_test.py +++ b/grr/server/grr_response_server/databases/spanner_signed_commands_test.py @@ -20,4 +20,4 @@ class SpannerDatabaseSignedCommandsTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 752deb450382a7b307b89e574fc2c2ba07af30ce Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:54:08 +0100 Subject: [PATCH 112/168] Update grr/server/grr_response_server/databases/spanner_signed_commands.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_signed_commands.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands.py b/grr/server/grr_response_server/databases/spanner_signed_commands.py index 5b05c9a66..7ec913fb4 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_commands.py +++ b/grr/server/grr_response_server/databases/spanner_signed_commands.py @@ -7,7 +7,6 @@ from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound from google.cloud import spanner as spanner_lib -from grr_response_core.lib.util import iterator from grr_response_proto import signed_commands_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils From f8b755d7d7dc93c646f91abf206b321cde080318 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:54:17 +0100 Subject: [PATCH 113/168] Update grr/server/grr_response_server/databases/spanner_signed_commands.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_signed_commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands.py b/grr/server/grr_response_server/databases/spanner_signed_commands.py index 7ec913fb4..bfbdd7bba 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_commands.py +++ b/grr/server/grr_response_server/databases/spanner_signed_commands.py @@ -128,4 +128,4 @@ def Mutation(mut) -> None: for command in to_delete: mut.delete("SignedCommands", spanner_lib.KeySet(keys=[[command.id, int(command.operating_system)]])) - self.db.Mutate(Mutation, txn_tag="DeleteAllSignedCommands") \ No newline at end of file + self.db.Mutate(Mutation, txn_tag="DeleteAllSignedCommands") From 62e08255a63aef8339c755f23a796cdab930a890 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:54:32 +0100 Subject: [PATCH 114/168] Update grr/server/grr_response_server/databases/spanner_blob_keys_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_blob_keys_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_keys_test.py b/grr/server/grr_response_server/databases/spanner_blob_keys_test.py index 65935bf85..6bcd81b9a 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_keys_test.py +++ b/grr/server/grr_response_server/databases/spanner_blob_keys_test.py @@ -19,4 +19,4 @@ class SpannerDatabaseBlobKeysTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 60d5bf0e1856278e9889b008bcd0adf1df77384b Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:54:49 +0100 Subject: [PATCH 115/168] Update grr/server/grr_response_server/databases/spanner_blob_references_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_blob_references_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references_test.py b/grr/server/grr_response_server/databases/spanner_blob_references_test.py index c4130afef..1bb27228a 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references_test.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references_test.py @@ -20,4 +20,4 @@ class SpannerDatabaseBlobReferencesTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 646c11cb3e4dbb04d842c10b2a59d42d2d2eb9eb Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:55:06 +0100 Subject: [PATCH 116/168] Update grr/server/grr_response_server/databases/spanner_clients_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_clients_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_clients_test.py b/grr/server/grr_response_server/databases/spanner_clients_test.py index 450b8e295..4611e75b2 100644 --- a/grr/server/grr_response_server/databases/spanner_clients_test.py +++ b/grr/server/grr_response_server/databases/spanner_clients_test.py @@ -30,4 +30,4 @@ def testLabelWriteToUnknownUser(self): if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 42930981275ba35c2f20919cc20c539b286a09d6 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:55:24 +0100 Subject: [PATCH 117/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_cron_jobs_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py index 9c4d52a0d..c033b6a20 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py @@ -19,4 +19,4 @@ class SpannerDatabaseCronJobsTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From b0586f515abe13c60e26da3d6a5a6da7f40c8bcd Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:55:56 +0100 Subject: [PATCH 118/168] Update grr/server/grr_response_server/databases/spanner_time_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_time_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_time_test.py b/grr/server/grr_response_server/databases/spanner_time_test.py index 989058cc4..4618c497f 100644 --- a/grr/server/grr_response_server/databases/spanner_time_test.py +++ b/grr/server/grr_response_server/databases/spanner_time_test.py @@ -19,4 +19,4 @@ class SpannerDatabaseTimeTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 0f4fb601df4a0889532e61f03134923eadef4b6f Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:56:12 +0100 Subject: [PATCH 119/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_cron_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index 3b7bf5ac9..9aa89a267 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -185,7 +185,7 @@ def Transaction(txn) -> None: # we check it exists ourselves. keyset = spanner_lib.KeySet(keys=[[cronjob_id]]) try: - txn.read(table="CronJobs", keyset=keyset, columns=['JobId']).one() + txn.read(table="CronJobs", keyset=keyset, columns=["JobId"]).one() except NotFound as error: raise db.UnknownCronJobError(cronjob_id) from error From 8011d068380875967c0be73dae4c23c6387a4c20 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:56:33 +0100 Subject: [PATCH 120/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_cron_jobs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index 9aa89a267..e39fac2ae 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -559,7 +559,6 @@ def Transaction(txn) -> int: for job_id, run_id in rows: keyset = spanner_lib.KeySet(keys=[[job_id, run_id]]) - txn.delete( table="CronJobRuns", keyset=keyset, From 11f63c07eb786f3c5232fa16f1a26b0190b842ea Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:56:53 +0100 Subject: [PATCH 121/168] Update grr/server/grr_response_server/databases/spanner_cron_jobs.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_cron_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py index e39fac2ae..ab3177995 100644 --- a/grr/server/grr_response_server/databases/spanner_cron_jobs.py +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -722,4 +722,4 @@ def _CronJobRunFromRow( if backtrace is not None: job_run.backtrace = backtrace - return job_run \ No newline at end of file + return job_run From d6ba41f18b580bd990b9c9ebae4a5de8b88650fa Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:57:11 +0100 Subject: [PATCH 122/168] Update grr/server/grr_response_server/databases/spanner_events_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_events_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_events_test.py b/grr/server/grr_response_server/databases/spanner_events_test.py index 44d5cd10c..db21a0582 100644 --- a/grr/server/grr_response_server/databases/spanner_events_test.py +++ b/grr/server/grr_response_server/databases/spanner_events_test.py @@ -19,4 +19,4 @@ class SpannerDatabaseEventsTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 84f19a732d4552a55cdb535b4f34acfb83534846 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:57:27 +0100 Subject: [PATCH 123/168] Update grr/server/grr_response_server/databases/spanner_flows_large_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_flows_large_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows_large_test.py b/grr/server/grr_response_server/databases/spanner_flows_large_test.py index 60ce031be..b021bbef3 100644 --- a/grr/server/grr_response_server/databases/spanner_flows_large_test.py +++ b/grr/server/grr_response_server/databases/spanner_flows_large_test.py @@ -2,7 +2,6 @@ from grr_response_server.databases import db_flows_test from grr_response_server.databases import spanner_test_lib -from grr_response_server.databases import spanner_utils def setUpModule() -> None: From a6ef39d05cedc653ce68c528c4ebed53e4045fdc Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:58:04 +0100 Subject: [PATCH 124/168] Update grr/server/grr_response_server/databases/spanner_flows_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows_test.py b/grr/server/grr_response_server/databases/spanner_flows_test.py index fe5462247..4eebb522b 100644 --- a/grr/server/grr_response_server/databases/spanner_flows_test.py +++ b/grr/server/grr_response_server/databases/spanner_flows_test.py @@ -1,6 +1,3 @@ -import random -from unittest import mock - from absl.testing import absltest from google.cloud import spanner as spanner_lib From fe4c4784d701829ae00596970bd71ddd20db9cd5 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:58:18 +0100 Subject: [PATCH 125/168] Update grr/server/grr_response_server/databases/spanner_flows_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows_test.py b/grr/server/grr_response_server/databases/spanner_flows_test.py index 4eebb522b..1b7e1c764 100644 --- a/grr/server/grr_response_server/databases/spanner_flows_test.py +++ b/grr/server/grr_response_server/databases/spanner_flows_test.py @@ -1,9 +1,6 @@ from absl.testing import absltest -from google.cloud import spanner as spanner_lib -from grr_response_proto import flows_pb2 from grr_response_server.databases import db_flows_test -from grr_response_server.databases import db_test_utils from grr_response_server.databases import spanner_test_lib From 487be8dbe5ae421879e79aa502d4177c8d2ec33f Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:58:34 +0100 Subject: [PATCH 126/168] Update grr/server/grr_response_server/databases/spanner.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py index bbfeb1db2..b2e6f04c4 100644 --- a/grr/server/grr_response_server/databases/spanner.py +++ b/grr/server/grr_response_server/databases/spanner.py @@ -24,8 +24,6 @@ from grr_response_server.databases import spanner_users from grr_response_server.databases import spanner_utils from grr_response_server.databases import spanner_yara -from grr_response_server.models import blobs as models_blobs -from grr_response_server.rdfvalues import objects as rdf_objects class SpannerDB( spanner_artifacts.ArtifactsMixin, From c8d733446ceeb623a1c32b79e691ccb8db7a4030 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:58:46 +0100 Subject: [PATCH 127/168] Update grr/server/grr_response_server/databases/spanner_yara.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_yara.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_yara.py b/grr/server/grr_response_server/databases/spanner_yara.py index 882dc2954..350e6ecfa 100644 --- a/grr/server/grr_response_server/databases/spanner_yara.py +++ b/grr/server/grr_response_server/databases/spanner_yara.py @@ -58,4 +58,4 @@ def VerifyYaraSignatureReference( except NotFound: return False - return True \ No newline at end of file + return True From 70dcdd2d188e9e014707d20ef94c36fdf9440235 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:58:58 +0100 Subject: [PATCH 128/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 4a3c5566b..316322a83 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -488,5 +488,4 @@ def ReadSet( keyset=rows, request_options={"request_tag": txn_tag} ) - return results From 8054c7f592ea119e96e3b30466aa16da047ab8f9 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:59:13 +0100 Subject: [PATCH 129/168] Update grr/server/grr_response_server/databases/spanner_users.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_users.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 8ae7cd570..e95daf225 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -3,7 +3,6 @@ import base64 import datetime import logging -import sys import uuid from typing import Optional, Sequence, Tuple From db7041348f638302ef0942dad73412c40dd523a2 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:59:25 +0100 Subject: [PATCH 130/168] Update grr/server/grr_response_server/databases/spanner_users.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_users.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index e95daf225..6978cbd71 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -12,7 +12,6 @@ from google.cloud import spanner as spanner_lib from grr_response_core.lib import rdfvalue -from grr_response_core.lib.util import iterator from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 from grr_response_proto import user_pb2 From f4ec7c41cbf1a4ccd2b81fcbf2093341fd615532 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Mon, 8 Dec 2025 18:59:37 +0100 Subject: [PATCH 131/168] Update grr/server/grr_response_server/databases/spanner_users_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_users_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_users_test.py b/grr/server/grr_response_server/databases/spanner_users_test.py index 2ca44bb93..ca0692686 100644 --- a/grr/server/grr_response_server/databases/spanner_users_test.py +++ b/grr/server/grr_response_server/databases/spanner_users_test.py @@ -19,4 +19,4 @@ class SpannerDatabaseUsersTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 016490ce095120998414cb22528a33faecac612b Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:48:44 +0100 Subject: [PATCH 132/168] Update grr/server/grr_response_server/databases/spanner_blob_references.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_blob_references.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py index 22eedf612..eabb80e9a 100644 --- a/grr/server/grr_response_server/databases/spanner_blob_references.py +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -70,7 +70,6 @@ def ReadHashBlobReferences( """Reads blob references of a given set of hashes.""" result = {} key_ranges = [] - for h in hashes: hash_id_b64 = base64.b64encode(bytes(h.AsBytes())) key_ranges.append(spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,])) From d98f50bf38d174a8b050bbf20270391a56a74780 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:48:58 +0100 Subject: [PATCH 133/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index a06d12841..8502f0fee 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -1,15 +1,14 @@ #!/usr/bin/env python """A module with flow methods of the Spanner database implementation.""" +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence import dataclasses -import datetime import logging import threading import time +from typing import Any, Optional, Union import uuid -from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Tuple, Union - from google.api_core.exceptions import AlreadyExists, NotFound from google.cloud import spanner as spanner_lib from google.cloud.spanner_v1 import param_types From 4072bd937d6d4f858f483b13b6365cf140abb9c6 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:49:10 +0100 Subject: [PATCH 134/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 8502f0fee..337269e27 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -2375,7 +2375,7 @@ def DeleteMessageHandlerRequests( query = "DELETE FROM MessageHandlerRequests WHERE RequestId IN UNNEST(@request_ids)" request_ids = [] for r in requests: - request_ids.append(str(r.request_id)) + request_ids.append(str(r.request_id)) params={"request_ids": request_ids} self.db.ParamExecute(query, params, txn_tag="DeleteMessageHandlerRequests") From 78cf38d353cbe7dfa151caba78deec49ed289ddb Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:49:20 +0100 Subject: [PATCH 135/168] Update grr/server/grr_response_server/databases/spanner_hunts_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_hunts_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_hunts_test.py b/grr/server/grr_response_server/databases/spanner_hunts_test.py index 607721ae3..fc7a56afa 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts_test.py +++ b/grr/server/grr_response_server/databases/spanner_hunts_test.py @@ -22,4 +22,4 @@ class SpannerDatabaseHuntsTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 42b409d9401776382cf42627e2cc2a96b0814468 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:49:29 +0100 Subject: [PATCH 136/168] Update grr/server/grr_response_server/databases/spanner_signed_binaries.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_signed_binaries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries.py b/grr/server/grr_response_server/databases/spanner_signed_binaries.py index 0fffca22a..0a3f9979f 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_binaries.py +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries.py @@ -7,7 +7,6 @@ from google.cloud import spanner as spanner_lib from grr_response_core.lib import rdfvalue -from grr_response_core.lib.util import iterator from grr_response_proto import objects_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils From 71ca03208b38eb62bde7ffaf2019651f81c9a778 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:49:37 +0100 Subject: [PATCH 137/168] Update grr/server/grr_response_server/databases/spanner_hunts.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_hunts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index 0dee90e85..1529f1255 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -1,8 +1,8 @@ #!/usr/bin/env python """A module with hunt methods of the Spanner database implementation.""" import base64 - -from typing import AbstractSet, Callable, Collection, Iterable, List, Mapping, Optional, Sequence, Set +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence, Set +from typing import Optional from google.api_core.exceptions import AlreadyExists, NotFound From cf6f18bcf2d4ac1b9a5eca7a0d3cd246ec0a8214 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:49:49 +0100 Subject: [PATCH 138/168] Update grr/server/grr_response_server/databases/spanner_paths.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_paths.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index 6205e5a4a..f84ad8f6d 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -7,7 +7,6 @@ from google.cloud import spanner as spanner_lib from grr_response_core.lib import rdfvalue -from grr_response_core.lib.util import iterator from grr_response_proto import objects_pb2 from grr_response_server.databases import db as abstract_db from grr_response_server.databases import db_utils From 950ae5910dfdf2f4581c3c33a7aea1e3b03f9e06 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:49:57 +0100 Subject: [PATCH 139/168] Update grr/server/grr_response_server/databases/spanner_paths.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_paths.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index f84ad8f6d..eb7e12727 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -10,7 +10,6 @@ from grr_response_proto import objects_pb2 from grr_response_server.databases import db as abstract_db from grr_response_server.databases import db_utils -from grr_response_server.databases import spanner_clients from grr_response_server.databases import spanner_utils from grr_response_server.models import paths as models_paths From 9ccda4c5aa9d3d4e9f09d8b06d9468a0a8509389 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:50:08 +0100 Subject: [PATCH 140/168] Update grr/server/grr_response_server/databases/spanner_users.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_users.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py index 6978cbd71..ea21adbe1 100644 --- a/grr/server/grr_response_server/databases/spanner_users.py +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -504,8 +504,8 @@ def UpdateUserNotifications( param_placeholders = ", ".join([f"{{ts{i}}}" for i in range(len(timestamps))]) for i, timestamp in enumerate(timestamps): - param_name = f"ts{i}" - params[param_name] = timestamp.AsDatetime() + param_name = f"ts{i}" + params[param_name] = timestamp.AsDatetime() query = f""" UPDATE UserNotifications n From 44f4545d340bb99c651ac926962c6f830c947aae Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:50:42 +0100 Subject: [PATCH 141/168] Update grr/server/grr_response_server/databases/spanner_paths.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_paths.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py index eb7e12727..c821fb021 100644 --- a/grr/server/grr_response_server/databases/spanner_paths.py +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -1,7 +1,8 @@ #!/usr/bin/env python """A module with path methods of the Spanner database implementation.""" import base64 -from typing import Collection, Dict, Iterable, Optional, Sequence +from collections.abc import Collection, Iterable, Sequence +from typing import Optional from google.api_core.exceptions import NotFound from google.cloud import spanner as spanner_lib From 579aab9250b1cfb0ea9f2c71f33e52884a8a025e Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:50:55 +0100 Subject: [PATCH 142/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils_test.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index 17e0b6b20..e838e862c 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -1,10 +1,5 @@ import datetime -import time -from typing import Any -from typing import Iterator -from typing import List -from typing import Mapping -from unittest import mock +from typing import Any, List from absl.testing import absltest From 6ef67bd66b810f845ebb4c28002cbdaa4f93ca27 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:51:21 +0100 Subject: [PATCH 143/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index e838e862c..647562941 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -6,7 +6,6 @@ from google.cloud import spanner as spanner_lib from google.api_core.exceptions import NotFound -from grr_response_core.lib.util import iterator from grr_response_server.databases import spanner_test_lib from grr_response_server.databases import spanner_utils From 983120a9a88f9d99934db628f1bdfb8b95220164 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:51:29 +0100 Subject: [PATCH 144/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index 647562941..7387edd8c 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -156,7 +156,6 @@ def testExecutePartitioned(self): self.raw_db.Insert(table="Table", row={"Key": "baz"}) self.raw_db.ExecutePartitioned("DELETE FROM Table AS t WHERE t.Key LIKE 'ba%'") - results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) self.assertLen(results, 1) self.assertEqual(results[0], ["foo",]) From e01cec26aa1a158049d9362ee0a348ae6e438c26 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:51:42 +0100 Subject: [PATCH 145/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index 7387edd8c..e507a3c4a 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -250,7 +250,6 @@ def testReadSimple(self): def testReadNotExisting(self): with self.assertRaises(NotFound): self.raw_db.Read(table="Table", key=("foo",), cols=("Column",)) - ####################################### # ReadSet Tests ####################################### From ed6d0b080738b3376ee2a1449ad492cd0ef18333 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:51:53 +0100 Subject: [PATCH 146/168] Update grr/server/grr_response_server/databases/spanner_utils_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py index e507a3c4a..bf92c7bb3 100644 --- a/grr/server/grr_response_server/databases/spanner_utils_test.py +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -8,7 +8,6 @@ from grr_response_server.databases import spanner_test_lib -from grr_response_server.databases import spanner_utils def setUpModule() -> None: spanner_test_lib.Init(spanner_test_lib.TEST_SCHEMA_SDL_PATH, False) From 9ecf5b11db55ba846e3a48b1f4c693c475237318 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:52:02 +0100 Subject: [PATCH 147/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 316322a83..4e29be99c 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -167,7 +167,7 @@ def Query(self, query: str, txn_tag: Optional[str] = None) -> Cursor: A cursor over the query results. """ with self._pyspanner.snapshot() as snapshot: - results = snapshot.execute_sql(query, request_options={"request_tag": txn_tag}) + results = snapshot.execute_sql(query, request_options={"request_tag": txn_tag}) return results From 64859cddd54e1bc41e3414c8229cea1aef48e482 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:52:13 +0100 Subject: [PATCH 148/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 4e29be99c..f64c65838 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -230,12 +230,12 @@ def ParamQuery( param_type[key] = None # Or re-raise, or handle differently with self._pyspanner.snapshot() as snapshot: - results = snapshot.execute_sql( - query, - params=params, - param_types=param_type, - request_options={"request_tag": txn_tag} - ) + results = snapshot.execute_sql( + query, + params=params, + param_types=param_type, + request_options={"request_tag": txn_tag} + ) return results From 3e9a8125b7903915b932f8cf6d1c578b0edfa968 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:52:20 +0100 Subject: [PATCH 149/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index f64c65838..92f7e57fd 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -455,12 +455,12 @@ def Read( """ keyset = KeySet(keys=[key]) with self._pyspanner.snapshot() as snapshot: - results = snapshot.read( - table=table, - columns=cols, - keyset=keyset, - request_options={"request_tag": txn_tag} - ) + results = snapshot.read( + table=table, + columns=cols, + keyset=keyset, + request_options={"request_tag": txn_tag} + ) return results.one() def ReadSet( From 1b39610aa0861ccd3dbe44fff97b1de490bd20ca Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:52:30 +0100 Subject: [PATCH 150/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 92f7e57fd..9c42503ea 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -482,10 +482,10 @@ def ReadSet( Mappings from columns to values of the rows read. """ with self._pyspanner.snapshot() as snapshot: - results = snapshot.read( - table=table, - columns=cols, - keyset=rows, - request_options={"request_tag": txn_tag} - ) + results = snapshot.read( + table=table, + columns=cols, + keyset=rows, + request_options={"request_tag": txn_tag} + ) return results From a55c9ec50c1b3f71afb1de9d3c2dde87f53acd65 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:52:39 +0100 Subject: [PATCH 151/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 9c42503ea..ce4344764 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -293,12 +293,12 @@ def ParamExecute( param_type[key] = None # Or re-raise, or handle differently def param_execute(txn: Transaction): - txn.execute_update( - query, - params=params, - param_types=param_type, - request_options={"request_tag": txn_tag}, - ) + txn.execute_update( + query, + params=params, + param_types=param_type, + request_options={"request_tag": txn_tag}, + ) self._pyspanner.run_in_transaction(param_execute) From ba61ddfd54762e102a39fa89b8cd149db1036a51 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:53:14 +0100 Subject: [PATCH 152/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_utils.py | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index ce4344764..c6e150fd8 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -1,26 +1,10 @@ """Spanner-related helpers and other utilities.""" -import contextlib +from collections.abc import Callable, Iterable, Mapping, Sequence import datetime import decimal -import pytz import re -import time - -from typing import Any -from typing import Callable -from typing import Generic -from typing import Iterable -from typing import Iterator -from typing import List -from typing import Mapping -from typing import Optional -from typing import Sequence -from typing import Tuple -from typing import Type -from typing import TypeVar - -from concurrent import futures +from typing import Any, Optional, Tuple, TypeVar from google.cloud.spanner_v1 import database as spanner_lib from google.cloud.spanner_v1.transaction import Transaction as _Transaction From a18f0449644033990f60ced1ce729d1090b33bd1 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:53:30 +0100 Subject: [PATCH 153/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index c6e150fd8..e20bf073a 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -11,7 +11,6 @@ from google.cloud.spanner_v1.batch import _BatchBase as _Mutation from google.cloud.spanner import KeyRange, KeySet -from google.cloud.spanner_admin_database_v1.types import spanner_database_admin from google.cloud.spanner_v1 import param_types from google.rpc.code_pb2 import OK From 67ace6a14fe16d66f0538a04a4660f4f729b980d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:53:47 +0100 Subject: [PATCH 154/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index e20bf073a..986967bba 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -13,9 +13,6 @@ from google.cloud.spanner import KeyRange, KeySet from google.cloud.spanner_v1 import param_types -from google.rpc.code_pb2 import OK - -from grr_response_core.lib import rdfvalue from grr_response_core.lib.util import collection from grr_response_core.lib.util import iterator From 798baef9bfc352686b49d370c914caf4f02b3df4 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:54:05 +0100 Subject: [PATCH 155/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 986967bba..d50c920c9 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -14,7 +14,6 @@ from google.cloud.spanner_v1 import param_types from grr_response_core.lib.util import collection -from grr_response_core.lib.util import iterator from grr_response_proto import flows_pb2 from grr_response_proto import objects_pb2 From 89cfdeb595d9e8e2fe74159ee084e454954ed11e Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:54:19 +0100 Subject: [PATCH 156/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index d50c920c9..9adca8a9b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -15,8 +15,6 @@ from grr_response_core.lib.util import collection -from grr_response_proto import flows_pb2 -from grr_response_proto import objects_pb2 Row = Tuple[Any, ...] Cursor = Iterable[Row] From 87afd04ab8415d08a4fbc038053032cc4bd612a6 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:54:37 +0100 Subject: [PATCH 157/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 9adca8a9b..46093b2fc 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -36,7 +36,7 @@ class Database: queries through a transaction runner handling all brittle logic for the user. """ - _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") + _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") def __init__(self, pyspanner: spanner_lib.Database, project_id: str) -> None: super().__init__() From 0ce830ec73125754c09aa4d1c229ece1625510ce Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:54:54 +0100 Subject: [PATCH 158/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 46093b2fc..55cbada45 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -42,7 +42,6 @@ def __init__(self, pyspanner: spanner_lib.Database, project_id: str) -> None: super().__init__() self._pyspanner = pyspanner self.project_id = project_id - def _parametrize(self, query: str, names: Iterable[str]) -> str: match = self._PYSPANNER_PARAM_REGEX.search(query) if match is not None: From 79b3dd82cf42c65c65ed38836ebc0f49460d9fc3 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:55:09 +0100 Subject: [PATCH 159/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 55cbada45..9e772655b 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -64,7 +64,6 @@ def _get_param_type(self, value): A google.cloud.spanner_v1.types.Type object, or None if the type cannot be reliably inferred (e.g., for a standalone None value or an empty list). - Raises: TypeError: Raised for any unsupported type or empty container value. """ From b4fff3fdedea952af8b19cba16f147d67c33f0a3 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:55:24 +0100 Subject: [PATCH 160/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 9e772655b..75e73a6af 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -68,9 +68,10 @@ def _get_param_type(self, value): TypeError: Raised for any unsupported type or empty container value. """ if value is None: - # Cannot determine a specific Spanner type from a None value alone. - # This indicates that the type is ambiguous without further schema context. - return None + # Cannot determine a specific Spanner type from a None value alone. + # This indicates that the type is ambiguous without further schema + # context. + return None py_type = type(value) From 7b02e87484e3f1a3724d5710a6d6b09a24844912 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 04:55:40 +0100 Subject: [PATCH 161/168] Update grr/server/grr_response_server/databases/spanner_utils.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_utils.py | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py index 75e73a6af..647bf4337 100644 --- a/grr/server/grr_response_server/databases/spanner_utils.py +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -76,32 +76,33 @@ def _get_param_type(self, value): py_type = type(value) if py_type is int: - return param_types.INT64 + return param_types.INT64 elif py_type is float: - return param_types.FLOAT64 + return param_types.FLOAT64 elif py_type is str: - return param_types.STRING + return param_types.STRING elif py_type is bool: - return param_types.BOOL + return param_types.BOOL elif py_type is bytes: - return param_types.BYTES + return param_types.BYTES elif py_type is datetime.date: - return param_types.DATE + return param_types.DATE elif py_type is datetime.datetime: - # Note: Spanner TIMESTAMPs are stored in UTC. Ensure datetime objects - # are timezone-aware (UTC) when writing data. This function only maps the type. - return param_types.TIMESTAMP + # Note: Spanner TIMESTAMPs are stored in UTC. Ensure datetime objects + # are timezone-aware (UTC) when writing data. This function only maps the + # type. + return param_types.TIMESTAMP elif py_type is decimal.Decimal: - return param_types.NUMERIC + return param_types.NUMERIC elif py_type is list: - if len(value) > 0: - return param_types.Array(self._get_param_type(value[0])) - else: - raise TypeError(f"Empty value for Python type: {py_type.__name__} for Spanner type conversion.") + if len(value) > 0: + return param_types.Array(self._get_param_type(value[0])) + else: + raise TypeError(f"Empty value for Python type: {py_type.__name__} for Spanner type conversion.") else: - # Potentially raise an error for unsupported types or return None - # For a generic solution, raising an error for unknown types is often safer. - raise TypeError(f"Unsupported Python type: {py_type.__name__} for Spanner type conversion.") + # Potentially raise an error for unsupported types or return None + # For a generic solution, raising an error for unknown types is often safer. + raise TypeError(f"Unsupported Python type: {py_type.__name__} for Spanner type conversion.") def Transact( self, From 474ef8b38ed23f03c15c2ddd664b79045676dc38 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:51:35 +0100 Subject: [PATCH 162/168] Update grr/server/grr_response_server/databases/spanner_signed_binaries_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_signed_binaries_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py b/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py index 871918ff2..edba473ac 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py @@ -20,4 +20,4 @@ class SpannerDatabaseSignedBinariesTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 41a495de1680807b2141e853539e60c6eace056c Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:51:44 +0100 Subject: [PATCH 163/168] Update grr/server/grr_response_server/databases/spanner_signed_binaries.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_signed_binaries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries.py b/grr/server/grr_response_server/databases/spanner_signed_binaries.py index 0a3f9979f..31baa51c0 100644 --- a/grr/server/grr_response_server/databases/spanner_signed_binaries.py +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries.py @@ -117,4 +117,4 @@ def Mutation(mut: spanner_utils.Mutation) -> None: try: self.db.Mutate(Mutation, txn_tag="DeleteSignedBinaryReferences") except NotFound as error: - raise db.UnknownSignedBinaryError(binary_id) from error \ No newline at end of file + raise db.UnknownSignedBinaryError(binary_id) from error From 28f1e940b1d34b562c3d618997d651ecebab51c3 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:51:54 +0100 Subject: [PATCH 164/168] Update grr/server/grr_response_server/databases/spanner_paths_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_paths_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_paths_test.py b/grr/server/grr_response_server/databases/spanner_paths_test.py index e662defed..e5bb60b91 100644 --- a/grr/server/grr_response_server/databases/spanner_paths_test.py +++ b/grr/server/grr_response_server/databases/spanner_paths_test.py @@ -55,4 +55,4 @@ def _testComponents(self, components: Sequence[str]): # pylint: disable=invalid if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From 3336ee9d2400896de1ea20029a7b421e4ba52d27 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:52:02 +0100 Subject: [PATCH 165/168] Update grr/server/grr_response_server/databases/spanner_message_handler_test.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../databases/spanner_message_handler_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_message_handler_test.py b/grr/server/grr_response_server/databases/spanner_message_handler_test.py index 4b930b57c..636b39f15 100644 --- a/grr/server/grr_response_server/databases/spanner_message_handler_test.py +++ b/grr/server/grr_response_server/databases/spanner_message_handler_test.py @@ -20,4 +20,4 @@ class SpannerDatabaseHandlerTest( if __name__ == "__main__": - absltest.main() \ No newline at end of file + absltest.main() From d60235fc99bc957eb3da4e187b4842cd98009c89 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:52:11 +0100 Subject: [PATCH 166/168] Update grr/server/grr_response_server/databases/spanner_hunts.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_hunts.py | 1 - 1 file changed, 1 deletion(-) diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py index 1529f1255..00817a160 100644 --- a/grr/server/grr_response_server/databases/spanner_hunts.py +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -18,7 +18,6 @@ from grr_response_proto import output_plugin_pb2 from grr_response_server.databases import db as abstract_db from grr_response_server.databases import db_utils -from grr_response_server.databases import spanner_flows from grr_response_server.databases import spanner_utils from grr_response_server.models import hunts as models_hunts From b8f498a71487b67ddb14b7d4935940581bf6298d Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:52:23 +0100 Subject: [PATCH 167/168] Update grr/server/grr_response_server/databases/spanner_flows.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- grr/server/grr_response_server/databases/spanner_flows.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py index 337269e27..35852a7c8 100644 --- a/grr/server/grr_response_server/databases/spanner_flows.py +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -20,12 +20,11 @@ from grr_response_proto import flows_pb2 from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 +from grr_response_proto import rrg_pb2 from grr_response_server.databases import db from grr_response_server.databases import db_utils -from grr_response_server.databases import spanner_clients from grr_response_server.databases import spanner_utils from grr_response_server.models import hunts as models_hunts -from grr_response_proto import rrg_pb2 SPANNER_DELETE_FLOW_REQUESTS_FAILURES = metrics.Counter( From 993eb017eaee2dc6daef29edabd81e6d57567c32 Mon Sep 17 00:00:00 2001 From: daschwanden Date: Tue, 9 Dec 2025 15:52:34 +0100 Subject: [PATCH 168/168] Update grr/server/grr_response_server/databases/spanner_clients.py Co-authored-by: coperni <5769938+coperni@users.noreply.github.com> --- .../grr_response_server/databases/spanner_clients.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py index 61396a00e..c3dda9a22 100644 --- a/grr/server/grr_response_server/databases/spanner_clients.py +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -1,24 +1,22 @@ #!/usr/bin/env python """A module with client methods of the Spanner database implementation.""" +from collections.abc import Collection, Iterator, Mapping, Sequence import datetime -import logging -import re -from typing import Collection, Iterator, Mapping, Optional, Sequence, Tuple +from typing import Optional from google.api_core.exceptions import NotFound from google.cloud import spanner as spanner_lib from grr_response_core.lib import rdfvalue -from grr_response_core.lib.util import iterator from grr_response_proto import jobs_pb2 from grr_response_proto import objects_pb2 +from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 # Aliasing the import since the name db clashes with the db annotation. from grr_response_server.databases import db as db_lib from grr_response_server.databases import db_utils from grr_response_server.databases import spanner_utils from grr_response_server.models import clients as models_clients -from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 class ClientsMixin: