diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 18b46c74a..542d900fe 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,7 +3,12 @@ on: [push, pull_request] env: GCS_BUCKET: autobuilds.grr-response.com GCS_BUCKET_OPENAPI: autobuilds-grr-openapi - DOCKER_REPOSITORY: ghcr.io/google/grr + DOCKER_REPOSITORY: ghcr.io/${{ github.repository_owner }}/grr + #SPANNER_EMULATOR_HOST: localhost:9010 + #SPANNER_DATABASE: grr-database + #SPANNER_INSTANCE: grr-instance + #SPANNER_PROJECT_ID: spanner-emulator-project + #PROJECT_ID: spanner-emulator-project jobs: test-devenv: runs-on: ubuntu-22.04 @@ -44,6 +49,27 @@ jobs: fi python3 -m venv --system-site-packages "${HOME}/INSTALL" travis/install.sh + #- name: 'Install Cloud SDK' + # uses: google-github-actions/setup-gcloud@v2 + # with: + # install_components: 'beta,pubsub-emulator,cloud-spanner-emulator' + #- name: 'Start Spanner emulator' + # run: | + # gcloud config configurations create emulator + # gcloud config set auth/disable_credentials true + # gcloud config set project ${{ env.SPANNER_PROJECT_ID }} + # gcloud config set api_endpoint_overrides/spanner http://localhost:9020/ + # gcloud emulators spanner start & + # gcloud spanner instances create ${{ env.SPANNER_INSTANCE }} \ + # --config=emulator-config --description="Spanner Test Instance" --nodes=1 + # sudo apt-get update + # sudo apt install -y protobuf-compiler libprotobuf-dev + # # Verify that we can create the Spanner Instance and Database + # cd grr/server/grr_response_server/databases + # ./spanner_setup.sh + # cd - + # # Remove the Database for a clean test env + # gcloud spanner databases delete ${{ env.SPANNER_DATABASE }} --instance=${{ env.SPANNER_INSTANCE }} --quiet - name: Test run: | source "${HOME}/INSTALL/bin/activate" diff --git a/.gitignore b/.gitignore index 66aacf16b..767dcd311 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,6 @@ grr/server/grr_response_server/gui/ui/.angular/ docker_config_files/*.pem compose.watch.yaml Dockerfile.client + +grr/server/grr_response_server/databases/spanner_grr.pb +spanner-setup.txt diff --git a/compose.testing.yaml b/compose.testing.yaml index 818901d9b..4bdcca486 100644 --- a/compose.testing.yaml +++ b/compose.testing.yaml @@ -1,14 +1,14 @@ services: grr-admin-ui: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test grr-fleetspeak-frontend: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test grr-worker: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test grr-client: - image: ghcr.io/google/grr:testing + image: ghcr.io/google/grr:test privileged: true diff --git a/grr/core/grr_response_core/config/data_store.py b/grr/core/grr_response_core/config/data_store.py index 7c96e8d10..9fe1619b5 100644 --- a/grr/core/grr_response_core/config/data_store.py +++ b/grr/core/grr_response_core/config/data_store.py @@ -121,3 +121,33 @@ "Only used when Blobstore.implementation is GCSBlobStore." ), ) + +# Spanner configuration. +config_lib.DEFINE_string( + "Spanner.project", + default=None, + help=( + "GCP Project where the Spanner Instance is located." + ), +) +config_lib.DEFINE_string( + "Spanner.instance", + default="grr-instance", + help="The Spanner Instance for GRR.") + +config_lib.DEFINE_string( + "Spanner.database", + default="grr-database", + help="The Spanner Database for GRR.") + +config_lib.DEFINE_integer( + "Spanner.flow_processing_threads_min", + default=1, + help="The minimum number of flow-processing worker threads.", +) + +config_lib.DEFINE_integer( + "Spanner.flow_processing_threads_max", + default=20, + help="The maximum number of flow-processing worker threads.", +) diff --git a/grr/server/grr_response_server/databases/db_cronjob_test.py b/grr/server/grr_response_server/databases/db_cronjob_test.py index 7203a67c6..2ea502736 100644 --- a/grr/server/grr_response_server/databases/db_cronjob_test.py +++ b/grr/server/grr_response_server/databases/db_cronjob_test.py @@ -191,7 +191,7 @@ def testCronJobLeasing(self): job = self._CreateCronJob() self.db.WriteCronJob(job) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs(lease_time=lease_time) @@ -232,7 +232,7 @@ def testCronJobLeasingByID(self): for j in jobs: self.db.WriteCronJob(j) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() lease_time = rdfvalue.Duration.From(5, rdfvalue.MINUTES) with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( @@ -254,7 +254,7 @@ def testCronJobReturning(self): with self.assertRaises(ValueError): self.db.ReturnLeasedCronJobs([job]) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( cronjob_ids=[leased_job.cron_job_id], @@ -276,14 +276,14 @@ def testCronJobReturningMultiple(self): for j in jobs: self.db.WriteCronJob(j) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10000) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) ) self.assertLen(leased, 3) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10001) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): unleased_jobs = self.db.LeaseCronJobs( lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) @@ -292,7 +292,7 @@ def testCronJobReturningMultiple(self): self.db.ReturnLeasedCronJobs(leased) - current_time = rdfvalue.RDFDatetime.FromSecondsSinceEpoch(10002) + current_time = rdfvalue.RDFDatetime.Now() with test_lib.FakeTime(current_time): leased = self.db.LeaseCronJobs( lease_time=rdfvalue.Duration.From(5, rdfvalue.MINUTES) diff --git a/grr/server/grr_response_server/databases/db_flows_test.py b/grr/server/grr_response_server/databases/db_flows_test.py index 8993b5e87..f723b3858 100644 --- a/grr/server/grr_response_server/databases/db_flows_test.py +++ b/grr/server/grr_response_server/databases/db_flows_test.py @@ -2270,7 +2270,8 @@ def testWritesAndReadsSingleFlowResultOfSingleType(self): sample_result = flows_pb2.FlowResult(client_id=client_id, flow_id=flow_id) sample_result.payload.Pack(jobs_pb2.ClientSummary(client_id=client_id)) - with test_lib.FakeTime(42): + current_time = rdfvalue.RDFDatetime.Now() + with test_lib.FakeTime(current_time): self.db.WriteFlowResults([sample_result]) results = self.db.ReadFlowResults(client_id, flow_id, 0, 100) @@ -2278,7 +2279,7 @@ def testWritesAndReadsSingleFlowResultOfSingleType(self): self.assertEqual(results[0].payload, sample_result.payload) self.assertEqual( rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(results[0].timestamp), - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + current_time, ) def testWritesAndReadsRDFStringFlowResult(self): @@ -2335,7 +2336,8 @@ def testReadResultsRestoresAllFlowResultsFields(self): ) sample_result.payload.Pack(jobs_pb2.ClientSummary(client_id=client_id)) - with test_lib.FakeTime(42): + current_time = rdfvalue.RDFDatetime.Now() + with test_lib.FakeTime(current_time): self.db.WriteFlowResults([sample_result]) results = self.db.ReadFlowResults(client_id, flow_id, 0, 100) @@ -2346,7 +2348,7 @@ def testReadResultsRestoresAllFlowResultsFields(self): self.assertEqual(results[0].payload, sample_result.payload) self.assertEqual( rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(results[0].timestamp), - rdfvalue.RDFDatetime.FromSecondsSinceEpoch(42), + current_time, ) def testWritesAndReadsMultipleFlowResultsOfSingleType(self): @@ -2743,7 +2745,8 @@ def testWritesAndReadsSingleFlowErrorOfSingleType(self): sample_error = flows_pb2.FlowError(client_id=client_id, flow_id=flow_id) sample_error.payload.Pack(jobs_pb2.ClientSummary(client_id=client_id)) - with test_lib.FakeTime(42): + current_time = rdfvalue.RDFDatetime.Now() + with test_lib.FakeTime(current_time): self.db.WriteFlowErrors([sample_error]) errors = self.db.ReadFlowErrors(client_id, flow_id, 0, 100) @@ -2751,9 +2754,7 @@ def testWritesAndReadsSingleFlowErrorOfSingleType(self): self.assertEqual(errors[0].payload, sample_error.payload) self.assertEqual( errors[0].timestamp, - rdfvalue.RDFDatetime.FromSecondsSinceEpoch( - 42 - ).AsMicrosecondsSinceEpoch(), + current_time.AsMicrosecondsSinceEpoch(), ) def testWritesAndReadsMultipleFlowErrorsOfSingleType(self): diff --git a/grr/server/grr_response_server/databases/db_message_handler_test.py b/grr/server/grr_response_server/databases/db_message_handler_test.py index 68d321051..944a3d350 100644 --- a/grr/server/grr_response_server/databases/db_message_handler_test.py +++ b/grr/server/grr_response_server/databases/db_message_handler_test.py @@ -93,6 +93,7 @@ def testMessageHandlerRequestLeasing(self): m.ClearField("timestamp") got += l self.db.DeleteMessageHandlerRequests(got) + self.db.UnregisterMessageHandler() got.sort(key=lambda req: req.request_id) self.assertEqual(requests, got) diff --git a/grr/server/grr_response_server/databases/registry_init.py b/grr/server/grr_response_server/databases/registry_init.py index 7d576bd48..c2affb156 100644 --- a/grr/server/grr_response_server/databases/registry_init.py +++ b/grr/server/grr_response_server/databases/registry_init.py @@ -3,9 +3,11 @@ from grr_response_server.databases import mem from grr_response_server.databases import mysql +from grr_response_server.databases import spanner # All available databases go into this registry. REGISTRY = { "InMemoryDB": mem.InMemoryDB, "MysqlDB": mysql.MysqlDB, + "SpannerDB": spanner.SpannerDB.FromConfig, } diff --git a/grr/server/grr_response_server/databases/spanner.md b/grr/server/grr_response_server/databases/spanner.md new file mode 100644 index 000000000..65c0465e4 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner.md @@ -0,0 +1,80 @@ +# GRR on Spanner + +When operating GRR you need to decide on a persistence store. + +This document describes how to use [Spanner](https://cloud.google.com/spanner) +as the GRR datastore. + +Spanner is a fully managed, mission-critical database service on +[Google Cloud](https://cloud.google.com) that brings together relational, graph, +key-value, and search. It offers transactional consistency at global scale, +automatic, synchronous replication for high availability. + +## 1. Google Cloud Resources + +Running GRR on Spanner requires that you create and configure a +[Spanner instance](https://cloud.google.com/spanner/docs/instances) before you +run GRR. + +Furthermore, you also need to create a +[Google Cloud Storage](https://cloud.google.com/storage) +[Bucket](https://cloud.google.com/storage/docs/buckets) that +will serve as the GRR blobstore. + +## 2. Google Cloud Spanner Instance + +You can follow the instructions in the Google Cloud online documentation to +[create a Spanner instance](https://cloud.google.com/spanner/docs/create-query-database-console#create-instance). + +> [!NOTE] You only need to create the +> [Spanner instance](https://cloud.google.com/spanner/docs/instances). The +> GRR [Spanner database](https://cloud.google.com/spanner/docs/databases) +> and its tables are created by running the provided [spanner_setup.sh](./spanner_setup.sh) +> script. The script assumes that you use `grr-instance` as the +> GRR instance name and `grr-database` as the GRR database name. In +> case you want to use different values then you need to update the +> [spanner_setup.sh](./spanner_setup.sh) script accordingly. +> The script assumes that you have the +> [gcloud](https://cloud.google.com/sdk/docs/install) and the +> [protoc](https://protobuf.dev/installation/) binaries installed on your machine. + +Run the following command to create the GRR database and its tables: + +```bash +export PROJECT_ID= +export SPANNER_INSTANCE=grr-instance +export SPANNER_DATABASE=grr-database +./spanner_setup.sh +``` + +## 3. GRR Configuration + +To run GRR on Spanner you need to configure the components settings with +the values of the Google Cloud Spanner and the GCS Bucket resources mentioned above. + +The snippet below illustrates a sample GRR `server.yaml` configuration. + +```bash + Database.implementation: SpannerDB + Spanner.project: + Spanner.instance: grr-instance + Spanner.database: grr-database + Blobstore.implementation: GCSBlobStore + Blobstore.gcs.project: + Blobstore.gcs.bucket: +``` + +> [!NOTE] Make sure you remove all the `Mysql` related configuration items. + +## 4. IAM Permissions + +This guide assumes that your GRR instance is running on the [Google Kubernetes Engine](https://cloud.google.com/kubernetes-engine) (GKE) +and you can leverage [Workload Identity Federation for GKE](https://cloud.google.com/kubernetes-engine/docs/how-to/workload-identity) (WIF). + +Using WIF you can assign the required IAM roles using the WIF principal +`principal://iam.googleapis.com/projects//locations/global/workloadIdentityPools//sa/` +where `K8S_NAMESPACE` is the value of the Kubernetes Namespace and `K8S_SERVICE_ACCOUNT` is the value Kubernetes Service Account that your GRR Pods are running under. + +The two IAM roles that are required are: +- `roles/spanner.databaseUser` on your Spanner Database and +- `roles/storage.objectUser` on our GCS Bucket (the GRR Blobstore mentioned above). \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner.py b/grr/server/grr_response_server/databases/spanner.py new file mode 100644 index 000000000..b2e6f04c4 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +"""Spanner implementation of the GRR relational database abstraction. + +See grr/server/db.py for interface. +""" +from google.cloud.spanner import Client + +from grr_response_core import config +from grr_response_core.lib import rdfvalue +from grr_response_server import threadpool +from grr_response_server.databases import db as db_module +from grr_response_server.databases import spanner_artifacts +from grr_response_server.databases import spanner_blob_keys +from grr_response_server.databases import spanner_blob_references +from grr_response_server.databases import spanner_clients +from grr_response_server.databases import spanner_cron_jobs +from grr_response_server.databases import spanner_events +from grr_response_server.databases import spanner_flows +from grr_response_server.databases import spanner_foreman_rules +from grr_response_server.databases import spanner_hunts +from grr_response_server.databases import spanner_paths +from grr_response_server.databases import spanner_signed_binaries +from grr_response_server.databases import spanner_signed_commands +from grr_response_server.databases import spanner_users +from grr_response_server.databases import spanner_utils +from grr_response_server.databases import spanner_yara + +class SpannerDB( + spanner_artifacts.ArtifactsMixin, + spanner_blob_keys.BlobKeysMixin, + spanner_blob_references.BlobReferencesMixin, + spanner_clients.ClientsMixin, + spanner_cron_jobs.CronJobsMixin, + spanner_events.EventsMixin, + spanner_flows.FlowsMixin, + spanner_foreman_rules.ForemanRulesMixin, + spanner_hunts.HuntsMixin, + spanner_paths.PathsMixin, + spanner_signed_binaries.SignedBinariesMixin, + spanner_signed_commands.SignedCommandsMixin, + spanner_users.UsersMixin, + spanner_yara.YaraMixin, + db_module.Database, +): + """A Spanner implementation of the GRR database.""" + + def __init__(self, db: spanner_utils.Database) -> None: + """Initializes the database.""" + self.db = db + self._write_rows_batch_size = 100 + + self.handler_thread = None + self.handler_stop = True + + self.flow_processing_request_handler_thread = None + self.flow_processing_request_handler_stop = None + self.flow_processing_request_handler_pool = threadpool.ThreadPool.Factory( + "spanner_flow_processing_pool", + min_threads=config.CONFIG["Spanner.flow_processing_threads_min"], + max_threads=config.CONFIG["Spanner.flow_processing_threads_max"], + ) + + @classmethod + def FromConfig(cls) -> "SpannerDB": + """Creates a GRR database instance for Spanner path specified in the config. + + Returns: + A GRR database instance. + """ + project_id = config.CONFIG["Spanner.project"] + spanner_client = Client(project_id) + spanner_instance = spanner_client.instance(config.CONFIG["Spanner.instance"]) + spanner_database = spanner_instance.database(config.CONFIG["Spanner.database"]) + + return cls(spanner_utils.Database(spanner_database, project_id)) + + def Now(self) -> rdfvalue.RDFDatetime: + """Retrieves current time as reported by the database.""" + (timestamp,) = self.db.QuerySingle("SELECT CURRENT_TIMESTAMP()") + return rdfvalue.RDFDatetime.FromDatetime(timestamp) + + def MinTimestamp(self) -> rdfvalue.RDFDatetime: + """Returns minimal timestamp allowed by the DB.""" + return rdfvalue.RDFDatetime.FromSecondsSinceEpoch(0) diff --git a/grr/server/grr_response_server/databases/spanner.sdl b/grr/server/grr_response_server/databases/spanner.sdl new file mode 100644 index 000000000..b23e021bb --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner.sdl @@ -0,0 +1,673 @@ +CREATE PROTO BUNDLE ( + `google.protobuf.Any`, + `google.protobuf.Timestamp`, + `grr.ApprovalRequest`, + `grr.ApprovalRequest.ApprovalType`, + `grr.ClientLabel`, + `grr.GRRUser`, + `grr.GRRUser.UserType`, + `grr.GUISettings`, + `grr.GUISettings.UIMode`, + `grr.FleetspeakValidationInfo`, + `grr.FleetspeakValidationInfoTag`, + `grr.Password`, + `grr.PathInfo`, + `grr.PathInfo.PathType`, + `grr.UserNotification`, + `grr.UserNotification.State`, + `grr.UserNotification.Type`, + `grr.AttributedDict`, + `grr.BlobArray`, + `grr.Dict`, + `grr.DataBlob`, + `grr.DataBlob.CompressionType`, + `grr.EmbeddedRDFValue`, + `grr.KeyValue`, + `grr.AmazonCloudInstance`, + `grr.ClientInformation`, + `grr.ClientCrash`, + `grr.ClientSnapshot`, + `grr.CloudInstance`, + `grr.CloudInstance.InstanceType`, + `grr.EdrAgent`, + `grr.Filesystem`, + `grr.GoogleCloudInstance`, + `grr.HardwareInfo`, + `grr.Interface`, + `grr.KnowledgeBase`, + `grr.NetworkAddress`, + `grr.NetworkAddress.Family`, + `grr.StartupInfo`, + `grr.StringMapEntry`, + `grr.UnixVolume`, + `grr.User`, + `grr.Volume`, + `grr.Volume.VolumeFileSystemFlagEnum`, + `grr.WindowsVolume`, + `grr.WindowsVolume.WindowsDriveTypeEnum`, + `grr.WindowsVolume.WindowsVolumeAttributeEnum`, + `grr.ApprovalRequestReference`, + `grr.ClientReference`, + `grr.CronJobReference`, + `grr.FlowLikeObjectReference`, + `grr.FlowLikeObjectReference.ObjectType`, + `grr.FlowReference`, + `grr.HuntReference`, + `grr.ObjectReference`, + `grr.ObjectReference.Type`, + `grr.VfsFileReference`, + `grr.CpuSeconds`, + `grr.Flow`, + `grr.Flow.FlowState`, + `grr.FlowIterator`, + `grr.FlowOutputPluginLogEntry`, + `grr.FlowOutputPluginLogEntry.LogEntryType`, + `grr.FlowProcessingRequest`, + `grr.FlowRequest`, + `grr.FlowResponse`, + `grr.FlowResultCount`, + `grr.FlowResultMetadata`, + `grr.FlowRunnerArgs`, + `grr.FlowStatus`, + `grr.FlowStatus.Status`, + `grr.GrrMessage`, + `grr.GrrMessage.AuthorizationState`, + `grr.GrrMessage.Type`, + `grr.GrrStatus`, + `grr.GrrStatus.ReturnedStatus`, + `grr.OutputPluginDescriptor`, + `grr.OutputPluginState`, + `grr.RequestState`, + `grr.APIAuditEntry`, + `grr.APIAuditEntry.Code`, + `grr.AuthenticodeSignedData`, + `grr.BlobImageChunkDescriptor`, + `grr.BlobImageDescriptor`, + `grr.BufferReference`, + `grr.Hash`, + `grr.FileFinderResult`, + `grr.PathSpec`, + `grr.PathSpec.ImplementationType`, + `grr.PathSpec.Options`, + `grr.PathSpec.PathType`, + `grr.PathSpec.tsk_fs_attr_type`, + `grr.StatEntry`, + `grr.StatEntry.ExtAttr`, + `grr.StatEntry.RegistryType`, + `grr.ForemanClientRule`, + `grr.ForemanClientRule.Type`, + `grr.ForemanClientRuleSet`, + `grr.ForemanClientRuleSet.MatchMode`, + `grr.ForemanCondition`, + `grr.ForemanIntegerClientRule`, + `grr.ForemanIntegerClientRule.ForemanIntegerField`, + `grr.ForemanIntegerClientRule.Operator`, + `grr.ForemanLabelClientRule`, + `grr.ForemanLabelClientRule.MatchMode`, + `grr.ForemanOsClientRule`, + `grr.ForemanRegexClientRule`, + `grr.ForemanRegexClientRule.ForemanStringField`, + `grr.ForemanRule`, + `grr.ForemanRuleAction`, + `grr.Artifact`, + `grr.ArtifactSource`, + `grr.ArtifactSource.SourceType`, + `grr.ArtifactDescriptor`, + `grr.ClientActionResult`, + `grr.Hunt`, + `grr.Hunt.HuntState`, + `grr.Hunt.HuntStateReason`, + `grr.HuntArguments`, + `grr.HuntArguments.HuntType`, + `grr.HuntArgumentsStandard`, + `grr.HuntArgumentsVariable`, + `grr.VariableHuntFlowGroup`, + `grr.BlobReference`, + `grr.BlobReferences`, + `grr.SignedBinaryID`, + `grr.SignedBinaryID.BinaryType`, + `grr.CronJob`, + `grr.CronJobAction`, + `grr.CronJobAction.ActionType`, + `grr.CronJobRun`, + `grr.CronJobRun.CronJobRunStatus`, + `grr.SystemCronAction`, + `grr.HuntCronAction`, + `grr.HuntRunnerArgs`, + `grr.MessageHandlerRequest`, + `grr.Command`, + `grr.Command.EnvVar`, + `grr.SignedCommand`, + `grr.SignedCommand.OS`, + `rrg.Log`, + `rrg.Log.Level`, + `rrg.fs.Path`, + `rrg.startup.Metadata`, + `rrg.startup.Startup`, + `rrg.startup.Version`, + `rrg.action.execute_signed_command.Command`, +); + +CREATE TABLE Labels ( + Label STRING(128) NOT NULL, +) PRIMARY KEY (Label); + +CREATE TABLE Clients ( + ClientId STRING(18) NOT NULL, + LastSnapshotTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastStartupTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastRRGStartupTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastCrashTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + FirstSeenTime TIMESTAMP, + LastPingTime TIMESTAMP, + LastForemanTime TIMESTAMP, + Certificate BYTES(MAX), + FleetspeakEnabled BOOL, + FleetspeakValidationInfo `grr.FleetspeakValidationInfo`, +) PRIMARY KEY (ClientId); + +CREATE TABLE ClientSnapshots( + ClientId STRING(18) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Snapshot `grr.ClientSnapshot` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientStartups( + ClientId STRING(18) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Startup `grr.StartupInfo` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientRRGStartups( + ClientId STRING(18) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Startup `rrg.startup.Startup` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientCrashes( + ClientId STRING(18) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Crash `grr.ClientCrash` NOT NULL, +) PRIMARY KEY (ClientId, CreationTime), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE Users ( + Username STRING(256) NOT NULL, + Email STRING(256), + Password `grr.Password`, + Type `grr.GRRUser.UserType`, + CanaryMode BOOL, + UiMode `grr.GUISettings.UIMode`, +) PRIMARY KEY (Username); + +CREATE TABLE UserNotifications( + Username STRING(256) NOT NULL, + NotificationId STRING(36) NOT NULL, + Type `grr.UserNotification.Type` NOT NULL, + State `grr.UserNotification.State` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Message STRING(MAX), + Reference `grr.ObjectReference` NOT NULL, +) PRIMARY KEY (Username, NotificationId), + INTERLEAVE IN PARENT Users ON DELETE CASCADE; + +CREATE TABLE ApprovalRequests( + Requestor STRING(256) NOT NULL, + ApprovalId STRING(36) NOT NULL, + SubjectClientId STRING(18), + SubjectHuntId STRING(8), + SubjectCronJobId STRING(100), + Reason STRING(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + ExpirationTime TIMESTAMP NOT NULL, + NotifiedUsers ARRAY NOT NULL, + CcEmails ARRAY NOT NULL, + CONSTRAINT fk_approval_request_requestor_username + FOREIGN KEY (Requestor) + REFERENCES Users(Username), + + CONSTRAINT ck_subject_id_valid + CHECK ((IF(SubjectClientId IS NOT NULL, 1, 0) + + IF(SubjectHuntId IS NOT NULL, 1, 0) + + IF(SubjectCronJobId IS NOT NULL, 1, 0)) = 1), +) PRIMARY KEY (Requestor, ApprovalId); + +CREATE INDEX ApprovalRequestsByRequestor + ON ApprovalRequests(Requestor); + +CREATE INDEX ApprovalRequestsByRequestorSubjectClientId + ON ApprovalRequests(Requestor, SubjectClientId); + +CREATE INDEX ApprovalRequestsByRequestorSubjectHuntId + ON ApprovalRequests(Requestor, SubjectHuntId); + +CREATE INDEX ApprovalRequestsByRequestorSubjectCronJobId + ON ApprovalRequests(Requestor, SubjectCronJobId); + +CREATE TABLE ApprovalGrants( + Requestor STRING(256) NOT NULL, + ApprovalId STRING(36) NOT NULL, + Grantor STRING(256) NOT NULL, + GrantId STRING(36) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT fk_approval_grant_grantor_username + FOREIGN KEY (Grantor) + REFERENCES Users(Username) ON DELETE CASCADE, +) PRIMARY KEY (Requestor, ApprovalId, Grantor, GrantId), + INTERLEAVE IN PARENT ApprovalRequests ON DELETE CASCADE; + +CREATE INDEX ApprovalGrantsByGrantor + ON ApprovalGrants(Grantor); + +CREATE TABLE ClientLabels( + ClientId STRING(18) NOT NULL, + Owner STRING(256) NOT NULL, + Label STRING(128) NOT NULL, + + CONSTRAINT fk_client_label_owner_username + FOREIGN KEY (Owner) + REFERENCES Users(Username), + + CONSTRAINT fk_client_label_label_label + FOREIGN KEY (Label) + REFERENCES Labels(Label), +) PRIMARY KEY (ClientId, Owner, Label), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE TABLE ClientKeywords( + ClientId STRING(18) NOT NULL, + Keyword STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT ck_client_keyword_not_empty + CHECK (Keyword <> ''), +) PRIMARY KEY (ClientId, Keyword), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX ClientKeywordsByKeywordCreationTime + ON ClientKeywords(Keyword, CreationTime); + +CREATE TABLE Flows( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + ParentFlowId STRING(16), + ParentHuntId STRING(16), + LongFlowId STRING(256) NOT NULL, + Creator STRING(256) NOT NULL, + Name STRING(256) NOT NULL, + State `grr.Flow.FlowState` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + UpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Crash `grr.ClientCrash`, + NextRequestToProcess STRING(16), + ProcessingWorker STRING(MAX), + ProcessingStartTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + ProcessingEndTime TIMESTAMP, + Flow `grr.Flow` NOT NULL, + ReplyCount INT64 NOT NULL, + NetworkBytesSent INT64 NOT NULL, + UserCpuTimeUsed FLOAT64 NOT NULL, + SystemCpuTimeUsed FLOAT64 NOT NULL, + + CONSTRAINT fk_flow_client_id_parent_id_flow + FOREIGN KEY (ClientId, ParentFlowId) + REFERENCES Flows(ClientId, FlowId), +) PRIMARY KEY (ClientId, FlowId), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX FlowsByParentHuntIdFlowIdUpdateTime + ON Flows(ParentHuntId, FlowId, UpdateTime) + STORING (State); + +CREATE INDEX FlowsByParentHuntIdFlowIdState + ON Flows(ParentHuntId, FlowId, State) + STORING (ReplyCount, NetworkBytesSent, UserCpuTimeUsed, SystemCpuTimeUsed); + +CREATE TABLE FlowResults( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + HuntId STRING(16), + CreationTime TIMESTAMP NOT NULL, + Payload `google.protobuf.Any`, + RdfType STRING(MAX), + Tag STRING(MAX), +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowResultsByHuntIdCreationTime + ON FlowResults(HuntId, CreationTime); +CREATE INDEX FlowResultsByHuntIdFlowIdCreationTime + ON FlowResults(HuntId, FlowId, CreationTime); +CREATE INDEX FlowResultsByHuntIdFlowIdRdfTypeTagCreationTime + ON FlowResults(HuntId, FlowId, RdfType, Tag, CreationTime); + +CREATE TABLE FlowErrors( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + HuntId STRING(16) NOT NULL, + CreationTime TIMESTAMP NOT NULL, + Payload `google.protobuf.Any`, + RdfType STRING(MAX), + Tag STRING(MAX), +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowErrorsByHuntIdFlowIdCreationTime + ON FlowErrors(HuntId, FlowId, CreationTime); +CREATE INDEX FlowErrorsByHuntIdFlowIdRdfTypeTagCreationTime + ON FlowErrors(HuntId, FlowId, RdfType, Tag, CreationTime); + +CREATE TABLE FlowRequests( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + RequestId STRING(16) NOT NULL, + NeedsProcessing BOOL, + ExpectedResponseCount INT64, + NextResponseId STRING(16), + CallbackState STRING(256), + Payload `grr.FlowRequest` NOT NULL, + StartTime TIMESTAMP, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), +) PRIMARY KEY (ClientId, FlowId, RequestId), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE, + ROW DELETION POLICY (OLDER_THAN(CreationTime, INTERVAL 84 DAY)); + +CREATE TABLE FlowResponses( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + RequestId STRING(16) NOT NULL, + ResponseId STRING(16) NOT NULL, + Response `grr.FlowResponse`, + Status `grr.FlowStatus`, + Iterator `grr.FlowIterator`, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT ck_flow_response_has_payload + CHECK ((IF(Response IS NOT NULL, 1, 0) + + IF(Status IS NOT NULL, 1, 0) + + IF(Iterator IS NOT NULL, 1, 0)) = 1), +) PRIMARY KEY (ClientId, FlowId, RequestId, ResponseId), + INTERLEAVE IN PARENT FlowRequests ON DELETE CASCADE; + +CREATE TABLE FlowLogEntries( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + HuntId STRING(16), + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Message STRING(MAX) NOT NULL, +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowLogEntriesByHuntIdCreationTime + ON FlowLogEntries(HuntId, CreationTime); + +CREATE TABLE FlowRRGLogs( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + RequestId STRING(16) NOT NULL, + ResponseId STRING(16) NOT NULL, + LogLevel `rrg.Log.Level` NOT NULL, + LogTime TIMESTAMP NOT NULL, + LogMessage STRING(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), +) PRIMARY KEY (ClientId, FlowId, RequestId, ResponseId), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE TABLE FlowOutputPluginLogEntries( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + OutputPluginId STRING(16) NOT NULL, + HuntId STRING(16), + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Type `grr.FlowOutputPluginLogEntry.LogEntryType` NOT NULL, + Message STRING(MAX) NOT NULL, +) PRIMARY KEY (ClientId, FlowId, OutputPluginId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowOutputPluginLogEntriesByHuntIdCreationTime + ON FlowOutputPluginLogEntries(HuntId, CreationTime); + +CREATE TABLE ScheduledFlows( + ClientId STRING(18) NOT NULL, + Creator STRING(256) NOT NULL, + ScheduledFlowId STRING(16) NOT NULL, + FlowName STRING(256) NOT NULL, + FlowArgs `google.protobuf.Any` NOT NULL, + RunnerArgs `grr.FlowRunnerArgs` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Error STRING(MAX), + + CONSTRAINT fk_creator_users_username + FOREIGN KEY (Creator) + REFERENCES Users(Username) ON DELETE CASCADE, +) PRIMARY KEY (ClientId, Creator, ScheduledFlowId), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX ScheduledFlowsByCreator + ON ScheduledFlows(Creator); + +CREATE TABLE Paths( + ClientId STRING(18) NOT NULL, + Type `grr.PathInfo.PathType` NOT NULL, + Path BYTES(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + LastFileStatTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + LastFileHashTime TIMESTAMP OPTIONS (allow_commit_timestamp = true), + IsDir BOOL NOT NULL, + Depth INT64 NOT NULL, +) PRIMARY KEY (ClientId, Type, Path), + INTERLEAVE IN PARENT Clients ON DELETE CASCADE; + +CREATE INDEX PathsByClientIdTypePathDepth + ON Paths(ClientId, Type, Path, Depth); + +CREATE TABLE PathFileStats( + ClientId STRING(18) NOT NULL, + Type `grr.PathInfo.PathType` NOT NULL, + Path BYTES(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Stat `grr.StatEntry` NOT NULL, +) PRIMARY KEY (ClientId, Type, Path, CreationTime), + INTERLEAVE IN PARENT Paths ON DELETE CASCADE; + +CREATE TABLE PathFileHashes( + ClientId STRING(18) NOT NULL, + Type `grr.PathInfo.PathType` NOT NULL, + Path BYTES(MAX) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + FileHash `grr.Hash` NOT NULL, +) PRIMARY KEY (ClientId, Type, Path, CreationTime), + INTERLEAVE IN PARENT Paths ON DELETE CASCADE; + +CREATE TABLE HashBlobReferences( + HashId BYTES(32) NOT NULL, + Offset INT64 NOT NULL, + BlobId BYTES(32) NOT NULL, + Size INT64 NOT NULL, + + CONSTRAINT ck_hash_id_valid CHECK (BYTE_LENGTH(HashId) = 32), + CONSTRAINT ck_blob_id_valid CHECK (BYTE_LENGTH(BlobId) = 32), +) PRIMARY KEY (HashId, Offset); + +CREATE TABLE ApiAuditEntry ( + Username STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + HttpRequestPath STRING(MAX) NOT NULL, + RouterMethodName STRING(256) NOT NULL, + ResponseCode `grr.APIAuditEntry.Code` NOT NULL, +) PRIMARY KEY (Username, CreationTime); + +CREATE TABLE Artifacts ( + Name STRING(256) NOT NULL, + Platforms ARRAY NOT NULL, + Payload `grr.Artifact` NOT NULL, +) PRIMARY KEY (Name); + +CREATE TABLE YaraSignatureReferences( + BlobId BYTES(32) NOT NULL, + Creator STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + + CONSTRAINT ck_yara_signature_reference_blob_id_valid + CHECK (BYTE_LENGTH(BlobId) = 32), + CONSTRAINT fk_yara_signature_reference_creator_username + FOREIGN KEY (Creator) + REFERENCES Users(Username) +) PRIMARY KEY (BlobId); + +CREATE TABLE Hunts( + HuntId STRING(16) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + LastUpdateTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Creator STRING(256) NOT NULL, + DurationMicros INT64 NOT NULL, + Description STRING(MAX) NOT NULL, + ClientRate FLOAT64 NOT NULL, + ClientLimit INT64 NOT NULL, + State `grr.Hunt.HuntState` NOT NULL, + StateReason `grr.Hunt.HuntStateReason` NOT NULL, + StateComment STRING(MAX) NOT NULL, + InitStartTime TIMESTAMP, + LastStartTime TIMESTAMP, + ClientCountAtStartTime INT64 NOT NULL, + Hunt `grr.Hunt` NOT NULL, + + CONSTRAINT fk_hunt_creator_username + FOREIGN KEY (Creator) + REFERENCES Users(Username), +) PRIMARY KEY (HuntId); + +ALTER TABLE FlowLogEntries ADD CONSTRAINT fk_flow_log_entry_hunt_id_hunt + FOREIGN KEY (HuntId) + REFERENCES Hunts(HuntId); +ALTER TABLE FlowOutputPluginLogEntries ADD CONSTRAINT fk_flow_output_plugin_log_entry_hunt_id_hunt + FOREIGN KEY (HuntId) + REFERENCES Hunts(HuntId); +ALTER TABLE Flows ADD CONSTRAINT fk_flow_parent_hunt_id_hunt + FOREIGN KEY (ParentHuntId) + REFERENCES Hunts(HuntId); + +CREATE INDEX HuntsByCreationTime + ON Hunts(CreationTime DESC); + +CREATE INDEX HuntsByCreator ON Hunts(Creator); + +CREATE TABLE HuntOutputPlugins( + HuntId STRING(16) NOT NULL, + OutputPluginId INT64 NOT NULL, + Name STRING(256) NOT NULL, + Args `google.protobuf.Any`, + State `google.protobuf.Any` NOT NULL, +) PRIMARY KEY (HuntId, OutputPluginId), + INTERLEAVE IN PARENT Hunts ON DELETE CASCADE; + +CREATE TABLE ForemanRules( + HuntId STRING(16) NOT NULL, + ExpirationTime TIMESTAMP NOT NULL, + Payload `grr.ForemanCondition`, + + CONSTRAINT fk_foreman_rule_hunt_id_hunt + FOREIGN KEY (HuntId) + REFERENCES Hunts(HuntId), +) PRIMARY KEY (HuntId); + +CREATE INDEX ForemanRulesByExpirationTime + ON ForemanRules(ExpirationTime); + +CREATE TABLE SignedBinaries ( + Type `grr.SignedBinaryID.BinaryType` NOT NULL, + Path STRING(MAX) NOT NULL, + BlobReferences `grr.BlobReferences` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), +) PRIMARY KEY (Type, Path); + +CREATE TABLE CronJobs ( + JobId STRING(256) NOT NULL, + Job `grr.CronJob` NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + Enabled BOOL NOT NULL, + CurrentRunId STRING(256), + ForcedRunRequested BOOL, + LastRunStatus `grr.CronJobRun.CronJobRunStatus`, + LastRunTime TIMESTAMP, + State `grr.AttributedDict`, + LeaseEndTime TIMESTAMP, + LeaseOwner STRING(256), +) PRIMARY KEY (JobId); + +CREATE TABLE CronJobRuns ( + JobId STRING(256) NOT NULL, + RunId STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + FinishTime TIMESTAMP, + Status `grr.CronJobRun.CronJobRunStatus` NOT NULL, + LogMessage STRING(MAX), + Backtrace STRING(MAX), + Payload `grr.CronJobRun` NOT NULL, +) PRIMARY KEY (JobId, RunId), + INTERLEAVE IN PARENT CronJobs ON DELETE CASCADE; + +ALTER TABLE CronJobs ADD CONSTRAINT fk_cron_job_run_job_id_current_run_id_cron_job_run + FOREIGN KEY (JobId, CurrentRunId) + REFERENCES CronJobRuns(JobId, RunId); + +ALTER TABLE ApprovalRequests + ADD CONSTRAINT fk_approval_request_subject_client_id_client + FOREIGN KEY (SubjectClientId) + REFERENCES Clients(ClientId); + +ALTER TABLE ApprovalRequests + ADD CONSTRAINT fk_approval_request_subject_hunt_id_hunt + FOREIGN KEY (SubjectHuntId) + REFERENCES Hunts(HuntId) + ON DELETE CASCADE; + +ALTER TABLE ApprovalRequests + ADD CONSTRAINT fk_approval_request_subject_cron_job_id_cron_job + FOREIGN KEY (SubjectCronJobId) + REFERENCES CronJobs(JobId) + ON DELETE CASCADE; + +CREATE TABLE BlobEncryptionKeys( + BlobId BYTES(32) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + KeyName STRING(256) NOT NULL, +) PRIMARY KEY (BlobId, CreationTime); + +CREATE TABLE SignedCommands( + Id STRING(128) NOT NULL, + OperatingSystem `grr.SignedCommand.OS` NOT NULL, + Ed25519Signature BYTES(64) NOT NULL, + Command BYTES(MAX) NOT NULL, +) PRIMARY KEY (Id, OperatingSystem); + +CREATE TABLE MessageHandlerRequests( + RequestId STRING(16) NOT NULL, + HandlerName STRING(256) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + LeasedUntil TIMESTAMP, + LeasedBy STRING(128), + Payload `grr.MessageHandlerRequest` NOT NULL, +) PRIMARY KEY (RequestId); + +CREATE INDEX MessageHandlerRequestsByLease + ON MessageHandlerRequests(LeasedUntil, LeasedBy); + +CREATE TABLE FlowProcessingRequests( + ClientId STRING(18) NOT NULL, + FlowId STRING(16) NOT NULL, + CreationTime TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp = true), + RequestId STRING(36) NOT NULL, + DeliveryTime TIMESTAMP, + LeasedUntil TIMESTAMP, + LeasedBy STRING(128), + Payload `grr.FlowProcessingRequest` NOT NULL, +) PRIMARY KEY (ClientId, FlowId, CreationTime), + INTERLEAVE IN PARENT Flows ON DELETE CASCADE; + +CREATE INDEX FlowProcessingRequestsIds + ON FlowProcessingRequests(RequestId); diff --git a/grr/server/grr_response_server/databases/spanner_artifacts.py b/grr/server/grr_response_server/databases/spanner_artifacts.py new file mode 100644 index 000000000..0afcc496d --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_artifacts.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +"""A module with artifacts methods of the Spanner backend.""" + +from typing import Optional, Sequence + +from google.api_core.exceptions import AlreadyExists, NotFound +from google.cloud import spanner as spanner_lib + +from grr_response_proto import artifact_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class ArtifactsMixin: + """A Spanner database mixin with implementation of artifacts.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteArtifact(self, artifact: artifact_pb2.Artifact) -> None: + """Writes new artifact to the database. + + Args: + artifact: Artifact to be stored. + + Raises: + DuplicatedArtifactError: when the artifact already exists. + """ + name = str(artifact.name) + row = { + "Name": name, + "Platforms": list(artifact.supported_os), + "Payload": artifact, + } + try: + self.db.Insert(table="Artifacts", row=row, txn_tag="WriteArtifact") + except AlreadyExists as error: + raise db.DuplicatedArtifactError(name) from error + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadArtifact(self, name: str) -> Optional[artifact_pb2.Artifact]: + """Looks up an artifact with given name from the database. + + Args: + name: Name of the artifact to be read. + + Returns: + The artifact object read from the database. + + Raises: + UnknownArtifactError: when the artifact does not exist. + """ + try: + row = self.db.Read("Artifacts", + key=[name], + cols=("Platforms", "Payload"), + txn_tag="ReadArtifacts") + except NotFound as error: + raise db.UnknownArtifactError(name) from error + + artifact = artifact_pb2.Artifact.FromString(row[1]) + artifact.name = name + artifact.supported_os[:] = row[0] + return artifact + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllArtifacts(self) -> Sequence[artifact_pb2.Artifact]: + """Lists all artifacts that are stored in the database.""" + result = [] + + query = """ + SELECT a.Name, a.Platforms, a.Payload + FROM Artifacts AS a + """ + for [name, supported_os, payload] in self.db.Query(query): + artifact = artifact_pb2.Artifact.FromString(payload) + artifact.name = name + artifact.supported_os[:] = supported_os + result.append(artifact) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteArtifact(self, name: str) -> None: + """Deletes an artifact with given name from the database. + + Args: + name: Name of the artifact to be deleted. + + Raises: + UnknownArtifactError when the artifact does not exist. + """ + def Transaction(txn) -> None: + # Spanner does not raise if we attempt to delete a non-existing row so + # we check it exists ourselves. + keyset = spanner_lib.KeySet(keys=[[name],]) + + try: + txn.read(table="Artifacts", columns=("Name",), keyset=keyset).one() + except NotFound as error: + raise db.UnknownArtifactError(name) from error + + txn.delete("Artifacts", keyset) + + self.db.Transact(Transaction, txn_tag="DeleteArtifact") diff --git a/grr/server/grr_response_server/databases/spanner_artifacts_test.py b/grr/server/grr_response_server/databases/spanner_artifacts_test.py new file mode 100644 index 000000000..4e1eeb9f1 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_artifacts_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_artifacts_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseArtifactsTest( + db_artifacts_test.DatabaseTestArtifactsMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_blob_keys.py b/grr/server/grr_response_server/databases/spanner_blob_keys.py new file mode 100644 index 000000000..9b13a971b --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_keys.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python +"""Blob encryption key methods of Spanner database implementation.""" +import base64 + +from typing import Collection, Dict, Optional + +from google.cloud import spanner as spanner_lib +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import blobs as models_blobs + + +class BlobKeysMixin: + """A Spanner mixin with implementation of blob encryption keys methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteBlobEncryptionKeys( + self, + key_names: Dict[models_blobs.BlobID, str], + ) -> None: + """Associates the specified blobs with the given encryption keys.""" + # A special case for empty list of blob identifiers to avoid issues with an + # empty mutation. + if not key_names: + return + + def Mutation(mut) -> None: + for blob_id, key_name in key_names.items(): + mut.insert( + table="BlobEncryptionKeys", + columns=("BlobId", "CreationTime", "KeyName"), + values=[((base64.b64encode(bytes(blob_id))), spanner_lib.COMMIT_TIMESTAMP, key_name)] + ) + + self.db.Mutate(Mutation, txn_tag="WriteBlobEncryptionKeys") + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadBlobEncryptionKeys( + self, + blob_ids: Collection[models_blobs.BlobID], + ) -> Dict[models_blobs.BlobID, Optional[str]]: + """Retrieves encryption keys associated with blobs.""" + # A special case for empty list of blob identifiers to avoid syntax errors + # in the query below. + if not blob_ids: + return {} + + param_placeholders = ", ".join([f"{{blobId{i}}}" for i in range(len(blob_ids))]) + + params = {} + for i, blob_id_bytes in enumerate(blob_ids): + param_name = f"blobId{i}" + params[param_name] = base64.b64encode(bytes(blob_id_bytes)) + + query = f""" + SELECT k.BlobId, k.KeyName + FROM BlobEncryptionKeys AS k + INNER JOIN (SELECT k.BlobId, MAX(k.CreationTime) AS MaxCreationTime + FROM BlobEncryptionKeys AS k + WHERE k.BlobId IN ({param_placeholders}) + GROUP BY k.BlobId) AS last_k + ON k.BlobId = last_k.BlobId + AND k.CreationTime = last_k.MaxCreationTime + """ + + results = {blob_id: None for blob_id in blob_ids} + + for blob_id_bytes, key_name in self.db.ParamQuery( + query, params, txn_tag="ReadBlobEncryptionKeys" + ): + blob_id = models_blobs.BlobID(base64.b64decode(blob_id_bytes)) + results[blob_id] = key_name + + return results diff --git a/grr/server/grr_response_server/databases/spanner_blob_keys_test.py b/grr/server/grr_response_server/databases/spanner_blob_keys_test.py new file mode 100644 index 000000000..6bcd81b9a --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_keys_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_blob_keys_test_lib +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseBlobKeysTest( + db_blob_keys_test_lib.DatabaseTestBlobKeysMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_blob_references.py b/grr/server/grr_response_server/databases/spanner_blob_references.py new file mode 100644 index 000000000..eabb80e9a --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_references.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +"""A library with blob references methods of Spanner database implementation.""" +import base64 + +from typing import Collection, Mapping, Optional + +from google.cloud import spanner as spanner_lib + +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.rdfvalues import objects as rdf_objects + + +class BlobReferencesMixin: + """A Spanner database mixin with implementation of blob references methods.""" + + db: spanner_utils.Database + BATCH_SIZE = 16000 + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteHashBlobReferences( + self, + references_by_hash: Mapping[ + rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] + ], + ) -> None: + """Writes blob references for a given set of hashes.""" + batch = dict() + for k, v in references_by_hash.items(): + batch[k] = v + if len(batch) == self.BATCH_SIZE: + self._WriteHashBlobReferences(batch) + batch = dict() + if batch: + self._WriteHashBlobReferences(batch) + + def _WriteHashBlobReferences( + self, + references_by_hash: Mapping[ + rdf_objects.SHA256HashID, Collection[objects_pb2.BlobReference] + ], + ) -> None: + """Writes blob references for a given set of hashes.""" + def Mutation(mut) -> None: + for hash_id, refs in references_by_hash.items(): + hash_id_b64 = base64.b64encode(bytes(hash_id.AsBytes())) + key_range = spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,]) + keyset = spanner_lib.KeySet(ranges=[key_range]) + # Make sure we delete any of the previously existing blob references. + mut.delete("HashBlobReferences", keyset) + + for ref in refs: + mut.insert( + table="HashBlobReferences", + columns=("HashId", "BlobId", "Offset", "Size",), + values=[(hash_id_b64, base64.b64encode(bytes(ref.blob_id)), ref.offset, ref.size,)] + ) + self.db.Mutate(Mutation, txn_tag="WriteHashBlobReferences") + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHashBlobReferences( + self, hashes: Collection[rdf_objects.SHA256HashID] + ) -> Mapping[ + rdf_objects.SHA256HashID, Optional[Collection[objects_pb2.BlobReference]] + ]: + """Reads blob references of a given set of hashes.""" + result = {} + key_ranges = [] + for h in hashes: + hash_id_b64 = base64.b64encode(bytes(h.AsBytes())) + key_ranges.append(spanner_lib.KeyRange(start_closed=[hash_id_b64,], end_closed=[hash_id_b64,])) + result[h] = [] + rows = spanner_lib.KeySet(ranges=key_ranges) + + hashes_left = set(hashes) + for row in self.db.ReadSet( + table="HashBlobReferences", + rows=rows, + cols=("HashId", "BlobId", "Offset", "Size"), + txn_tag="ReadHashBlobReferences" + ): + hash_id = rdf_objects.SHA256HashID(base64.b64decode(row[0])) + + blob_ref = objects_pb2.BlobReference() + blob_ref.blob_id = base64.b64decode(row[1]) + blob_ref.offset = row[2] + blob_ref.size = row[3] + + result[hash_id].append(blob_ref) + hashes_left.discard(hash_id) + + for h in hashes_left: + result[h] = None + + return result diff --git a/grr/server/grr_response_server/databases/spanner_blob_references_test.py b/grr/server/grr_response_server/databases/spanner_blob_references_test.py new file mode 100644 index 000000000..1bb27228a --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_blob_references_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_blob_references_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseBlobReferencesTest( + db_blob_references_test.DatabaseTestBlobReferencesMixin, + spanner_test_lib.TestCase, +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_clients.py b/grr/server/grr_response_server/databases/spanner_clients.py new file mode 100644 index 000000000..c3dda9a22 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_clients.py @@ -0,0 +1,836 @@ +#!/usr/bin/env python +"""A module with client methods of the Spanner database implementation.""" + +from collections.abc import Collection, Iterator, Mapping, Sequence +import datetime +from typing import Optional + +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + +from grr_response_core.lib import rdfvalue +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_proto.rrg import startup_pb2 as rrg_startup_pb2 +# Aliasing the import since the name db clashes with the db annotation. +from grr_response_server.databases import db as db_lib +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import clients as models_clients + + +class ClientsMixin: + """A Spanner database mixin with implementation of client methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiWriteClientMetadata( + self, + client_ids: Collection[str], + first_seen: Optional[rdfvalue.RDFDatetime] = None, + last_ping: Optional[rdfvalue.RDFDatetime] = None, + last_foreman: Optional[rdfvalue.RDFDatetime] = None, + fleetspeak_validation_info: Optional[Mapping[str, str]] = None, + ) -> None: + """Writes metadata about the clients.""" + if not client_ids: + return + + row = {} + + if first_seen is not None: + row["FirstSeenTime"] = first_seen.AsDatetime() + if last_ping is not None: + row["LastPingTime"] = last_ping.AsDatetime() + if last_foreman is not None: + row["LastForemanTime"] = last_foreman.AsDatetime() + + if fleetspeak_validation_info is not None: + if fleetspeak_validation_info: + row["FleetspeakValidationInfo"] = ( + models_clients.FleetspeakValidationInfoFromDict( + fleetspeak_validation_info + ) + ) + else: + row["FleetspeakValidationInfo"] = None + + def Mutation(mut) -> None: + columns = [] + rows = [] + for client_id in client_ids: + client_row = {"ClientId": client_id, **row} + columns, values = zip(*client_row.items()) + rows.append(values) + mut.insert_or_update(table="Clients", columns=columns, values=rows) + + self.db.Mutate(Mutation, txn_tag="MultiWriteClientMetadata") + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientMetadata( + self, + client_ids: Collection[str], + ) -> Mapping[str, objects_pb2.ClientMetadata]: + """Reads ClientMetadata records for a list of clients.""" + result = {} + + keys = [] + for client_id in client_ids: + keys.append([client_id]) + keyset = spanner_lib.KeySet(keys=keys) + + cols = ( + "ClientId", + "LastStartupTime", + "LastCrashTime", + "Certificate", + "FirstSeenTime", + "LastPingTime", + "LastForemanTime", + "FleetspeakValidationInfo", + ) + + for row in self.db.ReadSet(table="Clients", + rows=keyset, + cols=cols, + txn_tag="MultiReadClientMetadata"): + client_id = row[0] + + metadata = objects_pb2.ClientMetadata() + + if row[1] is not None: + metadata.startup_info_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[1]) + ) + if row[2] is not None: + metadata.last_crash_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[2]) + ) + if row[3] is not None: + metadata.certificate = row[3] + if row[4] is not None: + metadata.first_seen = int( + rdfvalue.RDFDatetime.FromDatetime(row[4]) + ) + if row[5] is not None: + metadata.ping = int( + rdfvalue.RDFDatetime.FromDatetime(row[5]) + ) + if row[6] is not None: + metadata.last_foreman_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[6]) + ) + if row[7] is not None: + metadata.last_fleetspeak_validation_info.ParseFromString( + row[7] + ) + + result[client_id] = metadata + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiAddClientLabels( + self, + client_ids: Collection[str], + owner: str, + labels: Collection[str], + ) -> None: + """Attaches user labels to the specified clients.""" + # Early return to avoid generating empty mutation. + if not client_ids or not labels: + return + + def Mutation(mut) -> None: + label_rows = [] + for label in labels: + label_rows.append([label]) + mut.insert_or_update(table="Labels", columns=["Label"], values=label_rows) + + client_rows = [] + for client_id in client_ids: + for label in labels: + client_rows.append([client_id, owner, label]) + columns = ["ClientId", "Owner", "Label"] + mut.insert_or_update(table="ClientLabels", columns=columns, values=client_rows) + + try: + self.db.Mutate(Mutation, txn_tag="MultiAddClientLabels") + except Exception as error: + message = str(error) + if "Parent row for row [" in message: + raise db_lib.AtLeastOneUnknownClientError(client_ids) from error + elif "fk_client_label_owner_username" in message: + raise db_lib.UnknownGRRUserError(username=owner, cause=error) + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientLabels( + self, + client_ids: Collection[str], + ) -> Mapping[str, Collection[objects_pb2.ClientLabel]]: + """Reads the user labels for a list of clients.""" + result = {client_id: [] for client_id in client_ids} + + query = """ + SELECT l.ClientId, ARRAY_AGG((l.Owner, l.Label)) + FROM ClientLabels AS l + WHERE l.ClientId IN UNNEST({client_ids}) + GROUP BY l.ClientId + """ + params = {"client_ids": client_ids} + + for client_id, labels in self.db.ParamQuery( + query, params, txn_tag="MultiReadClientLabels" + ): + for owner, label in labels: + label_proto = objects_pb2.ClientLabel() + label_proto.name = label + label_proto.owner = owner + result[client_id].append(label_proto) + + for labels in result.values(): + labels.sort(key=lambda label: (label.owner, label.name)) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveClientLabels( + self, + client_id: str, + owner: str, + labels: Collection[str], + ) -> None: + """Removes a list of user labels from a given client.""" + + def Mutation(mut) -> None: + keys = [] + for label in labels: + keys.append([client_id, owner, label]) + mut.delete(table="ClientLabels", keyset=spanner_lib.KeySet(keys=keys)) + + self.db.Mutate(Mutation, txn_tag="RemoveClientLabels") + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllClientLabels(self) -> Collection[str]: + """Lists all client labels known to the system.""" + result = [] + + query = """ + SELECT l.Label + FROM Labels AS l + """ + + for (label,) in self.db.Query(query, txn_tag="ReadAllClientLabels"): + result.append(label) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientSnapshot(self, snapshot: objects_pb2.ClientSnapshot) -> None: + """Writes new client snapshot.""" + + startup = snapshot.startup_info + snapshot_without_startup_info = objects_pb2.ClientSnapshot() + snapshot_without_startup_info.CopyFrom(snapshot) + snapshot_without_startup_info.ClearField("startup_info") + + def Mutation(mut) -> None: + clients_rows = [] + clients_rows.append([snapshot.client_id,spanner_lib.COMMIT_TIMESTAMP,spanner_lib.COMMIT_TIMESTAMP]) + clients_columns = ["ClientId", "LastSnapshotTime", "LastStartupTime"] + mut.update(table="Clients", columns=clients_columns, values=clients_rows) + + snapshots_rows = [] + snapshots_rows.append([snapshot.client_id, spanner_lib.COMMIT_TIMESTAMP, snapshot_without_startup_info]) + snapshots_columns= ["ClientId", "CreationTime", "Snapshot"] + mut.insert(table="ClientSnapshots", columns=snapshots_columns, values=snapshots_rows) + + startups_rows = [] + startups_rows.append([snapshot.client_id, spanner_lib.COMMIT_TIMESTAMP, startup]) + startups_columns = [ "ClientId", "CreationTime", "Startup"] + mut.insert(table="ClientStartups", columns=startups_columns, values=startups_rows) + + try: + self.db.Mutate(Mutation, txn_tag="WriteClientSnapshot") + except NotFound as error: + raise db_lib.UnknownClientError(snapshot.client_id, cause=error) + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientSnapshot( + self, + client_ids: Collection[str], + ) -> Mapping[str, Optional[objects_pb2.ClientSnapshot]]: + """Reads the latest client snapshots for a list of clients.""" + if not client_ids: + return {} + + result = {client_id: None for client_id in client_ids} + + query = """ + SELECT c.ClientId, ss.CreationTime, ss.Snapshot, su.Startup + FROM Clients AS c, ClientSnapshots AS ss, ClientStartups AS su + WHERE c.ClientId IN UNNEST({client_ids}) + AND ss.ClientId = c.ClientId + AND ss.CreationTime = c.LastSnapshotTime + AND su.ClientId = c.ClientId + AND su.CreationTime = c.LastStartupTime + """ + for row in self.db.ParamQuery( + query, {"client_ids": client_ids}, txn_tag="MultiReadClientSnapshot" + ): + client_id, creation_time, snapshot_bytes, startup_bytes = row + + snapshot = objects_pb2.ClientSnapshot() + snapshot.ParseFromString(snapshot_bytes) + snapshot.startup_info.ParseFromString(startup_bytes) + snapshot.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + result[client_id] = snapshot + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientSnapshotHistory( + self, + client_id: str, + timerange: Optional[ + tuple[Optional[rdfvalue.RDFDatetime], Optional[rdfvalue.RDFDatetime]] + ] = None, + ) -> Sequence[objects_pb2.ClientSnapshot]: + """Reads the full history for a particular client.""" + result = [] + + query = """ + SELECT ss.CreationTime, ss.Snapshot, su.Startup + FROM ClientSnapshots AS ss, ClientStartups AS su + WHERE ss.ClientId = {client_id} + AND ss.ClientId = su.ClientId + AND ss.CreationTime = su.CreationTime + """ + params = {"client_id": client_id} + + if timerange is not None: + time_since, time_until = timerange + if time_since is not None: + query += " AND ss.CreationTime >= {time_since}" + params["time_since"] = time_since.AsDatetime() + if time_until is not None: + query += " AND ss.CreationTime <= {time_until}" + params["time_until"] = time_until.AsDatetime() + + query += " ORDER BY ss.CreationTime DESC" + + for row in self.db.ParamQuery( + query, params=params, txn_tag="ReadClientSnapshotHistory" + ): + creation_time, snapshot_bytes, startup_bytes = row + + snapshot = objects_pb2.ClientSnapshot() + snapshot.ParseFromString(snapshot_bytes) + snapshot.startup_info.ParseFromString(startup_bytes) + snapshot.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + result.append(snapshot) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientStartupInfo( + self, + client_id: str, + startup: jobs_pb2.StartupInfo, + ) -> None: + """Writes a new client startup record.""" + + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.update(table="Clients", + columns=["ClientId", "LastStartupTime"], + values=[[client_id, spanner_lib.COMMIT_TIMESTAMP]]) + + mut.insert(table="ClientStartups", + columns=["ClientId", "CreationTime", "Startup"], + values=[[client_id ,spanner_lib.COMMIT_TIMESTAMP, startup]]) + + try: + self.db.Mutate(Mutation, txn_tag="WriteClientStartupInfo") + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientRRGStartup( + self, + client_id: str, + startup: rrg_startup_pb2.Startup, + ) -> None: + """Writes a new RRG startup entry to the database.""" + + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.update( + table="Clients", + columns=("ClientId", "LastRRGStartupTime"), + values=[(client_id, spanner_lib.COMMIT_TIMESTAMP)] + ) + + mut.insert( + table="ClientRRGStartups", + columns=("ClientId", "CreationTime", "Startup"), + values=[(client_id, spanner_lib.COMMIT_TIMESTAMP, startup)] + ) + + try: + self.db.Mutate(Mutation) + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientRRGStartup( + self, + client_id: str, + ) -> Optional[rrg_startup_pb2.Startup]: + """Reads the latest RRG startup entry for the given client.""" + query = """ + SELECT su.Startup + FROM Clients AS c + LEFT JOIN ClientRRGStartups AS su + ON c.ClientId = su.ClientId + AND c.LastRRGStartupTime = su.CreationTime + WHERE c.ClientId = {client_id} + """ + params = { + "client_id": client_id, + } + + try: + (startup_bytes,) = self.db.ParamQuerySingle( + query, params, txn_tag="ReadClientRRGStartup" + ) + except NotFound: + raise db_lib.UnknownClientError(client_id) # pylint: disable=raise-missing-from + + if startup_bytes is None: + return None + + return rrg_startup_pb2.Startup.FromString(startup_bytes) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientStartupInfo( + self, + client_id: str, + ) -> Optional[jobs_pb2.StartupInfo]: + """Reads the latest client startup record for a single client.""" + + query = """ + SELECT su.CreationTime, su.Startup + FROM Clients AS c, ClientStartups AS su + WHERE c.ClientId = {client_id} + AND c.ClientId = su.ClientId + AND c.LastStartupTime = su.CreationTime + """ + params = {"client_id": client_id} + + try: + (creation_time, startup_bytes) = self.db.ParamQuerySingle( + query, params, txn_tag="ReadClientStartupInfo" + ) + except NotFound: + return None + + startup = jobs_pb2.StartupInfo() + startup.ParseFromString(startup_bytes) + startup.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + return startup + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteClientCrashInfo( + self, + client_id: str, + crash: jobs_pb2.ClientCrash, + ) -> None: + """Writes a new client crash record.""" + + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.update(table="Clients", + columns=["ClientId", "LastCrashTime"], + values=[[client_id, spanner_lib.COMMIT_TIMESTAMP]]) + + mut.insert(table="ClientCrashes", + columns=["ClientId", "CreationTime", "Crash"], + values=[[client_id, spanner_lib.COMMIT_TIMESTAMP, crash]]) + + try: + self.db.Mutate(Mutation, txn_tag="WriteClientCrashInfo") + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientCrashInfo( + self, + client_id: str, + ) -> Optional[jobs_pb2.ClientCrash]: + """Reads the latest client crash record for a single client.""" + + query = """ + SELECT cr.CreationTime, cr.Crash + FROM Clients AS c, ClientCrashes AS cr + WHERE c.ClientId = {client_id} + AND c.ClientId = cr.ClientId + AND c.LastCrashTime = cr.CreationTime + """ + params = {"client_id": client_id} + + try: + (creation_time, crash_bytes) = self.db.ParamQuerySingle( + query, params, txn_tag="ReadClientCrashInfo" + ) + except NotFound: + return None + + crash = jobs_pb2.ClientCrash() + crash = crash.FromString(crash_bytes) + crash.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + return crash + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientCrashInfoHistory( + self, + client_id: str, + ) -> Sequence[jobs_pb2.ClientCrash]: + """Reads the full crash history for a particular client.""" + result = [] + + query = """ + SELECT cr.CreationTime, cr.Crash + FROM ClientCrashes AS cr + WHERE cr.ClientId = {client_id} + ORDER BY cr.CreationTime DESC + """ + params = {"client_id": client_id} + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadClientCrashInfoHistory" + ): + creation_time, crash_bytes = row + + crash = jobs_pb2.ClientCrash() + crash.ParseFromString(crash_bytes) + crash.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + + result.append(crash) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiReadClientFullInfo( + self, + client_ids: Collection[str], + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + ) -> Mapping[str, objects_pb2.ClientFullInfo]: + """Reads full client information for a list of clients.""" + if not client_ids: + return {} + + result = {} + + query = """ + SELECT c.ClientId, + c.Certificate, + c.FleetspeakValidationInfo, + c.FirstSeenTime, c.LastPingTime, c.LastForemanTime, + c.LastSnapshotTime, c.LastStartupTime, c.LastCrashTime, + ss_last.Snapshot, + su_last.Startup, su_last_snapshot.Startup, + rrg_su_last.Startup, + ARRAY(SELECT AS STRUCT + l.Owner, l.Label + FROM ClientLabels AS l + WHERE l.ClientId = c.ClientId) + FROM Clients AS c + LEFT JOIN ClientSnapshots AS ss_last + ON (c.ClientId = ss_last.ClientId) + AND (c.LastSnapshotTime = ss_last.CreationTime) + LEFT JOIN ClientStartups AS su_last + ON (c.ClientId = su_last.ClientId) + AND (c.LastStartupTime = su_last.CreationTime) + LEFT JOIN ClientRrgStartups AS rrg_su_last + ON (c.ClientId = rrg_su_last.ClientId) + AND (c.LastRrgStartupTime = rrg_su_last.CreationTime) + LEFT JOIN ClientStartups AS su_last_snapshot + ON (c.ClientId = su_last_snapshot.ClientId) + AND (c.LastSnapshotTime = su_last_snapshot.CreationTime) + WHERE c.ClientId IN UNNEST({client_ids}) + """ + params = {"client_ids": client_ids} + + if min_last_ping is not None: + query += " AND c.LastPingTime >= {min_last_ping_time}" + params["min_last_ping_time"] = min_last_ping.AsDatetime() + + for row in self.db.ParamQuery( + query, params, txn_tag="MultiReadClientFullInfo" + ): + client_id, certificate_bytes, *row = row + fleetspeak_validation_info_bytes, *row = row + first_seen_time, last_ping_time, *row = row + last_foreman_time, *row = row + last_snapshot_time, last_startup_time, last_crash_time, *row = row + last_snapshot_bytes, *row = row + last_startup_bytes, last_snapshot_startup_bytes, *row = row + last_rrg_startup_bytes, *row = row + (label_rows,) = row + + info = objects_pb2.ClientFullInfo() + + if last_startup_bytes is not None: + info.last_startup_info.ParseFromString(last_startup_bytes) + info.last_startup_info.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_startup_time) + ) + + if last_snapshot_bytes is not None: + info.last_snapshot.ParseFromString(last_snapshot_bytes) + info.last_snapshot.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_snapshot_time) + ) + + if last_snapshot_startup_bytes is not None: + info.last_snapshot.startup_info.ParseFromString( + last_snapshot_startup_bytes + ) + info.last_snapshot.startup_info.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_snapshot_time) + ) + + if certificate_bytes is not None: + info.metadata.certificate = certificate_bytes + if fleetspeak_validation_info_bytes is not None: + info.metadata.last_fleetspeak_validation_info.ParseFromString( + fleetspeak_validation_info_bytes + ) + if first_seen_time is not None: + info.metadata.first_seen = int( + rdfvalue.RDFDatetime.FromDatetime(first_seen_time) + ) + if last_ping_time is not None: + info.metadata.ping = int( + rdfvalue.RDFDatetime.FromDatetime(last_ping_time) + ) + if last_foreman_time is not None: + info.metadata.last_foreman_time = int( + rdfvalue.RDFDatetime.FromDatetime(last_foreman_time) + ) + if last_startup_time is not None: + info.metadata.startup_info_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_startup_time) + ) + if last_crash_time is not None: + info.metadata.last_crash_timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(last_crash_time) + ) + + info.last_snapshot.client_id = client_id + + for owner, label in label_rows: + info.labels.add(owner=owner, name=label) + + if last_rrg_startup_bytes is not None: + info.last_rrg_startup.ParseFromString(last_rrg_startup_bytes) + + result[client_id] = info + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadClientLastPings( + self, + min_last_ping: Optional[rdfvalue.RDFDatetime] = None, + max_last_ping: Optional[rdfvalue.RDFDatetime] = None, + batch_size: int = 0, + ) -> Iterator[Mapping[str, Optional[rdfvalue.RDFDatetime]]]: + """Yields dicts of last-ping timestamps for clients in the DB.""" + if not batch_size: + batch_size = db_lib.CLIENT_IDS_BATCH_SIZE + + last_client_id = "0" + + while True: + client_last_pings_batch = self._ReadClientLastPingsBatch( + count=(batch_size or db_lib.CLIENT_IDS_BATCH_SIZE), + last_client_id=last_client_id, + min_last_ping_time=min_last_ping, + max_last_ping_time=max_last_ping, + ) + + if client_last_pings_batch: + yield client_last_pings_batch + if len(client_last_pings_batch) < batch_size: + break + + last_client_id = max(client_last_pings_batch.keys()) + + def _ReadClientLastPingsBatch( + self, + count: int, + last_client_id: str, + min_last_ping_time: Optional[rdfvalue.RDFDatetime], + max_last_ping_time: Optional[rdfvalue.RDFDatetime], + ) -> Mapping[str, Optional[rdfvalue.RDFDatetime]]: + """Reads a single batch of last client last ping times. + + Args: + count: The number of entries to read in the batch. + last_client_id: The identifier of the last client of the previous batch. + min_last_ping_time: An (optional) lower bound on the last ping time value. + max_last_ping_time: An (optional) upper bound on the last ping time value. + + Returns: + A mapping from client identifiers to client last ping times. + """ + result = {} + + query = """ + SELECT c.ClientId, c.LastPingTime + FROM Clients AS c + WHERE c.ClientId > {last_client_id} + """ + params = {"last_client_id": last_client_id} + + if min_last_ping_time is not None: + query += " AND c.LastPingTime >= {min_last_ping_time}" + params["min_last_ping_time"] = min_last_ping_time.AsDatetime() + if max_last_ping_time is not None: + query += ( + " AND (c.LastPingTime IS NULL OR " + " c.LastPingTime <= {max_last_ping_time})" + ) + params["max_last_ping_time"] = max_last_ping_time.AsDatetime() + + query += """ + ORDER BY ClientId + LIMIT {count} + """ + params["count"] = count + + for client_id, last_ping_time in self.db.ParamQuery( + query, params, txn_tag="ReadClientLastPingsBatch" + ): + + if last_ping_time is not None: + result[client_id] = rdfvalue.RDFDatetime.FromDatetime(last_ping_time) + else: + result[client_id] = None + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteClient(self, client_id: str) -> None: + """Deletes a client with all associated metadata.""" + + def Transaction(txn) -> None: + # It looks like Spanner does not raise exception if we attempt to delete + # a non-existing row, so we have to verify row existence ourself. + keyrange = spanner_lib.KeyRange(start_closed=[client_id], end_closed=[client_id]) + keyset = spanner_lib.KeySet(ranges=[keyrange]) + try: + txn.read(table="Clients", keyset=keyset, columns=["ClientId"]).one() + except NotFound as error: + raise db_lib.UnknownClientError(client_id, cause=error) + + txn.delete(table="Clients", keyset=keyset) + + self.db.Transact(Transaction, txn_tag="DeleteClient") + + @db_utils.CallLogged + @db_utils.CallAccounted + def MultiAddClientKeywords( + self, + client_ids: Collection[str], + keywords: Collection[str], + ) -> None: + """Associates the provided keywords with the specified clients.""" + # Early return to avoid generating empty mutation. + if not client_ids or not keywords: + return + + def Mutation(mut: spanner_utils.Mutation) -> None: + for client_id in client_ids: + rows = [] + for keyword in keywords: + row = [client_id, keyword, spanner_lib.COMMIT_TIMESTAMP] + rows.append(row) + columns = ["ClientId", "Keyword", "CreationTime"] + mut.insert_or_update(table="ClientKeywords", columns=columns, values=rows) + + try: + self.db.Mutate(Mutation, txn_tag="MultiAddClientKeywords") + except NotFound as error: + raise db_lib.AtLeastOneUnknownClientError(client_ids) from error + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListClientsForKeywords( + self, + keywords: Collection[str], + start_time: Optional[rdfvalue.RDFDatetime] = None, + ) -> Mapping[str, Collection[str]]: + """Lists the clients associated with keywords.""" + results = {keyword: [] for keyword in keywords} + + query = """ + SELECT k.Keyword, ARRAY_AGG(k.ClientId) + FROM ClientKeywords@{{FORCE_INDEX=ClientKeywordsByKeywordCreationTime}} AS k + WHERE k.Keyword IN UNNEST({keywords}) + """ + params = { + "keywords": list(keywords), + } + + if start_time is not None: + query += " AND k.CreationTime >= {cutoff_time}" + params["cutoff_time"] = start_time.AsDatetime() + + query += " GROUP BY k.Keyword" + + for keyword, client_ids in self.db.ParamQuery( + query, params, txn_tag="ListClientsForKeywords" + ): + results[keyword].extend(client_ids) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveClientKeyword(self, client_id: str, keyword: str) -> None: + """Removes the association of a particular client to a keyword.""" + self.db.Delete( + table="ClientKeywords", + key=(client_id, keyword), + txn_tag="RemoveClientKeyword", + ) + + +_EPOCH = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) + +_DELETE_BATCH_SIZE = 5_000 diff --git a/grr/server/grr_response_server/databases/spanner_clients_test.py b/grr/server/grr_response_server/databases/spanner_clients_test.py new file mode 100644 index 000000000..4611e75b2 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_clients_test.py @@ -0,0 +1,33 @@ +from absl.testing import absltest + +from grr_response_server.databases import db +from grr_response_server.databases import db_clients_test +from grr_response_server.databases import db_test_utils +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseClientsTest( + db_clients_test.DatabaseTestClientsMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + + def testLabelWriteToUnknownUser(self): + client_id = db_test_utils.InitializeClient(self.db) + + with self.assertRaises(db.UnknownGRRUserError) as ctx: + self.db.AddClientLabels(client_id, owner="foo", labels=["bar", "baz"]) + + self.assertIsInstance(ctx.exception, db.UnknownGRRUserError) + self.assertEqual(ctx.exception.username, "foo") + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs.py b/grr/server/grr_response_server/databases/spanner_cron_jobs.py new file mode 100644 index 000000000..ab3177995 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs.py @@ -0,0 +1,725 @@ +#!/usr/bin/env python +"""A module with cronjobs methods of the Spanner backend.""" + +import datetime +from typing import Any, Mapping, Optional, Sequence + +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib import utils +from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + +_UNCHANGED = db.Database.UNCHANGED + + +class CronJobsMixin: + """A Spanner database mixin with implementation of cronjobs.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteCronJob(self, cronjob: flows_pb2.CronJob) -> None: + """Writes a cronjob to the database. + + Args: + cronjob: A flows_pb2.CronJob object. + """ + # We currently expect to reuse `created_at` if set. + rdf_created_at = rdfvalue.RDFDatetime().FromMicrosecondsSinceEpoch( + cronjob.created_at + ) + creation_time = rdf_created_at.AsDatetime() or spanner_lib.COMMIT_TIMESTAMP + + row = { + "JobId": cronjob.cron_job_id, + "Job": cronjob, + "Enabled": bool(cronjob.enabled), + "CreationTime": creation_time, + } + + self.db.InsertOrUpdate(table="CronJobs", row=row, txn_tag="WriteCronJob") + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadCronJobs( + self, cronjob_ids: Optional[Sequence[str]] = None + ) -> Sequence[flows_pb2.CronJob]: + """Reads all cronjobs from the database. + + Args: + cronjob_ids: A list of cronjob ids to read. If not set, returns all cron + jobs in the database. + + Returns: + A list of flows_pb2.CronJob objects. + + Raises: + UnknownCronJobError: A cron job for at least one of the given ids + does not exist. + """ + where_ids = "" + params = {} + if cronjob_ids: + where_ids = " WHERE cj.JobId IN UNNEST(@cronjob_ids)" + params["cronjob_ids"] = cronjob_ids + + def Transaction(txn) -> Sequence[flows_pb2.CronJob]: + return self._SelectCronJobsWith(txn, where_ids, params) + + res = self.db.Transact(Transaction, txn_tag="ReadCronJobs") + + if cronjob_ids and len(res) != len(cronjob_ids): + missing = set(cronjob_ids) - set([c.cron_job_id for c in res]) + raise db.UnknownCronJobError( + "CronJob(s) with id(s) %s not found." % missing + ) + return res + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateCronJob( # pytype: disable=annotation-type-mismatch + self, + cronjob_id: str, + last_run_status: Optional[ + "flows_pb2.CronJobRun.CronJobRunStatus" + ] = _UNCHANGED, + last_run_time: Optional[rdfvalue.RDFDatetime] = _UNCHANGED, + current_run_id: Optional[str] = _UNCHANGED, + state: Optional[jobs_pb2.AttributedDict] = _UNCHANGED, + forced_run_requested: Optional[bool] = _UNCHANGED, + ): + """Updates run information for an existing cron job. + + Args: + cronjob_id: The id of the cron job to update. + last_run_status: A CronJobRunStatus object. + last_run_time: The last time a run was started for this cron job. + current_run_id: The id of the currently active run. + state: The state dict for stateful cron jobs. + forced_run_requested: A boolean indicating if a forced run is pending for + this job. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + row = {"JobId": cronjob_id} + if last_run_status is not _UNCHANGED: + row["LastRunStatus"] = int(last_run_status) + if last_run_time is not _UNCHANGED: + row["LastRunTime"] = ( + last_run_time.AsDatetime() if last_run_time else last_run_time + ) + if current_run_id is not _UNCHANGED: + row["CurrentRunId"] = current_run_id + if state is not _UNCHANGED: + row["State"] = state if state is not None else None + if forced_run_requested is not _UNCHANGED: + row["ForcedRunRequested"] = forced_run_requested + + try: + self.db.Update("CronJobs", row=row, txn_tag="UpdateCronJob") + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error + + @db_utils.CallLogged + @db_utils.CallAccounted + def EnableCronJob(self, cronjob_id: str) -> None: + """Enables a cronjob. + + Args: + cronjob_id: The id of the cron job to enable. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + row = { + "JobId": cronjob_id, + "Enabled": True, + } + try: + self.db.Update("CronJobs", row=row, txn_tag="EnableCronJob") + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DisableCronJob(self, cronjob_id: str) -> None: + """Deletes a cronjob along with all its runs. + + Args: + cronjob_id: The id of the cron job to delete. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + row = { + "JobId": cronjob_id, + "Enabled": False, + } + try: + self.db.Update("CronJobs", row=row, txn_tag="DisableCronJob") + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteCronJob(self, cronjob_id: str) -> None: + """Deletes a cronjob along with all its runs. + + Args: + cronjob_id: The id of the cron job to delete. + + Raises: + UnknownCronJobError: A cron job with the given id does not exist. + """ + def Transaction(txn) -> None: + # Spanner does not raise if we attempt to delete a non-existing row so + # we check it exists ourselves. + keyset = spanner_lib.KeySet(keys=[[cronjob_id]]) + try: + txn.read(table="CronJobs", keyset=keyset, columns=["JobId"]).one() + except NotFound as error: + raise db.UnknownCronJobError(cronjob_id) from error + + txn.delete(table="CronJobs", keyset=keyset) + + self.db.Transact(Transaction, txn_tag="DeleteCronJob") + + + @db_utils.CallLogged + @db_utils.CallAccounted + def LeaseCronJobs( + self, + cronjob_ids: Optional[Sequence[str]] = None, + lease_time: Optional[rdfvalue.Duration] = None, + ) -> Sequence[flows_pb2.CronJob]: + """Leases all available cron jobs. + + Args: + cronjob_ids: A list of cronjob ids that should be leased. If None, all + available cronjobs will be leased. + lease_time: rdfvalue.Duration indicating how long the lease should be + valid. + + Returns: + A list of cronjobs.CronJob objects that were leased. + """ + + # We can't simply Update the rows because `UPDATE ... SET` will not return + # the affected rows. So, this transaction is broken up in three parts: + # 1. Identify the rows that will be updated + # 2. Update these rows with the new lease information + # 3. Read back the affected rows and return them + def Transaction(txn) -> Sequence[flows_pb2.CronJob]: + now = rdfvalue.RDFDatetime.Now() + lease_end_time = now + lease_time + lease_owner = utils.ProcessIdString() + + # --------------------------------------------------------------------- + # Query IDs to be updated on this transaction + # --------------------------------------------------------------------- + query_ids_to_update = """ + SELECT cj.JobId + FROM CronJobs as cj + WHERE (cj.LeaseEndTime IS NULL OR cj.LeaseEndTime < @now) + """ + params_ids_to_update = {"now": now.AsDatetime()} + + if cronjob_ids: + query_ids_to_update += "AND cj.JobId IN UNNEST(@cronjob_ids)" + params_ids_to_update["cronjob_ids"] = cronjob_ids + + response = txn.execute_sql(sql=query_ids_to_update, params=params_ids_to_update, + request_options={"request_tag": "LeaseCronJobs:CronJobs:execute_sql"}) + + ids_to_update = [] + for (job_id,) in response: + ids_to_update.append(job_id) + + if not ids_to_update: + return [] + + # --------------------------------------------------------------------- + # Effectively update them with this process as owner + # --------------------------------------------------------------------- + update_query = """ + UPDATE CronJobs as cj + SET cj.LeaseEndTime = @lease_end_time, cj.LeaseOwner = @lease_owner + WHERE cj.JobId IN UNNEST(@ids_to_update) + """ + update_params = { + "lease_end_time": lease_end_time.AsDatetime(), + "lease_owner": lease_owner, + "ids_to_update": ids_to_update, + } + + txn.execute_update(update_query, update_params, + request_options={"request_tag": "LeaseCronJobs:CronJobs:execute_update"}) + + # --------------------------------------------------------------------- + # Query (and return) jobs that were updated + # --------------------------------------------------------------------- + where_updated = """ + WHERE (cj.LeaseOwner = @lease_owner) + AND cj.JobId IN UNNEST(@updated_ids) + """ + updated_params = { + "lease_owner": lease_owner, + "updated_ids": ids_to_update, + } + + return self._SelectCronJobsWith(txn, where_updated, updated_params) + + return self.db.Transact(Transaction, txn_tag="LeaseCronJobs") + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReturnLeasedCronJobs(self, jobs: Sequence[flows_pb2.CronJob]) -> None: + """Makes leased cron jobs available for leasing again. + + Args: + jobs: A list of leased cronjobs. + + Raises: + ValueError: If not all of the cronjobs are leased. + """ + if not jobs: + return + + # Identify jobs that are not lease (cannot be returned). If there are any + # leased jobs that need returning, then we'll go ahead and try to update + # them anyway. + unleased_jobs = [] + conditions = [] + jobs_to_update_args = {} + for i, job in enumerate(jobs): + if not job.leased_by or not job.leased_until: + unleased_jobs.append(job) + continue + + conditions.append( + "(cj.JobId=@job_%d AND " + "cj.LeaseEndTime=@ld_%d AND " + "cj.LeaseOwner=@lo_%d)" % (i, i, i) + ) + dt_leased_until = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(job.leased_until) + .AsDatetime() + ) + jobs_to_update_args[i] = ( + job.cron_job_id, + dt_leased_until, + job.leased_by, + ) + + if not conditions: # all jobs are unleased. + raise ValueError("CronJobs to return are not leased: %s" % unleased_jobs) + + # We can't simply Update the rows because `UPDATE ... SET` will not return + # the affected rows. We need both the _already_ disabled jobs and the + # updated rows in order to raise the appropriate exceptions. + # 1. Identify the rows that need to be updated + # 2. Update the relevant rows with the new unlease information + # 3. Read back the affected rows and return them + def Transaction(txn) -> Sequence[flows_pb2.CronJob]: + # --------------------------------------------------------------------- + # Query IDs to be updated on this transaction + # --------------------------------------------------------------------- + query_job_ids_to_return = """ + SELECT cj.JobId + FROM CronJobs as cj + """ + params_job_ids_to_return = {} + query_job_ids_to_return += "WHERE" + " OR ".join(conditions) + for i, (job_id, ld, lo) in jobs_to_update_args.items(): + params_job_ids_to_return["job_%d" % i] = job_id + params_job_ids_to_return["ld_%d" % i] = ld + params_job_ids_to_return["lo_%d" % i] = lo + + response = txn.execute_sql( + query_job_ids_to_return, params_job_ids_to_return + ) + + ids_to_return = [] + for (job_id,) in response: + ids_to_return.append(job_id) + + if not ids_to_return: + return [] + + # --------------------------------------------------------------------- + # Effectively update them, removing owners + # --------------------------------------------------------------------- + update_query = """ + UPDATE CronJobs as cj + SET cj.LeaseEndTime = NULL, cj.LeaseOwner = NULL + WHERE cj.JobId IN UNNEST(@ids_to_return) + """ + update_params = { + "ids_to_return": ids_to_return, + } + + txn.execute_update(update_query, update_params, + request_options={"request_tag": "ReturnLeasedCronJobs:CronJobs:execute_update"}) + + # --------------------------------------------------------------------- + # Query (and return) jobs that were updated + # --------------------------------------------------------------------- + where_returned = """ + WHERE cj.JobId IN UNNEST(@updated_ids) + """ + returned_params = { + "updated_ids": ids_to_return, + } + + returned_jobs = self._SelectCronJobsWith( + txn, where_returned, returned_params + ) + + return returned_jobs + + returned_jobs = self.db.Transact( + Transaction, txn_tag="ReturnLeasedCronJobs" + ) + if unleased_jobs: + raise ValueError("CronJobs to return are not leased: %s" % unleased_jobs) + if len(returned_jobs) != len(jobs): + raise ValueError( + "%d cronjobs in %s could not be returned. Successfully returned: %s" + % ((len(jobs) - len(returned_jobs)), jobs, returned_jobs) + ) + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteCronJobRun(self, run_object: flows_pb2.CronJobRun) -> None: + """Stores a cron job run object in the database. + + Args: + run_object: A flows_pb2.CronJobRun object to store. + """ + # If created_at is set, we use that instead of the commit timestamp. + # This is important for overriding timestamps in testing. Ideally, we would + # have a better/easier way to mock CommitTimestamp instead. + if run_object.started_at: + creation_time = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(run_object.started_at) + .AsDatetime() + ) + else: + creation_time = spanner_lib.COMMIT_TIMESTAMP + + row = { + "JobId": run_object.cron_job_id, + "RunId": run_object.run_id, + "CreationTime": creation_time, + "Payload": run_object, + "Status": int(run_object.status) or 0, + } + if run_object.finished_at: + row["FinishTime"] = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(run_object.finished_at) + .AsDatetime() + ) + if run_object.log_message: + row["LogMessage"] = run_object.log_message + if run_object.backtrace: + row["Backtrace"] = run_object.backtrace + + try: + self.db.InsertOrUpdate( + table="CronJobRuns", row=row, txn_tag="WriteCronJobRun" + ) + except Exception as error: + if "Parent row for row [" in str(error): + # This error can be raised only when the parent cron job does not exist. + message = f"Cron job with id '{run_object.cron_job_id}' not found." + raise db.UnknownCronJobError(message) from error + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadCronJobRuns(self, job_id: str) -> Sequence[flows_pb2.CronJobRun]: + """Reads all cron job runs for a given job id. + + Args: + job_id: Runs will be returned for the job with the given id. + + Returns: + A list of flows_pb2.CronJobRun objects. + """ + cols = [ + "Payload", + "JobId", + "RunId", + "CreationTime", + "FinishTime", + "Status", + "LogMessage", + "Backtrace", + ] + rowrange = spanner_lib.KeyRange(start_closed=[job_id], end_closed=[job_id]) + rows = spanner_lib.KeySet(ranges=[rowrange]) + + res = [] + for row in self.db.ReadSet(table="CronJobRuns", + rows=rows, + cols=cols, + txn_tag="ReadCronJobRuns"): + res.append( + _CronJobRunFromRow( + job_run=row[0], + job_id=row[1], + run_id=row[2], + creation_time=row[3], + finish_time=row[4], + status=row[5], + log_message=row[6], + backtrace=row[7], + ) + ) + + return sorted(res, key=lambda run: run.started_at or 0, reverse=True) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadCronJobRun(self, job_id: str, run_id: str) -> flows_pb2.CronJobRun: + """Reads a single cron job run from the db. + + Args: + job_id: The job_id of the run to be read. + run_id: The run_id of the run to be read. + + Returns: + An flows_pb2.CronJobRun object. + """ + cols = [ + "Payload", + "JobId", + "RunId", + "CreationTime", + "FinishTime", + "Status", + "LogMessage", + "Backtrace", + ] + try: + row = self.db.Read(table="CronJobRuns", + key=(job_id, run_id), + cols=cols, + txn_tag="ReadCronJobRun") + except NotFound as error: + raise db.UnknownCronJobRunError( + "Run with job id %s and run id %s not found." % (job_id, run_id) + ) from error + + return _CronJobRunFromRow( + job_run=row[0], + job_id=row[1], + run_id=row[2], + creation_time=row[3], + finish_time=row[4], + status=row[5], + log_message=row[6], + backtrace=row[7], + ) + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteOldCronJobRuns(self, cutoff_timestamp: rdfvalue.RDFDatetime) -> int: + """Deletes cron job runs that are older than cutoff_timestamp. + + Args: + cutoff_timestamp: This method deletes all runs that were started before + cutoff_timestamp. + + Returns: + The number of deleted runs. + """ + query = """ + SELECT cjr.JobId, cjr.RunId + FROM CronJobRuns AS cjr + WHERE cjr.CreationTime < @cutoff_timestamp + """ + params = {"cutoff_timestamp": cutoff_timestamp.AsDatetime()} + + def Transaction(txn) -> int: + rows = list(txn.execute_sql(sql=query, params=params)) + + for job_id, run_id in rows: + keyset = spanner_lib.KeySet(keys=[[job_id, run_id]]) + txn.delete( + table="CronJobRuns", + keyset=keyset, + ) + + return len(rows) + + return self.db.Transact(Transaction, txn_tag="DeleteOldCronJobRuns") + + def _SelectCronJobsWith( + self, + txn, + where_clause: str, + params: Mapping[str, Any], + ) -> Sequence[flows_pb2.CronJob]: + """Reads rows within the transaction and converts results into CronJobs. + + Args: + txn: a transaction that will param query and return a cursor for the rows. + where_clause: where clause for filtering the rows. + params: params to be applied to the where_clause. + + Returns: + A list of CronJobs read from the database. + """ + + query = """ + SELECT cj.Job, cj.JobId, cj.CreationTime, cj.Enabled, + cj.ForcedRunRequested, cj.LastRunStatus, cj.LastRunTime, + cj.CurrentRunId, cj.State, cj.LeaseEndTime, cj.LeaseOwner + FROM CronJobs as cj + """ + query += where_clause + + response = txn.execute_sql(sql=query, params=params) + + res = [] + for row in response: + ( + job, + job_id, + creation_time, + enabled, + forced_run_requested, + last_run_status, + last_run_time, + current_run_id, + state, + lease_end_time, + lease_owner, + ) = row + res.append( + _CronJobFromRow( + job=job, + job_id=job_id, + creation_time=creation_time, + enabled=enabled, + forced_run_requested=forced_run_requested, + last_run_status=last_run_status, + last_run_time=last_run_time, + current_run_id=current_run_id, + state=state, + lease_end_time=lease_end_time, + lease_owner=lease_owner, + ) + ) + + return res + + +def _CronJobFromRow( + job: Optional[bytes] = None, + job_id: Optional[str] = None, + creation_time: Optional[datetime.datetime] = None, + enabled: Optional[bool] = None, + forced_run_requested: Optional[bool] = None, + last_run_status: Optional[int] = None, + last_run_time: Optional[datetime.datetime] = None, + current_run_id: Optional[str] = None, + state: Optional[bytes] = None, + lease_end_time: Optional[datetime.datetime] = None, + lease_owner: Optional[str] = None, +) -> flows_pb2.CronJob: + """Creates a CronJob object from a database result row.""" + + if job is not None: + parsed = flows_pb2.CronJob() + parsed.ParseFromString(job) + job = parsed + else: + job = flows_pb2.CronJob( + created_at=rdfvalue.RDFDatetime.Now().AsMicrosecondsSinceEpoch(), + ) + + if job_id is not None: + job.cron_job_id = job_id + if current_run_id is not None: + job.current_run_id = current_run_id + if enabled is not None: + job.enabled = enabled + if forced_run_requested is not None: + job.forced_run_requested = forced_run_requested + if last_run_status is not None: + job.last_run_status = last_run_status + if last_run_time is not None: + job.last_run_time = rdfvalue.RDFDatetime.FromDatetime( + last_run_time + ).AsMicrosecondsSinceEpoch() + if state is not None: + read_state = jobs_pb2.AttributedDict() + read_state.ParseFromString(state) + job.state.CopyFrom(read_state) + if creation_time is not None: + job.created_at = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + if lease_end_time is not None: + job.leased_until = rdfvalue.RDFDatetime.FromDatetime( + lease_end_time + ).AsMicrosecondsSinceEpoch() + if lease_owner is not None: + job.leased_by = lease_owner + + return job + +def _CronJobRunFromRow( + job_run: Optional[bytes] = None, + job_id: Optional[str] = None, + run_id: Optional[str] = None, + creation_time: Optional[datetime.datetime] = None, + finish_time: Optional[datetime.datetime] = None, + status: Optional[int] = None, + log_message: Optional[str] = None, + backtrace: Optional[str] = None, +) -> flows_pb2.CronJobRun: + """Creates a CronJobRun object from a database result row.""" + + if job_run is not None: + parsed = flows_pb2.CronJobRun() + parsed.ParseFromString(job_run) + job_run = parsed + else: + job_run = flows_pb2.CronJobRun() + + if job_id is not None: + job_run.cron_job_id = job_id + if run_id is not None: + job_run.run_id = run_id + if creation_time is not None: + job_run.created_at = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + if finish_time is not None: + job_run.finished_at = rdfvalue.RDFDatetime.FromDatetime( + finish_time + ).AsMicrosecondsSinceEpoch() + if status is not None: + job_run.status = status + if log_message is not None: + job_run.log_message = log_message + if backtrace is not None: + job_run.backtrace = backtrace + + return job_run diff --git a/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py new file mode 100644 index 000000000..c033b6a20 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_cron_jobs_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_cronjob_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseCronJobsTest( + db_cronjob_test.DatabaseTestCronJobMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_events.py b/grr/server/grr_response_server/databases/spanner_events.py new file mode 100644 index 000000000..1a64d8b22 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_events.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +"""A library with audit events methods of Spanner database implementation.""" + +from typing import Dict, List, Optional, Tuple + +from google.cloud import spanner as spanner_lib + +from grr_response_core.lib import rdfvalue +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class EventsMixin: + """A Spanner database mixin with implementation of audit events methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAPIAuditEntries( + self, + username: Optional[str] = None, + router_method_names: Optional[List[str]] = None, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> List[objects_pb2.APIAuditEntry]: + """Returns audit entries stored in the database.""" + query = """ + SELECT + a.Username, + a.CreationTime, + a.HttpRequestPath, + a.RouterMethodName, + a.ResponseCode + FROM ApiAuditEntry AS a + """ + + params = {} + conditions = [] + + if username is not None: + conditions.append("a.Username = {username}") + params["username"] = username + + if router_method_names: + param_placeholders = ", ".join([f"{{rmn{i}}}" for i in range(len(router_method_names))]) + for i, rmn in enumerate(router_method_names): + param_name = f"rmn{i}" + params[param_name] = rmn + conditions.append(f"""a.RouterMethodName IN ({param_placeholders})""") + + if min_timestamp is not None: + conditions.append("a.CreationTime >= {min_timestamp}") + params["min_timestamp"] = min_timestamp.AsDatetime() + + if max_timestamp is not None: + conditions.append("a.CreationTime <= {max_timestamp}") + params["max_timestamp"] = max_timestamp.AsDatetime() + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + result = [] + for ( + username, + ts, + http_request_path, + router_method_name, + response_code, + ) in self.db.ParamQuery(query, params, txn_tag="ReadAPIAuditEntries"): + result.append( + objects_pb2.APIAuditEntry( + username=username, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + ts + ).AsMicrosecondsSinceEpoch(), + http_request_path=http_request_path, + router_method_name=router_method_name, + response_code=response_code, + ) + ) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountAPIAuditEntriesByUserAndDay( + self, + min_timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Dict[Tuple[str, rdfvalue.RDFDatetime], int]: + """Returns audit entry counts grouped by user and calendar day.""" + query = """ + SELECT + a.Username, + TIMESTAMP_TRUNC(a.CreationTime, DAY, "UTC") AS day, + COUNT(*) + FROM ApiAuditEntry AS a + """ + + params = {} + conditions = [] + + if min_timestamp is not None: + conditions.append("a.CreationTime >= {min_timestamp}") + params["min_timestamp"] = min_timestamp.AsDatetime() + + if max_timestamp is not None: + conditions.append("a.CreationTime <= {max_timestamp}") + params["max_timestamp"] = max_timestamp.AsDatetime() + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += " GROUP BY a.Username, day" + + result = {} + for username, day, count in self.db.ParamQuery( + query, params, txn_tag="CountAPIAuditEntriesByUserAndDay" + ): + result[(username, rdfvalue.RDFDatetime.FromDatetime(day))] = count + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteAPIAuditEntry(self, entry: objects_pb2.APIAuditEntry): + """Writes an audit entry to the database.""" + row = { + "HttpRequestPath": entry.http_request_path, + "RouterMethodName": entry.router_method_name, + "Username": entry.username, + "ResponseCode": entry.response_code, + } + + if not entry.HasField("timestamp"): + row["CreationTime"] = spanner_lib.COMMIT_TIMESTAMP + else: + row["CreationTime"] = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + entry.timestamp + ).AsDatetime() + + self.db.InsertOrUpdate( + table="ApiAuditEntry", row=row, txn_tag="WriteAPIAuditEntry" + ) diff --git a/grr/server/grr_response_server/databases/spanner_events_test.py b/grr/server/grr_response_server/databases/spanner_events_test.py new file mode 100644 index 000000000..db21a0582 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_events_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_events_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseEventsTest( + db_events_test.DatabaseTestEventsMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_flows.py b/grr/server/grr_response_server/databases/spanner_flows.py new file mode 100644 index 000000000..35852a7c8 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_flows.py @@ -0,0 +1,2514 @@ +#!/usr/bin/env python +"""A module with flow methods of the Spanner database implementation.""" + +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence +import dataclasses +import logging +import threading +import time +from typing import Any, Optional, Union +import uuid + +from google.api_core.exceptions import AlreadyExists, NotFound +from google.cloud import spanner as spanner_lib +from google.cloud.spanner_v1 import param_types + +from grr_response_core.lib import rdfvalue +from grr_response_core.lib import utils +from grr_response_core.lib.util import collection +from grr_response_core.stats import metrics +from grr_response_proto import flows_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_proto import rrg_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import hunts as models_hunts + + +SPANNER_DELETE_FLOW_REQUESTS_FAILURES = metrics.Counter( + name="spanner_delete_flow_requests_failures" +) + +_MESSAGE_HANDLER_MAX_KEEPALIVE_SECONDS = 300 +_MESSAGE_HANDLER_MAX_ACTIVE_CALLBACKS = 20 + +_MILLISECONDS = 1000 +_SECONDS = 1000 * _MILLISECONDS + +@dataclasses.dataclass(frozen=True) +class _FlowKey: + """Unique key identifying a flow in helper methods.""" + + client_id: str + flow_id: str + + +@dataclasses.dataclass(frozen=True) +class _RequestKey: + """Unique key identifying a flow request in helper methods.""" + + client_id: str + flow_id: str + request_id: int + + +@dataclasses.dataclass(frozen=True) +class _ResponseKey: + """Unique key identifying a flow response in helper methods.""" + + client_id: str + flow_id: str + request_id: int + response_id: int + + +_UNCHANGED = db.Database.UNCHANGED +_UNCHANGED_TYPE = db.Database.UNCHANGED_TYPE + + +def _BuildReadFlowResultsErrorsConditions( + table_name: str, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, +) -> tuple[str, Mapping[str, Any], Mapping[str, Any]]: + """Builds query string and params for results/errors reading queries.""" + params = {} + param_type = {} + + query = f""" + SELECT t.Payload, t.RdfType, t.CreationTime, t.Tag, t.HuntId + FROM {table_name} AS t + """ + + query += """ + WHERE t.ClientId = {client_id} AND t.FlowId = {flow_id} + """ + + params["client_id"] = client_id + params["flow_id"] = flow_id + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + param_type["tag"] = param_types.STRING + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + param_type["type"] = param_types.STRING + + if with_substring is not None: + query += """ + AND STRPOS(SAFE_CONVERT_BYTES_TO_STRING(t.Payload.value), {substring}) != 0 + """ + params["substring"] = with_substring + param_type["substring"] = param_types.STRING + + query += """ + ORDER BY t.CreationTime ASC LIMIT {count} OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + return query, params, param_type + + +def _BuildCountFlowResultsErrorsConditions( + table_name: str, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, +) -> tuple[str, Mapping[str, Any], Mapping[str, Any]]: + """Builds query string and params for count flow results/errors queries.""" + params = {} + param_type = {} + + query = f""" + SELECT COUNT(*) + FROM {table_name} AS t + """ + + query += """ + WHERE t.ClientId = {client_id} AND t.FlowId = {flow_id} + """ + + params["client_id"] = client_id + params["flow_id"] = flow_id + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + param_type["tag"] = param_types.STRING + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + param_type["type"] = param_types.STRING + + return query, params, param_type + + +_READ_FLOW_OBJECT_COLS = ( + "LongFlowId", + "ParentFlowId", + "ParentHuntId", + "Creator", + "Name", + "State", + "CreationTime", + "UpdateTime", + "Crash", + "ProcessingWorker", + "ProcessingStartTime", + "ProcessingEndTime", + "NextRequestToProcess", + "Flow", +) + + +def _ParseReadFlowObjectRow( + client_id: str, + flow_id: str, + row: Mapping[str, Any], +) -> flows_pb2.Flow: + """Parses a row fetched with _READ_FLOW_OBJECT_COLS.""" + result = flows_pb2.Flow() + result.ParseFromString(row[13]) + + creation_time = rdfvalue.RDFDatetime.FromDatetime(row[6]) + update_time = rdfvalue.RDFDatetime.FromDatetime(row[7]) + + # We treat column values as the source of truth for values, not the message + # in the database itself. At least this is what the F1 implementation does. + result.client_id = client_id + result.flow_id = flow_id + result.long_flow_id = row[0] + + if row[1] is not None: + result.parent_flow_id = row[1] + if row[2] is not None: + result.parent_hunt_id = row[2] + + if row[4] is not None: + result.flow_class_name = row[4] + if row[3] is not None: + result.creator = row[3] + if row[5] not in [None, flows_pb2.Flow.FlowState.UNSET]: + result.flow_state = row[5] + if row[12]: + result.next_request_to_process = int(row[12]) + + result.create_time = int(creation_time) + result.last_update_time = int(update_time) + + if row[8] is not None: + client_crash = jobs_pb2.ClientCrash() + client_crash.ParseFromString(row[8]) + result.client_crash_info.CopyFrom(client_crash) + + result.ClearField("processing_on") + if row[9] is not None: + result.processing_on = row[9] + result.ClearField("processing_since") + if row[10] is not None: + result.processing_since = int( + rdfvalue.RDFDatetime.FromDatetime(row[10]) + ) + result.ClearField("processing_deadline") + if row[11] is not None: + result.processing_deadline = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + + return result + + +class FlowsMixin: + """A Spanner database mixin with implementation of flow methods.""" + + db: spanner_utils.Database + _write_rows_batch_size: int + + handler_thread: threading.Thread + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowObject( + self, + flow_obj: flows_pb2.Flow, + allow_update: bool = True, + ) -> None: + """Writes a flow object to the database.""" + client_id = flow_obj.client_id + flow_id = flow_obj.flow_id + + row = { + "ClientId": client_id, + "FlowId": flow_id, + "LongFlowId": flow_obj.long_flow_id, + } + + if flow_obj.parent_flow_id: + row["ParentFlowId"] = flow_obj.parent_flow_id + if flow_obj.parent_hunt_id: + row["ParentHuntId"] = flow_obj.parent_hunt_id + + row["Creator"] = flow_obj.creator + row["Name"] = flow_obj.flow_class_name + row["State"] = int(flow_obj.flow_state) + row["NextRequestToProcess"] = flow_obj.next_request_to_process + + row["CreationTime"] = spanner_lib.COMMIT_TIMESTAMP + row["UpdateTime"] = spanner_lib.COMMIT_TIMESTAMP + + if flow_obj.HasField("client_crash_info"): + row["Crash"] = flow_obj.client_crash_info + + if flow_obj.HasField("processing_on"): + row["ProcessingWorker"] = flow_obj.processing_on + if flow_obj.HasField("processing_since"): + row["ProcessingStartTime"] = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(flow_obj.processing_since) + .AsDatetime() + ) + if flow_obj.HasField("processing_deadline"): + row["ProcessingEndTime"] = ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(flow_obj.processing_deadline) + .AsDatetime() + ) + + row["Flow"] = flow_obj + + row["ReplyCount"] = int(flow_obj.num_replies_sent) + row["NetworkBytesSent"] = int(flow_obj.network_bytes_sent) + row["UserCpuTimeUsed"] = float(flow_obj.cpu_time_used.user_cpu_time) + row["SystemCpuTimeUsed"] = float(flow_obj.cpu_time_used.system_cpu_time) + + try: + if allow_update: + self.db.InsertOrUpdate( + table="Flows", row=row, txn_tag="WriteFlowObject_IOU" + ) + else: + self.db.Insert(table="Flows", row=row, txn_tag="WriteFlowObject_I") + except AlreadyExists as error: + raise db.FlowExistsError(client_id, flow_id) from error + except Exception as error: + if "Parent row for row [" in str(error): + raise db.UnknownClientError(client_id) + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowObject( + self, + client_id: str, + flow_id: str, + ) -> flows_pb2.Flow: + """Reads a flow object from the database.""" + + try: + row = self.db.Read( + table="Flows", + key=[client_id, flow_id], + cols=_READ_FLOW_OBJECT_COLS, + txn_tag="ReadFlowObject" + ) + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + + flow = _ParseReadFlowObjectRow(client_id, flow_id, row) + return flow + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllFlowObjects( + self, + client_id: Optional[str] = None, + parent_flow_id: Optional[str] = None, + min_create_time: Optional[rdfvalue.RDFDatetime] = None, + max_create_time: Optional[rdfvalue.RDFDatetime] = None, + include_child_flows: bool = True, + not_created_by: Optional[Iterable[str]] = None, + ) -> Sequence[flows_pb2.Flow]: + """Returns all flow objects that meet the specified conditions.""" + result = [] + + query = """ + SELECT f.ClientId, f.FlowId, f.LongFlowId, + f.ParentFlowId, f.ParentHuntId, + f.Creator, f.Name, f.State, + f.CreationTime, f.UpdateTime, + f.Crash, f.NextRequestToProcess, + f.Flow + FROM Flows AS f + """ + params = {} + + conds = [] + + if client_id is not None: + params["client_id"] = client_id + conds.append("f.ClientId = {client_id}") + if parent_flow_id is not None: + params["parent_flow_id"] = parent_flow_id + conds.append("f.ParentFlowId = {parent_flow_id}") + if min_create_time is not None: + params["min_creation_time"] = min_create_time.AsDatetime() + conds.append("f.CreationTime >= {min_creation_time}") + if max_create_time is not None: + params["max_creation_time"] = max_create_time.AsDatetime() + conds.append("f.CreationTime <= {max_creation_time}") + if not include_child_flows: + conds.append("f.ParentFlowId IS NULL") + if not_created_by is not None: + params["not_created_by"] = list(not_created_by) + conds.append("f.Creator NOT IN UNNEST({not_created_by})") + + if conds: + query += f" WHERE {' AND '.join(conds)}" + + for row in self.db.ParamQuery(query, params, txn_tag="ReadAllFlowObjects"): + client_id, flow_id, long_flow_id, *row = row + parent_flow_id, parent_hunt_id, *row = row + creator, name, state, *row = row + creation_time, update_time, *row = row + crash_bytes, next_request_to_process, flow_bytes = row + + flow = flows_pb2.Flow() + flow.ParseFromString(flow_bytes) + flow.client_id = client_id + flow.flow_id = flow_id + flow.long_flow_id = long_flow_id + flow.next_request_to_process = int(next_request_to_process) + + if parent_flow_id is not None: + flow.parent_flow_id = parent_flow_id + if parent_hunt_id is not None: + flow.parent_hunt_id = parent_hunt_id + + flow.creator = creator + flow.flow_state = state + flow.flow_class_name = name + + flow.create_time = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + flow.last_update_time = rdfvalue.RDFDatetime.FromDatetime( + update_time + ).AsMicrosecondsSinceEpoch() + + if crash_bytes is not None: + flow.client_crash_info.ParseFromString(crash_bytes) + + result.append(flow) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateFlow( + self, + client_id: str, + flow_id: str, + flow_obj: Union[flows_pb2.Flow, _UNCHANGED_TYPE] = _UNCHANGED, + flow_state: Union[ + flows_pb2.Flow.FlowState.ValueType, _UNCHANGED_TYPE + ] = _UNCHANGED, + client_crash_info: Union[ + jobs_pb2.ClientCrash, _UNCHANGED_TYPE + ] = _UNCHANGED, + processing_on: Optional[Union[str, _UNCHANGED_TYPE]] = _UNCHANGED, + processing_since: Optional[ + Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] + ] = _UNCHANGED, + processing_deadline: Optional[ + Union[rdfvalue.RDFDatetime, _UNCHANGED_TYPE] + ] = _UNCHANGED, + ) -> None: + """Updates flow objects in the database.""" + + row = { + "ClientId": client_id, + "FlowId": flow_id, + "UpdateTime": spanner_lib.COMMIT_TIMESTAMP, + } + + if isinstance(flow_obj, flows_pb2.Flow): + row["Flow"] = flow_obj + row["State"] = int(flow_obj.flow_state) + row["ReplyCount"] = int(flow_obj.num_replies_sent) + row["NetworkBytesSent"] = int(flow_obj.network_bytes_sent) + row["UserCpuTimeUsed"] = float(flow_obj.cpu_time_used.user_cpu_time) + row["SystemCpuTimeUsed"] = float(flow_obj.cpu_time_used.system_cpu_time) + if isinstance(flow_state, flows_pb2.Flow.FlowState.ValueType): + row["State"] = int(flow_state) + if isinstance(client_crash_info, jobs_pb2.ClientCrash): + row["Crash"] = client_crash_info + if ( + isinstance(processing_on, str) and processing_on is not db.UNCHANGED + ) or processing_on is None: + row["ProcessingWorker"] = processing_on + if isinstance(processing_since, rdfvalue.RDFDatetime): + row["ProcessingStartTime"] = processing_since.AsDatetime() + if processing_since is None: + row["ProcessingStartTime"] = None + if isinstance(processing_deadline, rdfvalue.RDFDatetime): + row["ProcessingEndTime"] = processing_deadline.AsDatetime() + if processing_deadline is None: + row["ProcessingEndTime"] = None + + try: + self.db.Update(table="Flows", row=row, txn_tag="UpdateFlow") + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: + """Writes flow results for a given flow.""" + for batch in collection.Batch(results, self._write_rows_batch_size): + self._WriteFlowResults(batch) + + def _WriteFlowResults(self, results: Sequence[flows_pb2.FlowResult]) -> None: + """Writes flow errors for a given flow.""" + + def Mutation(mut) -> None: + rows = [] + columns = ["ClientId", "FlowId", "HuntId", "CreationTime", + "Tag", "RdfType", "Payload"] + for r in results: + rows.append([ + r.client_id, + r.flow_id, + r.hunt_id if r.hunt_id else "0", + rdfvalue.RDFDatetime.Now().AsDatetime(), + r.tag, + db_utils.TypeURLToRDFTypeName(r.payload.type_url), + r.payload, + ]) + mut.insert(table="FlowResults", columns=columns, values=rows) + + self.db.Mutate(Mutation, txn_tag="WriteFlowResults") + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: + """Writes flow errors for a given flow.""" + for batch in collection.Batch(errors, self._write_rows_batch_size): + self._WriteFlowErrors(batch) + + def _WriteFlowErrors(self, errors: Sequence[flows_pb2.FlowError]) -> None: + """Writes flow errors for a given flow.""" + + def Mutation(mut) -> None: + rows = [] + columns = ["ClientId", "FlowId", "HuntId", + "CreationTime", "Payload", "RdfType", "Tag"] + for r in errors: + rows.append([r.client_id, + r.flow_id, + r.hunt_id if r.hunt_id else "0", + rdfvalue.RDFDatetime.Now().AsDatetime(), + r.payload, + db_utils.TypeURLToRDFTypeName(r.payload.type_url), + r.tag, + ]) + mut.insert(table="FlowErrors", columns=columns, values=rows) + + self.db.Mutate(Mutation, txn_tag="WriteFlowErrors") + + def ReadFlowResults( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowResult]: + """Reads flow results of a given flow using given query options.""" + query, params, param_type = _BuildReadFlowResultsErrorsConditions( + "FlowResults", + client_id, + flow_id, + offset, + count, + with_tag, + with_type, + with_substring, + ) + + results = [] + for ( + payload_bytes, + _, + creation_time, + tag, + hunt_id, + ) in self.db.ParamQuery(query, params, param_type=param_type, + txn_tag="ReadFlowResults"): + result = flows_pb2.FlowResult( + client_id=client_id, + flow_id=flow_id, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + result.payload.ParseFromString(payload_bytes) + + if hunt_id is not None: + result.hunt_id = hunt_id + + if tag is not None: + result.tag = tag + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowErrors( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowError]: + """Reads flow errors of a given flow using given query options.""" + query, params, param_type = _BuildReadFlowResultsErrorsConditions( + "FlowErrors", + client_id, + flow_id, + offset, + count, + with_tag, + with_type, + None, + ) + + errors = [] + for ( + payload_bytes, + payload_type, + creation_time, + tag, + hunt_id, + ) in self.db.ParamQuery(query, params, param_type=param_type, + txn_tag="ReadFlowErrors"): + error = flows_pb2.FlowError( + client_id=client_id, + flow_id=flow_id, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + # for separation of concerns reasons, + # ReadFlowResults/ReadFlowErrors shouldn't do the payload type validation, + # they should be completely agnostic to what payloads get written/read + # to/from the database. Keeping this logic here temporarily + # to narrow the scope of the RDFProtoStruct->protos migration. + if payload_type in rdfvalue.RDFValue.classes: + error.payload.ParseFromString(payload_bytes) + else: + unrecognized = objects_pb2.SerializedValueOfUnrecognizedType( + type_name=payload_type, value=payload_bytes + ) + error.payload.Pack(unrecognized) + + if hunt_id is not None: + error.hunt_id = hunt_id + + if tag is not None: + error.tag = tag + + errors.append(error) + + return errors + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowResults( + self, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: + """Counts flow results of a given flow using given query options.""" + + query, params, param_type = _BuildCountFlowResultsErrorsConditions( + "FlowResults", client_id, flow_id, with_tag, with_type + ) + (count,) = self.db.ParamQuerySingle( + query, params, param_type=param_type, txn_tag="CountFlowResults" + ) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowErrors( + self, + client_id: str, + flow_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: + """Counts flow errors of a given flow using given query options.""" + + query, params, param_type = _BuildCountFlowResultsErrorsConditions( + "FlowErrors", client_id, flow_id, with_tag, with_type + ) + (count,) = self.db.ParamQuerySingle( + query, params, param_type=param_type, txn_tag="CountFlowErrors" + ) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowResultsByType( + self, client_id: str, flow_id: str + ) -> Mapping[str, int]: + """Returns counts of flow results grouped by result type.""" + + query = """ + SELECT r.RdfType, COUNT(*) + FROM FlowResults AS r + WHERE r.ClientId = {client_id} AND r.FlowId = {flow_id} + GROUP BY RdfType + """ + + params = { + "client_id": client_id, + "flow_id": flow_id, + } + + result = {} + for type_name, count in self.db.ParamQuery( + query, params, txn_tag="CountFlowResultsByType" + ): + result[type_name] = count + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowErrorsByType( + self, client_id: str, flow_id: str + ) -> Mapping[str, int]: + """Returns counts of flow errors grouped by error type.""" + + query = """ + SELECT e.RdfType, COUNT(*) + FROM FlowErrors AS e + WHERE e.ClientId = {client_id} AND e.FlowId = {flow_id} + GROUP BY RdfType + """ + + params = { + "client_id": client_id, + "flow_id": flow_id, + } + + result = {} + for type_name, count in self.db.ParamQuery( + query, params, txn_tag="CountFlowErrorsByType" + ): + result[type_name] = count + + return result + + def _WriteFlowProcessingRequests( + self, + requests: Iterable[flows_pb2.FlowProcessingRequest], + txn + ) -> None: + """Writes a list of FlowProcessingRequests.""" + + columns = [ + "RequestId", + "ClientId", + "FlowId", + "CreationTime", + "Payload", + "DeliveryTime" + ] + rows = [] + for request in requests: + row = [ + str(uuid.uuid4()), + request.client_id, + request.flow_id, + spanner_lib.COMMIT_TIMESTAMP, + request, + rdfvalue.RDFDatetime( + request.delivery_time + ).AsDatetime() + ] + rows.append(row) + + txn.insert(table="FlowProcessingRequests", columns=columns, values=rows) + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowProcessingRequests( + self, + requests: Sequence[flows_pb2.FlowProcessingRequest], + ) -> None: + """Writes a list of flow processing requests to the database.""" + + def Txn(txn) -> None: + self._WriteFlowProcessingRequests(requests, txn) + + self.db.Transact(Txn, txn_tag="WriteFlowProcessingRequests") + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowProcessingRequests( + self, + ) -> Sequence[flows_pb2.FlowProcessingRequest]: + """Reads all flow processing requests from the database.""" + query = """ + SELECT fpr.Payload, fpr.CreationTime FROM FlowProcessingRequests AS fpr + """ + results = [] + for payload, creation_time in self.db.ParamQuery(query, {}): + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(payload) + req.creation_time = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + results.append(req) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def AckFlowProcessingRequests( + self, requests: Iterable[flows_pb2.FlowProcessingRequest] + ) -> None: + """Deletes a list of flow processing requests from the database.""" + def Txn(txn) -> None: + keys = [] + for request in requests: + creation_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + request.creation_time + ).AsDatetime() + keys.append([request.client_id, request.flow_id, creation_time]) + keyset = spanner_lib.KeySet(keys=keys) + txn.delete(table="FlowProcessingRequests", keyset=keyset) + + self.db.Transact(Txn, txn_tag="AckFlowProcessingRequests") + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteAllFlowProcessingRequests(self) -> None: + """Deletes all flow processing requests from the database.""" + + def Txn(txn) -> None: + keyset = spanner_lib.KeySet(all_=True) + txn.delete(table="FlowProcessingRequests", keyset=keyset) + + self.db.Transact(Txn, txn_tag="DeleteAllFlowProcessingRequests") + + @db_utils.CallLogged + @db_utils.CallAccounted + def _LeaseFlowProcessingRequests( + self, limit: int + ) -> Sequence[flows_pb2.FlowProcessingRequest]: + """Leases a number of flow processing requests.""" + now = rdfvalue.RDFDatetime.Now() + expiry = now + rdfvalue.Duration.From(10, rdfvalue.MINUTES) + + def Txn(txn) -> None: + params = { + "limit": limit, + "now": now.AsDatetime() + } + param_type = { + "limit": param_types.INT64, + "now": param_types.TIMESTAMP + } + requests = txn.execute_sql( + "SELECT RequestId, CreationTime, Payload " + "FROM FlowProcessingRequests " + "WHERE " + " (DeliveryTime IS NULL OR DeliveryTime <= @now) AND " + " (LeasedUntil IS NULL OR LeasedUntil < @now) " + "LIMIT @limit", + params=params, + param_types=param_type) + + res = [] + request_ids = [] + for request_id, creation_time, request in requests: + req = flows_pb2.FlowProcessingRequest() + req.ParseFromString(request) + req.creation_time = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + res.append(req) + request_ids.append(request_id) + + query = ( + "UPDATE FlowProcessingRequests " + "SET LeasedUntil=@leased_until, LeasedBy=@leased_by " + "WHERE RequestId IN UNNEST(@request_ids)" + ) + params = { + "request_ids": request_ids, + "leased_by": utils.ProcessIdString(), + "leased_until": expiry.AsDatetime() + } + param_type = { + "request_ids": param_types.Array(param_types.STRING), + "leased_by": param_types.STRING, + "leased_until": param_types.TIMESTAMP + } + txn.execute_update(query, params, param_type) + + return res + + return self.db.Transact(Txn, txn_tag="_LeaseFlowProcessingRequests") + + _FLOW_REQUEST_POLL_TIME_SECS = 3 + + def _FlowProcessingRequestHandlerLoop( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ) -> None: + """The main loop for the flow processing request queue.""" + self.flow_processing_request_handler_pool.Start() + + while not self.flow_processing_request_handler_stop: + thread_pool = self.flow_processing_request_handler_pool + free_threads = thread_pool.max_threads - thread_pool.busy_threads + if free_threads == 0: + time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) + continue + try: + msgs = self._LeaseFlowProcessingRequests(free_threads) + if msgs: + for m in msgs: + self.flow_processing_request_handler_pool.AddTask( + target=handler, args=(m,) + ) + else: + time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) + + except Exception as e: # pylint: disable=broad-except + logging.exception("_FlowProcessingRequestHandlerLoop raised %s.", e) + time.sleep(self._FLOW_REQUEST_POLL_TIME_SECS) + + self.flow_processing_request_handler_pool.Stop() + + def RegisterFlowProcessingHandler( + self, handler: Callable[[flows_pb2.FlowProcessingRequest], None] + ): + """Registers a handler to receive flow processing messages.""" + self.UnregisterFlowProcessingHandler() + + if handler: + self.flow_processing_request_handler_stop = False + self.flow_processing_request_handler_thread = threading.Thread( + name="flow_processing_request_handler", + target=self._FlowProcessingRequestHandlerLoop, + args=(handler,), + ) + self.flow_processing_request_handler_thread.daemon = True + self.flow_processing_request_handler_thread.start() + + def UnregisterFlowProcessingHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: + """Unregisters any registered flow processing handler.""" + if self.flow_processing_request_handler_thread: + self.flow_processing_request_handler_stop = True + self.flow_processing_request_handler_thread.join(timeout) + if self.flow_processing_request_handler_thread.is_alive(): + raise RuntimeError("Flow processing handler did not join in time.") + self.flow_processing_request_handler_thread = None + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + ) -> None: + for batch in collection.Batch(requests, self._write_rows_batch_size): + self._WriteFlowRequests(batch) + + def _WriteFlowRequests( + self, + requests: Collection[flows_pb2.FlowRequest], + ) -> None: + """Writes a list of flow requests to the database.""" + + flow_keys = [(r.client_id, r.flow_id) for r in requests] + + def Txn(txn) -> None: + needs_processing = {} + columns = ["ClientId", + "FlowId", + "RequestId", + "NeedsProcessing", + "NextResponseId", + "CallbackState", + "Payload", + "CreationTime", + "StartTime"] + rows = [] + for r in requests: + if r.needs_processing: + needs_processing.setdefault((r.client_id, r.flow_id), []).append(r) + + if r.start_time: + start_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(r.start_time).AsDatetime() + else: + start_time = rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch(0).AsDatetime() + + rows.append([r.client_id, r.flow_id, str(r.request_id), r.needs_processing, str(r.next_response_id), + r.callback_state, r, spanner_lib.COMMIT_TIMESTAMP, start_time]) + txn.insert_or_update(table="FlowRequests", columns=columns, values=rows) + + if needs_processing: + flow_processing_requests = [] + + keys = [] + # Note on linting: adding .keys() triggers a warning that + # .keys() should be omitted. Omitting keys leads to a + # mistaken warning that .items() was not called. + for client_id, flow_id in needs_processing: # pylint: disable=dict-iter-missing-items + keys.append([client_id, flow_id]) + + columns = ( + "ClientId", + "FlowId", + "NextRequestToProcess", + ) + for row in txn.read(table="Flows", keyset=spanner_lib.KeySet(keys=keys), columns=columns): + client_id = row[0] + flow_id = row[1] + next_request_to_process = int(row[2]) + + candidate_requests = needs_processing.get((client_id, flow_id), []) + for r in candidate_requests: + if next_request_to_process == r.request_id or r.start_time: + req = flows_pb2.FlowProcessingRequest( + client_id=client_id, flow_id=flow_id + ) + if r.start_time: + req.delivery_time = r.start_time + flow_processing_requests.append(req) + + if flow_processing_requests: + self._WriteFlowProcessingRequests(flow_processing_requests, txn) + + try: + self.db.Transact(Txn, txn_tag="WriteFlowRequests") + except NotFound as error: + if "Parent row for row [" in str(error): + raise db.AtLeastOneUnknownFlowError(flow_keys, cause=error) + else: + raise + + def _ReadRequestsInfo( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + txn, + ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str], set[_RequestKey]]: + """For given responses returns data about corresponding requests. + + Args: + responses: an iterable with responses. + txn: transaction to use. + + Returns: + A tuple of 3 dictionaries: ( + responses_expected_by_request, + callback_state_by_request, + currently_available_requests). + + responses_expected_by_request: for requests that already received + a Status response, maps each request id to the number of responses + expected for it. + + callback_state_by_request: for incremental requests, maps each request + id to the name of a flow callback state that has to be called on + every incoming response. + + currently_available_requests: a set with all the request ids corresponding + to given responses. + """ + + # Number of responses each affected request is waiting for (if available). + responses_expected_by_request = {} + + # We also store all requests we have in the db so we can discard responses + # for unknown requests right away. + currently_available_requests = set() + + # Callback states by request. + callback_state_by_request = {} + + keys = [] + for r in responses: + keys.append([r.client_id, r.flow_id, str(r.request_id)]) + + for row in txn.read( + table="FlowRequests", + keyset=spanner_lib.KeySet(keys=keys), + columns=[ + "ClientID", + "FlowID", + "RequestID", + "CallbackState", + "ExpectedResponseCount", + ] + ): + + request_key = _RequestKey( + row[0], + row[1], + int(row[2]), + ) + currently_available_requests.add(request_key) + + callback_state: str = row[3] + if callback_state: + callback_state_by_request[request_key] = callback_state + + responses_expected: int = row[4] + if responses_expected: + responses_expected_by_request[request_key] = responses_expected + + return ( + responses_expected_by_request, + callback_state_by_request, + currently_available_requests, + ) + + def _BuildResponseWrites( + self, + responses: Collection[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + txn, + ) -> None: + """Builds the writes to store given responses in the db. + + Args: + responses: iterable with flow responses to write. + txn: transaction to use for the writes. + + Raises: + TypeError: if responses have objects other than FlowResponse, FlowStatus + or FlowIterator. + """ + columns = ["ClientId", + "FlowId", + "RequestId", + "ResponseId", + "Response", + "Status", + "Iterator", + "CreationTime"] + rows = [] + for r in responses: + response = None + status = None + iterator = None + if isinstance(r, flows_pb2.FlowResponse): + response = r + elif isinstance(r, flows_pb2.FlowStatus): + status = r + elif isinstance(r, flows_pb2.FlowIterator): + iterator = r + else: + # This can't really happen due to DB validator type checking. + raise TypeError(f"Got unexpected response type: {type(r)} {r}") + rows.append([r.client_id, r.flow_id, str(r.request_id), str(r.response_id), + response,status,iterator,spanner_lib.COMMIT_TIMESTAMP]) + + txn.insert_or_update(table="FlowResponses", columns=columns, values=rows) + + def _BuildExpectedUpdates( + self, updates: dict[_RequestKey, int], txn + ) -> None: + """Builds updates for requests with known number of expected responses. + + Args: + updates: dict mapping requests to the number of expected responses. + txn: transaction to use for the writes. + """ + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "ExpectedResponseCount"] + for r_key, num_responses_expected in updates.items(): + rows.append([r_key.client_id, + r_key.flow_id, + str(r_key.request_id), + num_responses_expected, + ]) + txn.update(table="FlowRequests", columns=columns, values=rows) + + def _WriteFlowResponsesAndExpectedUpdates( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: + """Writes a flow responses and updates flow requests expected counts. + + Args: + responses: responses to write. + + Returns: + A tuple of (expected_responses_by_request, callback_state_by_request). + + expected_responses_by_request: number of expected responses by + request id. These numbers are collected from Status responses + discovered in the `responses` sequence. This data is later + passed to _BuildExpectedUpdates. + + callback_state_by_request: callback states by request. If incremental + requests are discovered during processing, their callback states end + up in this dictionary. This information is used later to make a + decision whether a flow should be notified about new responses: + incremental flows have to be notified even if Status responses were + not received. + """ + + if not responses: + return ({}, {}) + + def Txn(txn) -> tuple[dict[_RequestKey, int], dict[_RequestKey, str]]: + ( + responses_expected_by_request, + callback_state_by_request, + currently_available_requests, + ) = self._ReadRequestsInfo(responses, txn) + + # For some requests we will need to update the number of expected + # responses. + needs_expected_update = {} + + for r in responses: + req_key = _RequestKey(r.client_id, r.flow_id, int(r.request_id)) + + # If the response is not a FlowStatus, we have nothing to do: it will be + # simply written to the DB. If it's a FlowStatus, we have to update + # the FlowRequest with the number of expected messages. + if not isinstance(r, flows_pb2.FlowStatus): + continue + + if req_key not in currently_available_requests: + logging.info("Dropping status for unknown request %s", req_key) + continue + + current = responses_expected_by_request.get(req_key) + if current: + logging.warning( + "Got duplicate status message for request %s", req_key + ) + + # If there is already responses_expected information, we need to make + # sure the current status doesn't disagree. + if current != r.response_id: + logging.error( + "Got conflicting status information for request %s: %s", + req_key, + r, + ) + else: + needs_expected_update[req_key] = r.response_id + + responses_expected_by_request[req_key] = r.response_id + + responses_to_write = {} + for r in responses: + req_key = _RequestKey(r.client_id, r.flow_id, int(r.request_id)) + full_key = _ResponseKey( + r.client_id, r.flow_id, int(r.request_id), int(r.response_id) + ) + + if req_key not in currently_available_requests: + continue + + if full_key in responses_to_write: + # Don't write a response if it was already written as part of the + # same batch. + prev = responses_to_write[full_key] + if r != prev: + logging.warning( + "WriteFlowResponses attempted to write two different " + "entries with identical key %s. First is %s and " + "second is %s.", + full_key, + prev, + r, + ) + continue + + responses_to_write[full_key] = r + + if responses_to_write or needs_expected_update: + self._BuildResponseWrites(responses_to_write.values(), txn) + if needs_expected_update: + self._BuildExpectedUpdates(needs_expected_update, txn) + + return responses_expected_by_request, callback_state_by_request + + return tuple(self.db.Transact( + Txn, txn_tag="WriteFlowResponsesAndExpectedUpdates" + )) + + def _GetFlowResponsesPerRequestCounts( + self, + request_keys: Iterable[_RequestKey] + ) -> dict[_RequestKey, int]: + """Gets counts of already received responses for given requests. + + Args: + request_keys: iterable with request keys. + txn: transaction to use. + + Returns: + A dictionary mapping request keys to the number of existing flow + responses. + """ + + if not request_keys: + return {} + + conditions = [] + params = {} + for i, req_key in enumerate(request_keys): + if i > 0: + conditions.append("OR") + + conditions.append(f""" + (fr.ClientId = {{client_id_{i}}} AND + fr.FlowId = {{flow_id_{i}}} AND + fr.RequestId = {{request_id_{i}}}) + """) + + params[f"client_id_{i}"] = req_key.client_id + params[f"flow_id_{i}"] = req_key.flow_id + params[f"request_id_{i}"] = str(req_key.request_id) + + query = f""" + SELECT fr.ClientId, fr.FlowId, fr.RequestId, COUNT(*) AS ResponseCount + FROM FlowResponses as fr + WHERE {" ".join(conditions)} + GROUP BY fr.ClientID, fr.FlowID, fr.RequestID + """ + + result = {} + for row in self.db.ParamQuery(query, params): + client_id, flow_id, request_id, count = row + + req_key = _RequestKey( + client_id, + flow_id, + int(request_id), + ) + result[req_key] = count + + return result + + def _ReadFlowRequestsNotYetMarkedForProcessing( + self, + requests: set[_RequestKey], + callback_states: dict[_RequestKey, str], + txn, + ) -> tuple[ + set[_RequestKey], set[tuple[_FlowKey, Optional[rdfvalue.RDFDatetime]]] + ]: + """Reads given requests and returns only ones not marked for processing. + + Args: + requests: request keys for requests to be read. + callback_states: dict containing incremental flow requests from the set. + For each such request the request key will be mapped to the callback + state of the flow. + txn: transaction to use. + + Returns: + A tuple of (requests_to_mark, flows_to_notify). + + requests_to_mark is a set of request keys for requests that have to be + marked as needing processing. + + flows_to_notify is a set of tuples (flow_key, start_time) for flows that + have to be notified of incoming responses. start_time in the tuple + corresponds to the intended notification delivery time. + """ + flow_keys = [] + req_keys = [] + + unique_flow_keys = set() + + for req_key in set(requests) | set(callback_states): + req_keys.append([req_key.client_id, req_key.flow_id, str(req_key.request_id)]) + unique_flow_keys.add((req_key.client_id, req_key.flow_id)) + + for client_id, flow_id in unique_flow_keys: + flow_keys.append([client_id, flow_id]) + + next_request_to_process_by_flow = {} + flow_cols = [ + "ClientId", + "FlowId", + "NextRequestToProcess", + ] + for row in txn.read(table="Flows", + keyset=spanner_lib.KeySet(keys=flow_keys), + columns=flow_cols): + client_id: int = row[0] + flow_id: int = row[1] + next_request_id: int = int(row[2]) + next_request_to_process_by_flow[(client_id, flow_id)] = ( + next_request_id + ) + + requests_to_mark = set() + requests_to_notify = set() + req_cols = [ + "ClientId", + "FlowId", + "RequestId", + "NeedsProcessing", + "StartTime", + ] + for row in txn.read(table="FlowRequests", + keyset=spanner_lib.KeySet(keys=req_keys), + columns=req_cols): + client_id: str = row[0] + flow_id: str = row[1] + request_id: int = int(row[2]) + np: bool = row[3] + start_time: Optional[rdfvalue.RDFDatetime] = None + if row[4] is not None: + start_time = rdfvalue.RDFDatetime.FromDatetime(row[4]) + + if not np: + + req_key = _RequestKey(client_id, flow_id, request_id) + if req_key in requests: + requests_to_mark.add(req_key) + + if ( + next_request_to_process_by_flow[(client_id, flow_id)] == request_id + ): + requests_to_notify.add((_FlowKey(client_id, flow_id), start_time)) + + return requests_to_mark, requests_to_notify + + def _BuildNeedsProcessingUpdates( + self, requests: set[_RequestKey], txn + ) -> None: + """Builds updates for requests that have their NeedsProcessing flag set. + + Args: + requests: keys of requests to be updated. + txn: transaction to use. + """ + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "NeedsProcessing"] + for req_key in requests: + rows.append([req_key.client_id, + req_key.flow_id, + str(req_key.request_id), + True, + ]) + txn.update(table="FlowRequests", columns=columns, values=rows) + + def _UpdateNeedsProcessingAndWriteFlowProcessingRequests( + self, + requests_ready_for_processing: set[_RequestKey], + callback_state_by_request: dict[_RequestKey, str], + txn, + ) -> None: + """Updates requests needs-processing flags, writes processing requests. + + Args: + requests_ready_for_processing: request keys for requests that have to be + updated. + callback_state_by_request: for incremental requests from the set - mapping + from request ids to callback states that are incrementally processing + incoming responses. + txn: transaction to use. + """ + + if not requests_ready_for_processing and not callback_state_by_request: + return + + (requests_to_mark, flows_to_notify) = ( + self._ReadFlowRequestsNotYetMarkedForProcessing( + requests_ready_for_processing, callback_state_by_request, txn + ) + ) + + if requests_to_mark: + self._BuildNeedsProcessingUpdates(requests_to_mark, txn) + + if flows_to_notify: + flow_processing_requests = [] + for flow_key, start_time in flows_to_notify: + fpr = flows_pb2.FlowProcessingRequest( + client_id=flow_key.client_id, + flow_id=flow_key.flow_id, + ) + if start_time is not None: + fpr.delivery_time = int(start_time) + flow_processing_requests.append(fpr) + + self._WriteFlowProcessingRequests(flow_processing_requests, txn) + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowResponses( + self, + responses: Sequence[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ) -> None: + """Writes Flow messages and updates corresponding requests.""" + responses_expected_by_request = {} + callback_state_by_request = {} + for batch in collection.Batch(responses, self._write_rows_batch_size): + res_exp_by_req_iter, callback_state_by_req_iter = ( + self._WriteFlowResponsesAndExpectedUpdates(batch) + ) + + responses_expected_by_request.update(res_exp_by_req_iter) + callback_state_by_request.update(callback_state_by_req_iter) + + # If we didn't get any status messages, then there's nothing to process. + if not responses_expected_by_request and not callback_state_by_request: + return + + # Get actual per-request responses counts using a separate transaction. + counts = self._GetFlowResponsesPerRequestCounts( + responses_expected_by_request + ) + + requests_ready_for_processing = set() + for req_key, responses_expected in responses_expected_by_request.items(): + if counts.get(req_key) == responses_expected: + requests_ready_for_processing.add(req_key) + + # requests_to_notify is a subset of requests_ready_for_processing, so no + # need to check if it's empty or not. + if requests_ready_for_processing or callback_state_by_request: + + def Txn(txn) -> None: + self._UpdateNeedsProcessingAndWriteFlowProcessingRequests( + requests_ready_for_processing, callback_state_by_request, txn + ) + + self.db.Transact(Txn, txn_tag="WriteFlowResponses") + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> None: + """Deletes all requests and responses for a given flow from the database.""" + self.db.DeleteWithPrefix( + "FlowRequests", + (client_id, flow_id), + txn_tag="DeleteAllFlowRequestsAndResponses" + ) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllFlowRequestsAndResponses( + self, + client_id: str, + flow_id: str, + ) -> Iterable[ + tuple[ + flows_pb2.FlowRequest, + dict[ + int, + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ] + ]: + """Reads all requests and responses for a given flow from the database.""" + rowrange = spanner_lib.KeyRange(start_closed=[client_id, flow_id], end_closed=[client_id, flow_id]) + rows = spanner_lib.KeySet(ranges=[rowrange]) + req_cols = [ + "Payload", + "NeedsProcessing", + "ExpectedResponseCount", + "CallbackState", + "NextResponseId", + "CreationTime", + ] + requests = [] + for row in self.db.ReadSet(table="FlowRequests", + rows=rows, + cols=req_cols, + txn_tag="ReadAllFlowRequestsAndResponses:FlowRequests"): + request = flows_pb2.FlowRequest() + request.ParseFromString(row[0]) + request.needs_processing = row[1] + if row[2] is not None: + request.nr_responses_expected = row[2] + request.callback_state = row[3] + request.next_response_id = int(row[4]) + request.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[5]) + ) + requests.append(request) + + resp_cols = [ + "Response", + "Status", + "Iterator", + "CreationTime", + ] + responses = {} + for row in self.db.ReadSet( + table="FlowResponses", + rows=rows, + cols=resp_cols, + txn_tag="ReadAllFlowRequestsAndResponses:FlowResponses" + ): + if row[1] is not None: + response = flows_pb2.FlowStatus() + response.ParseFromString(row[1]) + elif row[2] is not None: + response = flows_pb2.FlowIterator() + response.ParseFromString(row[2]) + else: + response = flows_pb2.FlowResponse() + response.ParseFromString(row[0]) + response.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[3]) + ) + responses.setdefault(response.request_id, {})[ + response.response_id + ] = response + + ret = [] + for req in sorted(requests, key=lambda r: r.request_id): + ret.append((req, responses.get(req.request_id, {}))) + return ret + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteFlowRequests( + self, + requests: Sequence[flows_pb2.FlowRequest], + ) -> None: + """Deletes a list of flow requests from the database.""" + if not requests: + return + + def Mutation(mut) -> None: + for request in requests: + keyset = spanner_lib.KeySet([[ + request.client_id, + request.flow_id, + str(request.request_id) + ]]) + mut.delete(table="FlowRequests", keyset=keyset) + + try: + self.db.Mutate(Mutation, txn_tag="DeleteFlowRequests") + except Exception: + if len(requests) == 1: + # If there is only one request and we still hit Spanner limits it means + # that the requests has a lot of responses. It should be extremely rare + # to end up in such situation, so we just leave the request in the + # database. Eventually, these rows will be deleted automatically due to + # our row retention policies [1] for this table. + # + # [1]: go/spanner-row-deletion-policies. + SPANNER_DELETE_FLOW_REQUESTS_FAILURES.Increment() + logging.error( + "Transaction too big to delete flow request '%s'", requests[0] + ) + else: + # If there is more than one request, we attempt to divide the data into + # smaller parts and delete these. + # + # Note that dividing in two does not mean that the number of deleted + # rows will spread evenly as it might be the case that one request in + # one part has significantly more responses than requests in the other + # part. However, as a cheap and reasonable approximation, this should do + # just fine. + # + # Notice that both this `DeleteFlowRequests` calls happen in separate + # transactions. Since we are just deleting rows "obsolete" rows we do + # not really care about atomicity. If one of them succeeds and the other + # one fails, rows are going to be deleted eventually anyway (see the + # comment for a single request case). + self.DeleteFlowRequests(requests[: len(requests) // 2]) + self.DeleteFlowRequests(requests[len(requests) // 2 :]) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowRequests( + self, + client_id: str, + flow_id: str, + ) -> dict[ + int, + tuple[ + flows_pb2.FlowRequest, + list[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ]: + """Reads all requests for a flow that can be processed by the worker.""" + rowrange = spanner_lib.KeyRange(start_closed=[client_id, flow_id], end_closed=[client_id, flow_id]) + rows = spanner_lib.KeySet(ranges=[rowrange]) + + responses: dict[ + int, + list[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ] + ], + ] = {} + resp_cols = [ + "Response", + "Status", + "Iterator", + "CreationTime", + ] + for row in self.db.ReadSet(table="FlowResponses", + rows=rows, + cols=resp_cols, + txn_tag="ReadFlowRequests:FlowResponses"): + if row[1]: + response = flows_pb2.FlowStatus() + response.ParseFromString(row[1]) + elif row[2]: + response = flows_pb2.FlowIterator() + response.ParseFromString(row[2]) + else: + response = flows_pb2.FlowResponse() + response.ParseFromString(row[0]) + response.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[3]) + ) + responses.setdefault(response.request_id, []).append(response) + + requests: dict[ + int, + tuple[ + flows_pb2.FlowRequest, + list[ + Union[ + flows_pb2.FlowResponse, + flows_pb2.FlowStatus, + flows_pb2.FlowIterator, + ], + ], + ], + ] = {} + req_cols = [ + "Payload", + "NeedsProcessing", + "ExpectedResponseCount", + "NextResponseId", + "CallbackState", + "CreationTime", + ] + for row in self.db.ReadSet(table="FlowRequests", + rows=rows, + cols=req_cols, + txn_tag="ReadFlowRequests:FlowRequests"): + request = flows_pb2.FlowRequest() + request.ParseFromString(row[0]) + request.needs_processing = row[1] + if row[2] is not None: + request.nr_responses_expected = row[2] + request.callback_state = row[4] + request.next_response_id = int(row[3]) + request.timestamp = int( + rdfvalue.RDFDatetime.FromDatetime(row[5]) + ) + requests[request.request_id] = ( + request, + sorted( + responses.get(request.request_id, []), key=lambda r: r.response_id + ), + ) + return requests + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateIncrementalFlowRequests( + self, + client_id: str, + flow_id: str, + next_response_id_updates: Mapping[int, int], + ) -> None: + """Updates next response ids of given requests.""" + + def Txn(txn) -> None: + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "NextResponseId"] + for request_id, response_id in next_response_id_updates.items(): + rows.append([client_id, flow_id, str(request_id), str(response_id)]) + txn.update( + table="FlowRequests", + columns=columns, + values=rows + ) + + self.db.Transact(Txn, txn_tag="UpdateIncrementalFlowRequests") + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowLogEntry(self, entry: flows_pb2.FlowLogEntry) -> None: + """Writes a single flow log entry to the database.""" + row = { + "ClientId": entry.client_id, + "FlowId": entry.flow_id, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "Message": entry.message, + } + + if entry.hunt_id: + row["HuntId"] = entry.hunt_id + + try: + self.db.Insert(table="FlowLogEntries", row=row, txn_tag="WriteFlowLogEntry") + except NotFound as error: + raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error + + def ReadFlowLogEntries( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + with_substring: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowLogEntry]: + """Reads flow log entries of a given flow using given query options.""" + results = [] + + query = """ + SELECT l.HuntId, + l.CreationTime, + l.Message + FROM FlowLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + """ + params = { + "client_id": client_id, + "flow_id": flow_id, + } + + if with_substring is not None: + query += " AND STRPOS(l.Message, {substring}) != 0" + params["substring"] = with_substring + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + for row in self.db.ParamQuery(query, params): + hunt_id, creation_time, message = row + + result = flows_pb2.FlowLogEntry() + result.client_id = client_id + result.flow_id = flow_id + + if hunt_id is not None: + result.hunt_id = hunt_id + + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.message = message + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowLogEntries(self, client_id: str, flow_id: str) -> int: + """Returns number of flow log entries of a given flow.""" + query = """ + SELECT COUNT(*) + FROM FlowLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + """ + params = { + "client_id": client_id, + "flow_id": flow_id, + } + + (count,) = self.db.ParamQuerySingle(query, params) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowRRGLogs( + self, + client_id: str, + flow_id: str, + request_id: int, + logs: Mapping[int, rrg_pb2.Log], + ) -> None: + """Writes new log entries for a particular action request.""" + # Mutations cannot be empty, so we exit early to avoid that if needed. + if not logs: + return + + def Mutation(mut) -> None: + rows = [] + columns = ["ClientId", "FlowId", "RequestId", "ResponseId", + "LogLevel", "LogTime", "LogMessage", "CreationTime"] + for response_id, log in logs.items(): + rows.append([client_id, + flow_id, + str(request_id), + str(response_id), + log.level, + log.timestamp.ToDatetime(), + log.message, + spanner_lib.COMMIT_TIMESTAMP + ]) + mut.insert(table="FlowRRGLogs", columns=columns, values=rows) + + try: + self.db.Mutate(Mutation) + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) from error + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowRRGLogs( + self, + client_id: str, + flow_id: str, + offset: int, + count: int, + ) -> Sequence[rrg_pb2.Log]: + """Reads log entries logged by actions issued by a particular flow.""" + query = """ + SELECT + l.LogLevel, l.LogTime, l.LogMessage + FROM + FlowRRGLogs AS l + WHERE + l.ClientId = {client_id} AND l.FlowId = {flow_id} + ORDER BY + l.RequestId, l.ResponseId + LIMIT + {count} + OFFSET + {offset} + """ + params = { + "client_id": client_id, + "flow_id": flow_id, + "offset": offset, + "count": count, + } + + results: list[rrg_pb2.Log] = [] + + for row in self.db.ParamQuery(query, params, txn_tag="ReadFlowRRGLogs"): + log_level, log_time, log_message = row + + log = rrg_pb2.Log() + log.level = log_level + log.timestamp.FromDatetime(log_time) + log.message = log_message + + results.append(log) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteFlowOutputPluginLogEntry( + self, + entry: flows_pb2.FlowOutputPluginLogEntry, + ) -> None: + """Writes a single output plugin log entry to the database. + + Args: + entry: An output plugin flow entry to write. + """ + row = { + "ClientId": entry.client_id, + "FlowId": entry.flow_id, + "OutputPluginId": entry.output_plugin_id, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "Type": int(entry.log_entry_type), + "Message": entry.message, + } + + if entry.hunt_id: + row["HuntId"] = entry.hunt_id + + try: + self.db.Insert(table="FlowOutputPluginLogEntries", row=row, txn_tag="WriteFlowOutputPluginLogEntry") + except NotFound as error: + raise db.UnknownFlowError(entry.client_id, entry.flow_id) from error + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadFlowOutputPluginLogEntries( + self, + client_id: str, + flow_id: str, + output_plugin_id: str, + offset: int, + count: int, + with_type: Optional[ + flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ValueType + ] = None, + ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: + """Reads flow output plugin log entries.""" + results = [] + + query = """ + SELECT l.HuntId, + l.CreationTime, + l.Type, l.Message + FROM FlowOutputPluginLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "client_id": client_id, + "flow_id": flow_id, + "output_plugin_id": output_plugin_id, + } + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + for row in self.db.ParamQuery(query, params): + hunt_id, creation_time, int_type, message = row + + result = flows_pb2.FlowOutputPluginLogEntry() + result.client_id = client_id + result.flow_id = flow_id + result.output_plugin_id = output_plugin_id + + if hunt_id is not None: + result.hunt_id = hunt_id + + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.log_entry_type = int_type + result.message = message + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountFlowOutputPluginLogEntries( + self, + client_id: str, + flow_id: str, + output_plugin_id: str, + with_type: Optional[ + flows_pb2.FlowOutputPluginLogEntry.LogEntryType.ValueType + ] = None, + ) -> int: + """Returns the number of flow output plugin log entries of a given flow.""" + query = """ + SELECT COUNT(*) + FROM FlowOutputPluginLogEntries AS l + WHERE l.ClientId = {client_id} + AND l.FlowId = {flow_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "client_id": client_id, + "flow_id": flow_id, + "output_plugin_id": output_plugin_id, + } + param_type = {} + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + param_type["type"] = param_types.INT64 + + (count,) = self.db.ParamQuerySingle(query, params, param_type=param_type) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteScheduledFlow( + self, + scheduled_flow: flows_pb2.ScheduledFlow, + ) -> None: + """Inserts or updates the ScheduledFlow in the database.""" + row = { + "ClientId": scheduled_flow.client_id, + "Creator": scheduled_flow.creator, + "ScheduledFlowId": scheduled_flow.scheduled_flow_id, + "FlowName": scheduled_flow.flow_name, + "FlowArgs": scheduled_flow.flow_args, + "RunnerArgs": scheduled_flow.runner_args, + "CreationTime": rdfvalue.RDFDatetime( + scheduled_flow.create_time + ).AsDatetime(), + "Error": scheduled_flow.error, + } + + try: + self.db.InsertOrUpdate(table="ScheduledFlows", row=row, txn_tag="WriteScheduledFlow") + except Exception as error: + if "Parent row for row [" in str(error): + raise db.UnknownClientError(scheduled_flow.client_id) from error + elif "fk_creator_users_username" in str(error): + raise db.UnknownGRRUserError(scheduled_flow.creator) from error + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteScheduledFlow( + self, + client_id: str, + creator: str, + scheduled_flow_id: str, + ) -> None: + """Deletes the ScheduledFlow from the database.""" + keyset = spanner_lib.KeySet(keys=[[client_id, creator, scheduled_flow_id]]) + + def Transaction(txn) -> None: + try: + txn.read(table="ScheduledFlows", columns=["ScheduledFlowId"], keyset=keyset).one() + except NotFound as e: + raise db.UnknownScheduledFlowError( + client_id=client_id, + creator=creator, + scheduled_flow_id=scheduled_flow_id, + ) from e + + txn.delete(table="ScheduledFlows", keyset=keyset) + + self.db.Transact(Transaction, txn_tag="DeleteScheduledFlow") + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListScheduledFlows( + self, + client_id: str, + creator: str, + ) -> Sequence[flows_pb2.ScheduledFlow]: + """Lists all ScheduledFlows for the client and creator.""" + range = spanner_lib.KeyRange(start_closed=[client_id, creator], end_closed=[client_id, creator]) + rows = spanner_lib.KeySet(ranges=[range]) + + cols = [ + "ClientId", + "Creator", + "ScheduledFlowId", + "FlowName", + "FlowArgs", + "RunnerArgs", + "CreationTime", + "Error", + ] + results = [] + + for row in self.db.ReadSet("ScheduledFlows", rows, cols, + txn_tag="ListScheduledFlows"): + sf = flows_pb2.ScheduledFlow() + sf.client_id = row[0] + sf.creator = row[1] + sf.scheduled_flow_id = row[2] + sf.flow_name = row[3] + sf.flow_args.ParseFromString(row[4]) + sf.runner_args.ParseFromString(row[5]) + sf.create_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[6]) + ) + sf.error = row[7] + + results.append(sf) + + return results + + def RegisterMessageHandler( + self, + handler: Callable[[Sequence[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Leases a number of message handler requests up to the indicated limit.""" + self.UnregisterMessageHandler() + + if handler: + self.handler_stop = False + self.handler_thread = threading.Thread( + name="message_handler", + target=self._MessageHandlerLoop, + args=(handler, lease_time, limit), + ) + self.handler_thread.daemon = True + self.handler_thread.start() + + def UnregisterMessageHandler( + self, timeout: Optional[rdfvalue.Duration] = None + ) -> None: + """Unregisters any registered message handler.""" + if self.handler_thread: + self.handler_stop = True + self.handler_thread.join(timeout) + if self.handler_thread.is_alive(): + raise RuntimeError("Message handler thread did not join in time.") + self.handler_thread = None + + _MESSAGE_HANDLER_POLL_TIME_SECS = 5 + + def _MessageHandlerLoop( + self, + handler: Callable[[Iterable[objects_pb2.MessageHandlerRequest]], None], + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> None: + """Loop to handle outstanding requests.""" + while not self.handler_stop: + try: + msgs = self._LeaseMessageHandlerRequests(lease_time, limit) + if msgs: + handler(msgs) + else: + time.sleep(self._MESSAGE_HANDLER_POLL_TIME_SECS) + except Exception as e: # pylint: disable=broad-except + logging.exception("_LeaseMessageHandlerRequests raised %s.", e) + + def _LeaseMessageHandlerRequests( + self, + lease_time: rdfvalue.Duration, + limit: int = 1000, + ) -> Iterable[objects_pb2.MessageHandlerRequest]: + """Leases a number of message handler requests up to the indicated limit.""" + now = rdfvalue.RDFDatetime.Now() + delivery_time = now + lease_time + + leased_until = delivery_time.AsMicrosecondsSinceEpoch() + leased_by = utils.ProcessIdString() + + def Txn(txn) -> None: + # Read the message handler requests waiting for leases + params = { + "limit": limit, + "now": now.AsDatetime() + } + param_type = { + "limit": param_types.INT64, + "now": param_types.TIMESTAMP + } + requests = txn.execute_sql( + "SELECT RequestId, CreationTime, Payload " + "FROM MessageHandlerRequests " + "WHERE LeasedUntil IS NULL OR LeasedUntil < @now " + "LIMIT @limit", + params=params, + param_types=param_type) + res = [] + request_ids = [] + for request_id, creation_time, request in requests: + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(request) + req.timestamp = req.leased_until = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + req.leased_until = leased_until + req.leased_by = leased_by + res.append(req) + request_ids.append(request_id) + + query = ( + "UPDATE MessageHandlerRequests " + "SET LeasedUntil=@leased_until, LeasedBy=@leased_by " + "WHERE RequestId IN UNNEST(@request_ids)" + ) + params = { + "request_ids": request_ids, + "leased_by": leased_by, + "leased_until": delivery_time.AsDatetime() + } + param_type = { + "request_ids": param_types.Array(param_types.STRING), + "leased_by": param_types.STRING, + "leased_until": param_types.TIMESTAMP + } + txn.execute_update(query, params, param_type) + + return res + + return self.db.Transact(Txn, txn_tag="_LeaseMessageHandlerRequests") + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteMessageHandlerRequests( + self, requests: Iterable[objects_pb2.MessageHandlerRequest] + ) -> None: + """Writes a list of message handler requests to the queue.""" + def Mutation(mut) -> None: + creation_timestamp = spanner_lib.COMMIT_TIMESTAMP + rows = [] + columns = ["RequestId", "HandlerName", "CreationTime", "Payload"] + for request in requests: + rows.append([ + str(request.request_id), + request.handler_name, + creation_timestamp, + request, + ]) + mut.insert(table="MessageHandlerRequests", columns=columns, values=rows) + + self.db.Mutate(Mutation, txn_tag="WriteMessageHandlerRequests") + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadMessageHandlerRequests( + self, + ) -> Sequence[objects_pb2.MessageHandlerRequest]: + """Reads all message handler requests from the queue.""" + query = """ + SELECT t.Payload, t.CreationTime, t.LeasedBy, t.LeasedUntil FROM MessageHandlerRequests AS t + """ + results = [] + for payload, creation_time, leased_by, leased_until in self.db.ParamQuery(query, {}): + req = objects_pb2.MessageHandlerRequest() + req.ParseFromString(payload) + req.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + if leased_by is not None: + req.leased_by = leased_by + if leased_until is not None: + req.leased_until = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + results.append(req) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteMessageHandlerRequests( + self, + requests: Iterable[objects_pb2.MessageHandlerRequest], + ) -> None: + """Deletes a list of message handler requests from the database.""" + + query = "DELETE FROM MessageHandlerRequests WHERE RequestId IN UNNEST(@request_ids)" + request_ids = [] + for r in requests: + request_ids.append(str(r.request_id)) + params={"request_ids": request_ids} + + self.db.ParamExecute(query, params, txn_tag="DeleteMessageHandlerRequests") + + def _ReadHuntState( + self, txn, hunt_id: str + ) -> Optional[int]: + try: + row = txn.read(table="Hunts", keyset=spanner_lib.KeySet(keys=[[hunt_id]]), columns=["State",]).one() + return row[0] + except NotFound: + return None + + @db_utils.CallLogged + @db_utils.CallAccounted + def LeaseFlowForProcessing( + self, + client_id: str, + flow_id: str, + processing_time: rdfvalue.Duration, + ) -> flows_pb2.Flow: + """Marks a flow as being processed on this worker and returns it.""" + + def Txn(txn) -> flows_pb2.Flow: + try: + row = txn.read( + table="Flows", + keyset=spanner_lib.KeySet(keys=[[client_id, flow_id]]), + columns=_READ_FLOW_OBJECT_COLS + ).one() + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + + flow = _ParseReadFlowObjectRow(client_id, flow_id, row) + now = rdfvalue.RDFDatetime.Now() + if flow.processing_on and flow.processing_deadline > int(now): + raise ValueError( + "Flow {}/{} is already being processed on {} since {} " + "with deadline {} (now: {})).".format( + client_id, + flow_id, + flow.processing_on, + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + flow.processing_since + ), + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + flow.processing_deadline + ), + now, + ) + ) + + if flow.parent_hunt_id is not None: + hunt_state = self._ReadHuntState(txn, flow.parent_hunt_id) + if ( + hunt_state is not None + and not models_hunts.IsHuntSuitableForFlowProcessing(hunt_state) + ): + raise db.ParentHuntIsNotRunningError( + client_id, flow_id, flow.parent_hunt_id, hunt_state + ) + + flow.processing_on = utils.ProcessIdString() + flow.processing_deadline = int(now + processing_time) + + txn.update( + table="Flows", + columns = ["ClientId", "FlowId", "ProcessingWorker", + "ProcessingEndTime","ProcessingStartTime"], + values=[[client_id, flow_id, flow.processing_on, + rdfvalue.RDFDatetime.FromMicrosecondsSinceEpoch( + flow.processing_deadline + ).AsDatetime(), + spanner_lib.COMMIT_TIMESTAMP, + ]] + ) + + return flow + + def Txn2(txn) -> flows_pb2.Flow: + try: + row = txn.read( + table="Flows", + keyset=spanner_lib.KeySet(keys=[[client_id, flow_id]]), + columns=_READ_FLOW_OBJECT_COLS + ).one() + flow = _ParseReadFlowObjectRow(client_id, flow_id, row) + except NotFound as error: + raise db.UnknownFlowError(client_id, flow_id, cause=error) + return flow + + leased_flow = self.db.Transact(Txn, txn_tag="LeaseFlowForProcessing") + flow = self.db.Transact(Txn2, txn_tag="LeaseFlowForProcessing2") + leased_flow.processing_since = flow.processing_since + return leased_flow + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReleaseProcessedFlow(self, flow_obj: flows_pb2.Flow) -> bool: + """Releases a flow that the worker was processing to the database.""" + + def Txn(txn) -> bool: + try: + row = txn.read( + table="FlowRequests", + keyset=spanner_lib.KeySet(keys=[[flow_obj.client_id, flow_obj.flow_id, flow_obj.next_request_to_process]]), + columns=["NeedsProcessing", "StartTime"] + ).one() + if row[0]: + start_time = row[1] + if start_time is None: + return False + elif ( + rdfvalue.RDFDatetime.FromDatetime(start_time) + < rdfvalue.RDFDatetime.Now() + ): + return False + except NotFound: + pass + txn.update( + table="Flows", + columns=["ClientId", "FlowId", "Flow", "State", "UserCpuTimeUsed", + "SystemCpuTimeUsed", "NetworkBytesSent", "ProcessingWorker", + "ProcessingStartTime", "ProcessingEndTime", "NextRequestToProcess", + "UpdateTime", "ReplyCount"], + values=[[flow_obj.client_id, flow_obj.flow_id,flow_obj, + int(flow_obj.flow_state), float(flow_obj.cpu_time_used.user_cpu_time), + float(flow_obj.cpu_time_used.system_cpu_time), + int(flow_obj.network_bytes_sent), None, None, None, + flow_obj.next_request_to_process, + spanner_lib.COMMIT_TIMESTAMP, + flow_obj.num_replies_sent, + ]] + ) + return True + + return self.db.Transact(Txn, txn_tag="ReleaseProcessedFlow") diff --git a/grr/server/grr_response_server/databases/spanner_flows_large_test.py b/grr/server/grr_response_server/databases/spanner_flows_large_test.py new file mode 100644 index 000000000..b021bbef3 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_flows_large_test.py @@ -0,0 +1,24 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_flows_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseFlowsTest( + db_flows_test.DatabaseLargeTestFlowMixin, + spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + + pass # Test methods are defined in the base mixin class. + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_flows_test.py b/grr/server/grr_response_server/databases/spanner_flows_test.py new file mode 100644 index 000000000..1b7e1c764 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_flows_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_flows_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseFlowsTest( + db_flows_test.DatabaseTestFlowMixin, spanner_test_lib.TestCase +): + """Spanner flow tests.""" + pass # Test methods are defined in the base mixin class. + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules.py b/grr/server/grr_response_server/databases/spanner_foreman_rules.py new file mode 100644 index 000000000..9b6950fe3 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +"""A module with foreman rules methods of the Spanner backend.""" + +from typing import Sequence + +from grr_response_core.lib import rdfvalue +from grr_response_proto import jobs_pb2 +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class ForemanRulesMixin: + """A Spanner database mixin with implementation of foreman rules.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteForemanRule(self, rule: jobs_pb2.ForemanCondition) -> None: + """Writes a foreman rule to the database.""" + row = { + "HuntId": rule.hunt_id, + "ExpirationTime": ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(rule.expiration_time) + .AsDatetime() + ), + "Payload": rule, + } + self.db.InsertOrUpdate( + table="ForemanRules", row=row, txn_tag="WriteForemanRule" + ) + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveForemanRule(self, hunt_id: str) -> None: + """Removes a foreman rule from the database.""" + self.db.Delete( + table="ForemanRules", key=[hunt_id], txn_tag="RemoveForemanRule" + ) + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadAllForemanRules(self) -> Sequence[jobs_pb2.ForemanCondition]: + """Reads all foreman rules from the database.""" + result = [] + + query = """ + SELECT fr.Payload + FROM ForemanRules AS fr + """ + for [payload] in self.db.Query(query, txn_tag="ReadAllForemanRules"): + rule = jobs_pb2.ForemanCondition() + rule.ParseFromString(payload) + result.append(rule) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def RemoveExpiredForemanRules(self) -> None: + """Removes all expired foreman rules from the database.""" + query = """ + DELETE + FROM ForemanRules@{{FORCE_INDEX=ForemanRulesByExpirationTime}} AS fr + WHERE fr.ExpirationTime < CURRENT_TIMESTAMP() + """ + self.db.ParamExecute(query, {}, txn_tag="RemoveExpiredForemanRules") + diff --git a/grr/server/grr_response_server/databases/spanner_foreman_rules_test.py b/grr/server/grr_response_server/databases/spanner_foreman_rules_test.py new file mode 100644 index 000000000..8c6153fc0 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_foreman_rules_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_foreman_rules_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseForemanRulesTest( + db_foreman_rules_test.DatabaseTestForemanRulesMixin, + spanner_test_lib.TestCase, +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_hunts.py b/grr/server/grr_response_server/databases/spanner_hunts.py new file mode 100644 index 000000000..00817a160 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_hunts.py @@ -0,0 +1,1401 @@ +#!/usr/bin/env python +"""A module with hunt methods of the Spanner database implementation.""" +import base64 +from collections.abc import Callable, Collection, Iterable, Mapping, Sequence, Set +from typing import Optional + +from google.api_core.exceptions import AlreadyExists, NotFound + +from google.cloud import spanner as spanner_lib +from google.cloud.spanner_v1 import param_types + +from google.protobuf import any_pb2 +from grr_response_core.lib import rdfvalue +from grr_response_core.lib.rdfvalues import client as rdf_client +from grr_response_proto import flows_pb2 +from grr_response_proto import hunts_pb2 +from grr_response_proto import jobs_pb2 +from grr_response_proto import output_plugin_pb2 +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import hunts as models_hunts + +def _HuntOutputPluginStateFromRow( + plugin_name: str, + plugin_args: Optional[bytes], + plugin_state: bytes, +) -> output_plugin_pb2.OutputPluginState: + """Creates OutputPluginState from the corresponding table's row data.""" + plugin_args_any = None + if plugin_args is not None: + plugin_args_any = any_pb2.Any() + plugin_args_any.ParseFromString(plugin_args) + + plugin_descriptor = output_plugin_pb2.OutputPluginDescriptor( + plugin_name=plugin_name, args=plugin_args_any + ) + + plugin_state_any = any_pb2.Any() + plugin_state_any.ParseFromString(plugin_state) + # currently AttributedDict is used to store output + # plugins' states. This is suboptimal and unsafe and should be refactored + # towards using per-plugin protos. + attributed_dict = jobs_pb2.AttributedDict() + plugin_state_any.Unpack(attributed_dict) + + return output_plugin_pb2.OutputPluginState( + plugin_descriptor=plugin_descriptor, plugin_state=attributed_dict + ) + + +def _BinsToQuery(bins: list[int], column_name: str) -> str: + """Builds an SQL query part to fetch counts corresponding to given bins.""" + result = [] + # With the current StatsHistogram implementation the last bin simply + # takes all the values that are greater than range_max_value of + # the one-before-the-last bin. range_max_value of the last bin + # is thus effectively ignored. + for prev_b, next_b in zip([0] + bins[:-1], bins[:-1] + [None]): + query = f"COUNT(CASE WHEN {column_name} >= {prev_b}" + if next_b is not None: + query += f" AND {column_name} < {next_b}" + + query += " THEN 1 END)" + + result.append(query) + + return ", ".join(result) + + +_HUNT_FLOW_CONDITION_TO_FLOW_STATE_MAPPING = { + abstract_db.HuntFlowsCondition.FAILED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.ERROR, + ), + abstract_db.HuntFlowsCondition.SUCCEEDED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.FINISHED, + ), + abstract_db.HuntFlowsCondition.COMPLETED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.ERROR, + flows_pb2.Flow.FlowState.FINISHED, + ), + abstract_db.HuntFlowsCondition.FLOWS_IN_PROGRESS_ONLY: ( + flows_pb2.Flow.FlowState.RUNNING, + ), + abstract_db.HuntFlowsCondition.CRASHED_FLOWS_ONLY: ( + flows_pb2.Flow.FlowState.CRASHED, + ), +} + +class HuntsMixin: + """A Spanner database mixin with implementation of flow methods.""" + + db: spanner_utils.Database + _write_rows_batch_size: int + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteHuntObject(self, hunt_obj: hunts_pb2.Hunt): + """Writes a hunt object to the database.""" + row = { + "HuntId": hunt_obj.hunt_id, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "LastUpdateTime": spanner_lib.COMMIT_TIMESTAMP, + "Creator": hunt_obj.creator, + "DurationMicros": hunt_obj.duration * 10**6, + "Description": hunt_obj.description, + "ClientRate": float(hunt_obj.client_rate), + "ClientLimit": hunt_obj.client_limit, + "State": int(hunt_obj.hunt_state), + "StateReason": int(hunt_obj.hunt_state_reason), + "StateComment": hunt_obj.hunt_state_comment, + "InitStartTime": None, + "LastStartTime": None, + "ClientCountAtStartTime": hunt_obj.num_clients_at_start_time, + "Hunt": hunt_obj, + } + + try: + self.db.Insert(table="Hunts", row=row, txn_tag="WriteHuntObject") + except AlreadyExists as error: + raise abstract_db.DuplicatedHuntError( + hunt_id=hunt_obj.hunt_id, cause=error + ) + + + def _UpdateHuntObject( + self, + txn, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[float] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, + ): + """Updates the hunt object within a given transaction.""" + params = { + "hunt_id": hunt_id, + } + params_types = { + "hunt_id": param_types.STRING + } + assignments = ["h.LastUpdateTime = PENDING_COMMIT_TIMESTAMP()"] + + if duration is not None: + assignments.append("h.DurationMicros = @duration_micros") + params["duration_micros"] = int(duration.microseconds) + params_types["duration_micros"] = param_types.INT64 + + if client_rate is not None: + assignments.append("h.ClientRate = @client_rate") + params["client_rate"] = float(client_rate) + params_types["client_rate"] = param_types.FLOAT64 + + if client_limit is not None: + assignments.append("h.ClientLimit = @client_limit") + params["client_limit"] = int(client_limit) + params_types["client_limit"] = param_types.INT64 + + if hunt_state is not None: + assignments.append("h.State = @hunt_state") + params["hunt_state"] = int(hunt_state) + params_types["hunt_state"] = param_types.INT64 + + if hunt_state_reason is not None: + assignments.append("h.StateReason = @hunt_state_reason") + params["hunt_state_reason"] = int(hunt_state_reason) + params_types["hunt_state_reason"] = param_types.INT64 + + if hunt_state_comment is not None: + assignments.append("h.StateComment = @hunt_state_comment") + params["hunt_state_comment"] = hunt_state_comment + params_types["hunt_state_comment"] = param_types.STRING + + if start_time is not None: + assignments.append( + "h.InitStartTime = IFNULL(h.InitStartTime, @start_time)" + ) + assignments.append("h.LastStartTime = @start_time") + params["start_time"] = start_time.AsDatetime() + params_types["start_time"] = param_types.TIMESTAMP + + if num_clients_at_start_time is not None: + assignments.append( + "h.ClientCountAtStartTime = @client_count_at_start_time" + ) + params["client_count_at_start_time"] = int( + num_clients_at_start_time + ) + params_types["client_count_at_start_time"] = param_types.INT64 + + query = f""" + UPDATE Hunts AS h + SET {", ".join(assignments)} + WHERE h.HuntId = @hunt_id + """ + + txn.execute_update(query, params=params, param_types=params_types, + request_options={"request_tag": "_UpdateHuntObject:Hunts:execute_update"}) + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateHuntObject( + self, + hunt_id: str, + duration: Optional[rdfvalue.Duration] = None, + client_rate: Optional[int] = None, + client_limit: Optional[int] = None, + hunt_state: Optional[hunts_pb2.Hunt.HuntState.ValueType] = None, + hunt_state_reason: Optional[ + hunts_pb2.Hunt.HuntStateReason.ValueType + ] = None, + hunt_state_comment: Optional[str] = None, + start_time: Optional[rdfvalue.RDFDatetime] = None, + num_clients_at_start_time: Optional[int] = None, + ): + """Updates the hunt object.""" + + def Txn(txn) -> None: + # Make sure the hunt is there. + try: + keyset = spanner_lib.KeySet(keys=[[hunt_id]]) + txn.read(table="Hunts", keyset=keyset, columns=["HuntId",]).one() + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + # Then update the hunt. + self._UpdateHuntObject( + txn, + hunt_id, + duration=duration, + client_rate=client_rate, + client_limit=client_limit, + hunt_state=hunt_state, + hunt_state_reason=hunt_state_reason, + hunt_state_comment=hunt_state_comment, + start_time=start_time, + num_clients_at_start_time=num_clients_at_start_time, + ) + + self.db.Transact(Txn, txn_tag="UpdateHuntObject") + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteHuntObject(self, hunt_id: str) -> None: + """Deletes a hunt object with a given id.""" + self.db.Delete( + table="Hunts", key=[hunt_id], txn_tag="DeleteHuntObject" + ) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntObject(self, hunt_id: str) -> hunts_pb2.Hunt: + """Reads a hunt object from the database.""" + + cols = [ + "Hunt", + "CreationTime", + "LastUpdateTime", + "DurationMicros", + "Description", + "Creator", + "ClientRate", + "ClientLimit", + "State", + "StateReason", + "StateComment", + "InitStartTime", + "LastStartTime", + "ClientCountAtStartTime", + ] + + try: + row = self.db.Read(table="Hunts", + key=[hunt_id], + cols=cols, + txn_tag="ReadHuntObject") + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + hunt_obj = hunts_pb2.Hunt() + hunt_obj.ParseFromString(row[0]) + hunt_obj.create_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[1]) + ) + hunt_obj.last_update_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[2]) + ) + hunt_obj.duration = rdfvalue.DurationSeconds.From( + row[3], rdfvalue.MICROSECONDS + ).ToInt(rdfvalue.SECONDS) + hunt_obj.description = row[4] + hunt_obj.creator = row[5] + hunt_obj.client_rate = row[6] + hunt_obj.client_limit = row[7] + hunt_obj.hunt_state = row[8] + hunt_obj.hunt_state_reason = row[9] + hunt_obj.hunt_state_comment = row[10] + if row[11] is not None: + hunt_obj.init_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + if row[12] is not None: + hunt_obj.last_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[12]) + ) + hunt_obj.num_clients_at_start_time = row[13] + + return hunt_obj + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntObjects( + self, + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[Set[str]] = None, + not_created_by: Optional[Set[str]] = None, + with_states: Optional[ + Iterable[hunts_pb2.Hunt.HuntState.ValueType] + ] = None + ) -> list[hunts_pb2.Hunt]: + """Reads hunt objects from the database.""" + + conditions = [] + params = { + "limit": count, + "offset": offset, + } + param_type = {} + + if with_creator is not None: + conditions.append("h.Creator = {creator}") + params["creator"] = with_creator + + if created_by is not None: + conditions.append("h.Creator IN UNNEST({created_by})") + params["created_by"] = list(created_by) + param_type["created_by"] = param_types.Array(param_types.STRING) + + if not_created_by is not None: + conditions.append("h.Creator NOT IN UNNEST({not_created_by})") + params["not_created_by"] = list(not_created_by) + param_type["not_created_by"] = param_types.Array(param_types.STRING) + + if created_after is not None: + conditions.append("h.CreationTime > {creation_time}") + params["creation_time"] = created_after.AsDatetime() + + if with_description_match is not None: + conditions.append("h.Description LIKE {description}") + params["description"] = f"%{with_description_match}%" + + if with_states is not None: + if not with_states: + return [] + ors = [] + for i, state in enumerate(with_states): + ors.append(f"h.State = {{state_{i}}}") + params[f"state_{i}"] = int(state) + conditions.append("(" + " OR ".join(ors) + ")") + + query = """ + SELECT + h.HuntId, + h.CreationTime, + h.LastUpdateTime, + h.Creator, + h.DurationMicros, + h.Description, + h.ClientRate, + h.ClientLimit, + h.State, + h.StateReason, + h.StateComment, + h.InitStartTime, + h.LastStartTime, + h.ClientCountAtStartTime, + h.Hunt + FROM Hunts AS h + """ + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += """ + ORDER BY h.CreationTime DESC + LIMIT {limit} OFFSET {offset} + """ + + result = [] + for row in self.db.ParamQuery(query, params, param_type=param_type, txn_tag="ReadHuntObjects"): + hunt_obj = hunts_pb2.Hunt() + hunt_obj.ParseFromString(row[14]) + + hunt_obj.create_time = int(rdfvalue.RDFDatetime.FromDatetime(row[1])) + hunt_obj.last_update_time = int(rdfvalue.RDFDatetime.FromDatetime(row[2])) + hunt_obj.creator = row[3] + hunt_obj.duration = rdfvalue.DurationSeconds.From( + row[4], rdfvalue.MICROSECONDS + ).ToInt(rdfvalue.SECONDS) + + hunt_obj.description = row[5] + hunt_obj.client_rate = row[6] + hunt_obj.client_limit = row[7] + hunt_obj.hunt_state = row[8] + hunt_obj.hunt_state_reason = row[9] + hunt_obj.hunt_state_comment = row[10] + + if row[11] is not None: + hunt_obj.init_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + + if row[12] is not None: + hunt_obj.last_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[12]) + ) + + hunt_obj.num_clients_at_start_time = row[13] + + result.append(hunt_obj) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListHuntObjects( + self, + offset: int, + count: int, + with_creator: Optional[str] = None, + created_after: Optional[rdfvalue.RDFDatetime] = None, + with_description_match: Optional[str] = None, + created_by: Optional[Iterable[str]] = None, + not_created_by: Optional[Iterable[str]] = None, + with_states: Optional[ + Iterable[hunts_pb2.Hunt.HuntState.ValueType] + ] = None, + ) -> Iterable[hunts_pb2.HuntMetadata]: + """Reads metadata for hunt objects from the database.""" + + conditions = [] + params = { + "limit": count, + "offset": offset, + } + param_type = {} + + if with_creator is not None: + conditions.append("h.Creator = {creator}") + params["creator"] = with_creator + + if created_by is not None: + conditions.append("h.Creator IN UNNEST({created_by})") + params["created_by"] = list(created_by) + param_type["created_by"] = param_types.Array(param_types.STRING) + + if not_created_by is not None: + conditions.append("h.Creator NOT IN UNNEST({not_created_by})") + params["not_created_by"] = list(not_created_by) + param_type["not_created_by"] = param_types.Array(param_types.STRING) + + if created_after is not None: + conditions.append("h.CreationTime > {creation_time}") + params["creation_time"] = created_after.AsDatetime() + + if with_description_match is not None: + conditions.append("h.Description LIKE {description}") + params["description"] = f"%{with_description_match}%" + + if with_states is not None: + if not with_states: + return [] + ors = [] + for i, state in enumerate(with_states): + ors.append(f"h.State = {{state_{i}}}") + params[f"state_{i}"] = int(state) + conditions.append("(" + " OR ".join(ors) + ")") + + query = """ + SELECT + h.HuntId, + h.CreationTime, + h.LastUpdateTime, + h.Creator, + h.DurationMicros, + h.Description, + h.ClientRate, + h.ClientLimit, + h.State, + h.StateComment, + h.InitStartTime, + h.LastStartTime, + FROM Hunts AS h + """ + + if conditions: + query += " WHERE " + " AND ".join(conditions) + + query += """ + ORDER BY h.CreationTime DESC + LIMIT {limit} OFFSET {offset} + """ + + result = [] + for row in self.db.ParamQuery(query, params, + param_type=param_type, txn_tag="ListHuntObjects"): + hunt_mdata = hunts_pb2.HuntMetadata() + hunt_mdata.hunt_id = row[0] + + hunt_mdata.create_time = int(rdfvalue.RDFDatetime.FromDatetime(row[1])) + hunt_mdata.last_update_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[2]) + ) + hunt_mdata.creator = row[3] + hunt_mdata.duration = int( + rdfvalue.Duration.From(row[4], rdfvalue.MICROSECONDS).ToInt( + rdfvalue.SECONDS + ) + ) + if row[5]: + hunt_mdata.description = row[5] + hunt_mdata.client_rate = row[6] + hunt_mdata.client_limit = row[7] + hunt_mdata.hunt_state = row[8] + if row[9]: + hunt_mdata.hunt_state_comment = row[9] + + if row[10] is not None: + hunt_mdata.init_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[10]) + ) + + if row[11] is not None: + hunt_mdata.last_start_time = int( + rdfvalue.RDFDatetime.FromDatetime(row[11]) + ) + + result.append(hunt_mdata) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntResults( + self, + hunt_id: str, + offset: int, + count: int, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + with_substring: Optional[str] = None, + with_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> Iterable[flows_pb2.FlowResult]: + """Reads hunt results of a given hunt using given query options.""" + params = { + "hunt_id": hunt_id, + "offset": offset, + "count": count, + } + param_type = {} + + query = """ + SELECT t.Payload, t.CreationTime, t.Tag, t.ClientId, t.FlowId + FROM FlowResults AS t + WHERE t.HuntId = {hunt_id} + AND t.FlowId = {hunt_id} + """ + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + param_type["tag"] = param_types.STRING + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + param_type["type"] = param_types.STRING + + if with_substring is not None: + query += ( + " AND STRPOS(SAFE_CONVERT_BYTES_TO_STRING(t.Payload.value), " + "{substring}) != 0" + ) + params["substring"] = with_substring + param_type["substring"] = param_types.STRING + + if with_timestamp is not None: + query += " AND t.CreationTime = {creation_time}" + params["creation_time"] = with_timestamp.AsDatetime() + param_type["creation_time"] = param_types.TIMESTAMP + + query += """ + ORDER BY t.CreationTime ASC LIMIT {count} OFFSET {offset} + """ + + results = [] + for ( + payload_bytes, + creation_time, + tag, + client_id, + flow_id, + ) in self.db.ParamQuery(query, params, param_type=param_type, + txn_tag="ReadHuntResults"): + result = flows_pb2.FlowResult() + result.hunt_id = hunt_id + result.client_id = client_id + result.flow_id = flow_id + result.timestamp = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + result.payload.ParseFromString(payload_bytes) + + if tag is not None: + result.tag = tag + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntResults( + self, + hunt_id: str, + with_tag: Optional[str] = None, + with_type: Optional[str] = None, + ) -> int: + """Counts hunt results of a given hunt using given query options.""" + params = { + "hunt_id": hunt_id, + } + param_type = {} + + query = """ + SELECT COUNT(*) + FROM FlowResults AS t + WHERE t.HuntId = {hunt_id} + """ + + if with_tag is not None: + query += " AND t.Tag = {tag} " + params["tag"] = with_tag + param_type["tag"] = param_types.STRING + + if with_type is not None: + query += " AND t.RdfType = {type}" + params["type"] = with_type + param_type["type"] = param_types.STRING + + (count,) = self.db.ParamQuerySingle( + query, params, + param_type=param_type, + txn_tag="CountHuntResults" + ) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntResultsByType(self, hunt_id: str) -> Mapping[str, int]: + """Returns counts of items in hunt results grouped by type.""" + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT t.RdfType, COUNT(*) + FROM FlowResults AS t + WHERE t.HuntId = {hunt_id} + GROUP BY t.RdfType + """ + + result = {} + for type_name, count in self.db.ParamQuery( + query, params, txn_tag="CountHuntResultsByType" + ): + result[type_name] = count + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntLogEntries( + self, + hunt_id: str, + offset: int, + count: int, + with_substring: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowLogEntry]: + """Reads hunt log entries of a given hunt using given query options.""" + params = { + "hunt_id": hunt_id, + "offset": offset, + "count": count, + } + + query = """ + SELECT l.ClientId, + l.FlowId, + l.CreationTime, + l.Message + FROM FlowLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.FlowId = {hunt_id} + """ + + if with_substring is not None: + query += " AND STRPOS(l.Message, {substring}) != 0" + params["substring"] = with_substring + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + results = [] + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntLogEntries"): + client_id, flow_id, creation_time, message = row + + result = flows_pb2.FlowLogEntry() + result.hunt_id = hunt_id + result.client_id = client_id + result.flow_id = flow_id + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.message = message + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntLogEntries(self, hunt_id: str) -> int: + """Returns number of hunt log entries of a given hunt.""" + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT COUNT(*) + FROM FlowLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.FlowId = {hunt_id} + """ + + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountHuntLogEntries" + ) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntOutputPluginLogEntries( + self, + hunt_id: str, + output_plugin_id: str, + offset: int, + count: int, + with_type: Optional[str] = None, + ) -> Sequence[flows_pb2.FlowOutputPluginLogEntry]: + """Reads hunt output plugin log entries.""" + + query = """ + SELECT l.ClientId, + l.FlowId, + l.CreationTime, + l.Type, + l.Message + FROM FlowOutputPluginLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "hunt_id": hunt_id, + "output_plugin_id": output_plugin_id, + } + param_type = {} + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + param_type["type"] = param_types.INT64 + + query += """ + LIMIT {count} + OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + results = [] + for row in self.db.ParamQuery( + query, params, param_type=param_type, + txn_tag="ReadHuntOutputPluginLogEntries" + ): + client_id, flow_id, creation_time, int_type, message = row + + result = flows_pb2.FlowOutputPluginLogEntry() + result.hunt_id = hunt_id + result.client_id = client_id + result.flow_id = flow_id + result.output_plugin_id = output_plugin_id + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch() + result.log_entry_type = int_type + result.message = message + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntOutputPluginLogEntries( + self, + hunt_id: str, + output_plugin_id: str, + with_type: Optional[str] = None, + ) -> int: + """Returns number of hunt output plugin log entries of a given hunt.""" + query = """ + SELECT COUNT(*) + FROM FlowOutputPluginLogEntries AS l + WHERE l.HuntId = {hunt_id} + AND l.OutputPluginId = {output_plugin_id} + """ + params = { + "hunt_id": hunt_id, + "output_plugin_id": output_plugin_id, + } + + if with_type is not None: + query += " AND l.Type = {type}" + params["type"] = int(with_type) + + (count,) = self.db.ParamQuerySingle( + query, params, txn_tag="CountHuntOutputPluginLogEntries" + ) + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntOutputPluginsStates( + self, hunt_id: str + ) -> list[output_plugin_pb2.OutputPluginState]: + """Reads all hunt output plugins states of a given hunt.""" + # Make sure the hunt is there. + try: + self.db.Read(table="Hunts", + key=[hunt_id,], + cols=("HuntId",), + txn_tag="ReadHuntOutputPluginsStates") + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT s.Name, + s.Args, + s.State + FROM HuntOutputPlugins AS s + WHERE s.HuntId = {hunt_id} + """ + + results = [] + for row in self.db.ParamQuery( + query, params, txn_tag="ReadHuntOutputPluginsStates" + ): + name, args, state = row + results.append(_HuntOutputPluginStateFromRow(name, args, state)) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteHuntOutputPluginsStates( + self, + hunt_id: str, + states: Collection[output_plugin_pb2.OutputPluginState], + ) -> None: + """Writes hunt output plugin states for a given hunt.""" + + def Mutation(mut) -> None: + for index, state in enumerate(states): + state_any = any_pb2.Any() + state_any.Pack(state.plugin_state) + columns = ["HuntId", "OutputPluginId", + "Name", + "State" + ] + row = [hunt_id, index, + state.plugin_descriptor.plugin_name, + base64.b64encode(state_any.SerializeToString()), + ] + + if state.plugin_descriptor.HasField("args"): + columns.append("Args") + row.append(base64.b64encode(state.plugin_descriptor.args.SerializeToString())) + + mut.insert_or_update(table="HuntOutputPlugins", columns=columns, values=[row]) + + try: + self.db.Mutate(Mutation, txn_tag="WriteHuntOutputPluginsStates") + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateHuntOutputPluginState( + self, + hunt_id: str, + state_index: int, + update_fn: Callable[[jobs_pb2.AttributedDict], jobs_pb2.AttributedDict], + ) -> None: + """Updates hunt output plugin state for a given output plugin.""" + + def Txn(txn) -> None: + row = txn.read( + table="HuntOutputPlugins", + keyset=spanner_lib.KeySet(keys=[[hunt_id, state_index]]), + columns=["Name", "Args", "State"], + request_options={"request_tag": "UpdateHuntOutputPluginState:HuntOutputPlugins:read"} + ).one() + state = _HuntOutputPluginStateFromRow( + row[0], row[1], row[2] + ) + + modified_plugin_state = update_fn(state.plugin_state) + modified_plugin_state_any = any_pb2.Any() + modified_plugin_state_any.Pack(modified_plugin_state) + columns = ["HuntId", "OutputPluginId", "State"] + row = [hunt_id,state_index, + base64.b64encode(modified_plugin_state_any.SerializeToString())] + txn.update("HuntOutputPlugins", columns=columns, values=[row]) + + try: + self.db.Transact(Txn, txn_tag="UpdateHuntOutputPluginState") + except NotFound as e: + raise abstract_db.UnknownHuntError(hunt_id) from e + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntFlows( # pytype: disable=annotation-type-mismatch + self, + hunt_id: str, + offset: int, + count: int, + filter_condition: abstract_db.HuntFlowsCondition = abstract_db.HuntFlowsCondition.UNSET, + ) -> Sequence[flows_pb2.Flow]: + """Reads hunt flows matching given conditions.""" + + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT + f.ClientId, + f.Creator, + f.Name, + f.State, + f.CreationTime, + f.UpdateTime, + f.Crash, + f.ProcessingWorker, + f.ProcessingStartTime, + f.ProcessingEndTime, + f.NextRequestToProcess, + f.Flow + FROM Flows AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + """ + + if filter_condition != abstract_db.HuntFlowsCondition.UNSET: + states = _HUNT_FLOW_CONDITION_TO_FLOW_STATE_MAPPING[filter_condition] # pytype: disable=unsupported-operands + conditions = (f"f.State = {{state_{i}}}" for i in range(len(states))) + query += "AND (" + " OR ".join(conditions) + ")" + for i, state in enumerate(states): + params[f"state_{i}"] = int(state) + + query += """ + ORDER BY f.UpdateTime ASC LIMIT {count} OFFSET {offset} + """ + params["offset"] = offset + params["count"] = count + + results = [] + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntFlows"): + ( + client_id, + creator, + name, + state, + creation_time, + update_time, + crash, + processing_worker, + processing_start_time, + processing_end_time, + next_request_to_process, + flow_payload_bytes, + ) = row + + result = flows_pb2.Flow() + result.ParseFromString(flow_payload_bytes) + result.client_id = client_id + result.parent_hunt_id = hunt_id + result.creator = creator + result.flow_class_name = name + result.flow_state = state + result.next_request_to_process = int(next_request_to_process) + result.create_time = int(rdfvalue.RDFDatetime.FromDatetime(creation_time)) + result.last_update_time = int( + rdfvalue.RDFDatetime.FromDatetime(update_time) + ) + + if crash is not None: + client_crash = jobs_pb2.ClientCrash() + client_crash.ParseFromString(crash) + result.client_crash_info.CopyFrom(client_crash) + + if processing_worker: + result.processing_on = processing_worker + + if processing_start_time is not None: + result.processing_since = int( + rdfvalue.RDFDatetime.FromDatetime(processing_start_time) + ) + if processing_end_time is not None: + result.processing_deadline = int( + rdfvalue.RDFDatetime.FromDatetime(processing_end_time) + ) + + results.append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntFlowErrors( + self, + hunt_id: str, + offset: int, + count: int, + ) -> Mapping[str, abstract_db.FlowErrorInfo]: + """Returns errors for flows of the given hunt.""" + results = {} + + query = """ + SELECT f.ClientId, + f.UpdateTime, + f.Flow.error_message, + f.Flow.backtrace, + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + AND f.State = 'ERROR' + ORDER BY f.UpdateTime ASC LIMIT {count} + OFFSET {offset} + """ + + params = { + "hunt_id": hunt_id, + "offset": offset, + "count": count, + } + + for row in self.db.ParamQuery(query, params, txn_tag="ReadHuntFlowErrors"): + (client_id, time, message, backtrace) = row + + info = abstract_db.FlowErrorInfo( + message=message, + time=rdfvalue.RDFDatetime.FromDate(time), + ) + if backtrace: + info.backtrace = backtrace + + results[client_id] = info + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountHuntFlows( # pytype: disable=annotation-type-mismatch + self, + hunt_id: str, + filter_condition: Optional[ + abstract_db.HuntFlowsCondition + ] = abstract_db.HuntFlowsCondition.UNSET, + ) -> int: + """Counts hunt flows matching given conditions.""" + + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT COUNT(*) + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + """ + + if filter_condition != abstract_db.HuntFlowsCondition.UNSET: + states = _HUNT_FLOW_CONDITION_TO_FLOW_STATE_MAPPING[filter_condition] # pytype: disable=unsupported-operands + query += f""" AND({ + " OR ".join(f"f.State = {{state_{i}}}" for i in range(len(states))) + })""" + for i, state in enumerate(states): + params[f"state_{i}"] = int(state) + + (count,) = self.db.ParamQuerySingle(query, params, txn_tag="CountHuntFlows") + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntFlowsStatesAndTimestamps( + self, + hunt_id: str, + ) -> Sequence[abstract_db.FlowStateAndTimestamps]: + """Reads hunt flows states and timestamps.""" + + params = { + "hunt_id": hunt_id, + } + + query = """ + SELECT f.State, f.CreationTime, f.UpdateTime + FROM Flows AS f + WHERE f.ParentHuntId = {hunt_id} + AND f.FlowId = {hunt_id} + """ + + results = [] + for row in self.db.ParamQuery( + query, params, txn_tag="ReadHuntFlowsStatesAndTimestamps" + ): + int_state, creation_time, update_time = row + + results.append( + abstract_db.FlowStateAndTimestamps( + flow_state=int_state, + create_time=rdfvalue.RDFDatetime.FromDatetime(creation_time), + last_update_time=rdfvalue.RDFDatetime.FromDatetime(update_time), + ) + ) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntsCounters( + self, + hunt_ids: Collection[str], + ) -> Mapping[str, abstract_db.HuntCounters]: + """Reads hunt counters for several of hunt ids.""" + + params = { + "hunt_ids": hunt_ids, + } + param_type = { + "hunt_ids": param_types.Array(param_types.STRING) + } + + states_query = """ + SELECT f.ParentHuntID, f.State, COUNT(*) + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntID IN UNNEST({hunt_ids}) + AND f.FlowId IN UNNEST({hunt_ids}) + AND f.FlowId = f.ParentHuntID + GROUP BY f.ParentHuntID, f.State + """ + counts_by_state_per_hunt = dict.fromkeys(hunt_ids, {}) + for hunt_id, state, count in self.db.ParamQuery( + states_query, params, param_type=param_type, txn_tag="ReadHuntCounters_1" + ): + counts_by_state_per_hunt[hunt_id][state] = count + + hunt_counters = dict.fromkeys( + hunt_ids, + abstract_db.HuntCounters( + num_clients=0, + num_successful_clients=0, + num_failed_clients=0, + num_clients_with_results=0, + num_crashed_clients=0, + num_running_clients=0, + num_results=0, + total_cpu_seconds=0, + total_network_bytes_sent=0, + ), + ) + + resources_results_query = """ + SELECT + f.ParentHuntID, + SUM(f.UserCpuTimeUsed + f.SystemCpuTimeUsed), + SUM(f.NetworkBytesSent), + SUM(f.ReplyCount), + COUNT(IF(f.ReplyCount > 0, f.ClientId, NULL)) + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntID IN UNNEST({hunt_ids}) + AND f.FlowId IN UNNEST({hunt_ids}) + AND f.ParentHuntID = f.FlowID + GROUP BY f.ParentHuntID + """ + for ( + hunt_id, + total_cpu_seconds, + total_network_bytes_sent, + num_results, + num_clients_with_results, + ) in self.db.ParamQuery( + resources_results_query, params, + param_type=param_type, txn_tag="ReadHuntCounters_2" + ): + counts_by_state = counts_by_state_per_hunt[hunt_id] + num_successful_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.FINISHED, 0 + ) + num_failed_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.ERROR, 0 + ) + num_crashed_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.CRASHED, 0 + ) + num_running_clients = counts_by_state.get( + flows_pb2.Flow.FlowState.RUNNING, 0 + ) + num_clients = sum(counts_by_state.values()) + + hunt_counters[hunt_id] = abstract_db.HuntCounters( + num_clients=num_clients, + num_successful_clients=num_successful_clients, + num_failed_clients=num_failed_clients, + num_clients_with_results=num_clients_with_results, + num_crashed_clients=num_crashed_clients, + num_running_clients=num_running_clients, + # Spanner's SUM on no elements returns NULL - accounting for + # this here. + num_results=num_results or 0, + total_cpu_seconds=total_cpu_seconds or 0, + total_network_bytes_sent=total_network_bytes_sent or 0, + ) + return hunt_counters + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadHuntClientResourcesStats( + self, + hunt_id: str, + ) -> jobs_pb2.ClientResourcesStats: + """Read hunt client resources stats.""" + + params = { + "hunt_id": hunt_id, + } + + # For some reason Spanner SQL doesn't have STDDEV_POP aggregate + # function. Thus, we have to effectively reimplement it ourselves + # using other aggregate functions. + # For the reference, see: + # http://google3/ops/security/grr/core/grr_response_core/lib/rdfvalues/stats.py;l=121;rcl=379683316 + query = """ + SELECT + COUNT(*), + SUM(f.UserCpuTimeUsed), + SQRT(SUM(POW(f.UserCpuTimeUsed, 2)) / COUNT(*) - POW(AVG(f.UserCpuTimeUsed), 2)), + SUM(f.SystemCpuTimeUsed), + SQRT(SUM(POW(f.SystemCpuTimeUsed, 2)) / COUNT(*) - POW(AVG(f.SystemCpuTimeUsed), 2)), + SUM(f.NetworkBytesSent), + SQRT(SUM(POW(f.NetworkBytesSent, 2)) / COUNT(*) - POW(AVG(f.NetworkBytesSent), 2)), + """ + + query += ", ".join([ + _BinsToQuery(models_hunts.CPU_STATS_BINS, "f.UserCpuTimeUsed"), + _BinsToQuery(models_hunts.CPU_STATS_BINS, "f.SystemCpuTimeUsed"), + _BinsToQuery(models_hunts.NETWORK_STATS_BINS, "f.NetworkBytesSent"), + ]) + + query += """ + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE f.ParentHuntID = {hunt_id} AND f.FlowId = {hunt_id} + """ + + response = self.db.ParamQuerySingle( + query, params, txn_tag="ReadHuntClientResourcesStats_1" + ) + + ( + count, + user_sum, + user_stddev, + system_sum, + system_stddev, + network_sum, + network_stddev, + ) = response[:7] + + stats = jobs_pb2.ClientResourcesStats( + user_cpu_stats=jobs_pb2.RunningStats( + num=count, + sum=user_sum, + stddev=user_stddev, + ), + system_cpu_stats=jobs_pb2.RunningStats( + num=count, + sum=system_sum, + stddev=system_stddev, + ), + network_bytes_sent_stats=jobs_pb2.RunningStats( + num=count, + sum=network_sum, + stddev=network_stddev, + ), + ) + + offset = 7 + user_cpu_stats_histogram = jobs_pb2.StatsHistogram() + for b_num, b_max_value in zip( + response[offset:], models_hunts.CPU_STATS_BINS + ): + user_cpu_stats_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.user_cpu_stats.histogram.CopyFrom(user_cpu_stats_histogram) + + offset += len(models_hunts.CPU_STATS_BINS) + system_cpu_stats_histogram = jobs_pb2.StatsHistogram() + for b_num, b_max_value in zip( + response[offset:], models_hunts.CPU_STATS_BINS + ): + system_cpu_stats_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.system_cpu_stats.histogram.CopyFrom(system_cpu_stats_histogram) + + offset += len(models_hunts.CPU_STATS_BINS) + network_bytes_histogram = jobs_pb2.StatsHistogram() + for b_num, b_max_value in zip( + response[offset:], models_hunts.NETWORK_STATS_BINS + ): + network_bytes_histogram.bins.append( + jobs_pb2.StatsHistogramBin(range_max_value=b_max_value, num=b_num) + ) + stats.network_bytes_sent_stats.histogram.CopyFrom(network_bytes_histogram) + + clients_query = """ + SELECT + f.ClientID, + f.FlowID, + f.UserCPUTimeUsed, + f.SystemCPUTimeUsed, + f.NetworkBytesSent + FROM Flows@{{FORCE_INDEX=FlowsByParentHuntIdFlowIdState}} AS f + WHERE + f.ParentHuntID = {hunt_id} AND + f.FlowId = {hunt_id} AND + (f.UserCpuTimeUsed > 0 OR + f.SystemCpuTimeUsed > 0 OR + f.NetworkBytesSent > 0) + ORDER BY (f.UserCPUTimeUsed + f.SystemCPUTimeUsed) DESC + LIMIT 10 + """ + + responses = self.db.ParamQuery( + clients_query, params, txn_tag="ReadHuntClientResourcesStats_2" + ) + for cid, fid, ucpu, scpu, nbs in responses: + client_id = cid + flow_id = fid + stats.worst_performers.append( + jobs_pb2.ClientResources( + client_id=str(rdf_client.ClientURN.FromHumanReadable(client_id)), + session_id=str(rdfvalue.RDFURN(client_id).Add(flow_id)), + cpu_usage=jobs_pb2.CpuSeconds( + user_cpu_time=ucpu, + system_cpu_time=scpu, + ), + network_bytes_sent=nbs, + ) + ) + + return stats diff --git a/grr/server/grr_response_server/databases/spanner_hunts_test.py b/grr/server/grr_response_server/databases/spanner_hunts_test.py new file mode 100644 index 000000000..fc7a56afa --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_hunts_test.py @@ -0,0 +1,25 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_hunts_test +from grr_response_server.databases import db_test_utils +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseHuntsTest( + db_hunts_test.DatabaseTestHuntMixin, + db_test_utils.QueryTestHelpersMixin, + spanner_test_lib.TestCase, +): + pass + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_message_handler_test.py b/grr/server/grr_response_server/databases/spanner_message_handler_test.py new file mode 100644 index 000000000..636b39f15 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_message_handler_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_message_handler_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseHandlerTest( + db_message_handler_test.DatabaseTestHandlerMixin, + spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_paths.py b/grr/server/grr_response_server/databases/spanner_paths.py new file mode 100644 index 000000000..c821fb021 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_paths.py @@ -0,0 +1,635 @@ +#!/usr/bin/env python +"""A module with path methods of the Spanner database implementation.""" +import base64 +from collections.abc import Collection, Iterable, Sequence +from typing import Optional + +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + +from grr_response_core.lib import rdfvalue +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import paths as models_paths + +class PathsMixin: + """A Spanner database mixin with implementation of path methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WritePathInfos( + self, + client_id: str, + path_infos: Iterable[objects_pb2.PathInfo], + ) -> None: + """Writes a collection of path records for a client.""" + # Special case for empty list of paths because Spanner does not like empty + # mutations. We still have to validate the client id. + if not path_infos: + try: + self.db.Read(table="Clients", + key=[client_id], + cols=(["ClientId"]), + txn_tag="WritePathInfos") + except NotFound as error: + raise abstract_db.UnknownClientError(client_id) from error + return + + def Mutation(mut) -> None: + ancestors = set() + file_hash_columns = ["ClientId", "Type", "Path", "CreationTime", "FileHash"] + file_stat_columns = ["ClientId", "Type", "Path", "CreationTime", "Stat"] + for path_info in path_infos: + path_columns = ["ClientId", "Type", "Path", "CreationTime", "IsDir", "Depth"] + int_path_type = int(path_info.path_type) + path = EncodePathComponents(path_info.components) + + path_row = [client_id, int_path_type, path, + spanner_lib.COMMIT_TIMESTAMP, + path_info.directory, + len(path_info.components) + ] + + if path_info.HasField("stat_entry"): + path_columns.append("LastFileStatTime") + path_row.append(spanner_lib.COMMIT_TIMESTAMP) + + file_stat_row = [client_id, int_path_type, path, + spanner_lib.COMMIT_TIMESTAMP, + path_info.stat_entry + ] + else: + file_stat_row = None + + if path_info.HasField("hash_entry"): + path_columns.append("LastFileHashTime") + path_row.append(spanner_lib.COMMIT_TIMESTAMP) + + file_hash_row = [client_id, int_path_type, path, + spanner_lib.COMMIT_TIMESTAMP, + path_info.hash_entry, + ] + else: + file_hash_row = None + + mut.insert_or_update(table="Paths", columns=path_columns, values=[path_row]) + if file_stat_row is not None: + mut.insert(table="PathFileStats", columns=file_stat_columns, values=[file_stat_row]) + if file_hash_row is not None: + mut.insert(table="PathFileHashes", columns=file_hash_columns, values=[file_hash_row]) + + for path_info_ancestor in models_paths.GetAncestorPathInfos(path_info): + components = tuple(path_info_ancestor.components) + ancestors.add((path_info.path_type, components)) + + path_columns = ["ClientId", "Type", "Path", "CreationTime", "IsDir", "Depth"] + for path_type, components in ancestors: + path_row = [client_id, int(path_type), EncodePathComponents(components), + spanner_lib.COMMIT_TIMESTAMP, True, len(components), + ] + mut.insert_or_update(table="Paths", columns=path_columns, values=[path_row]) + + try: + self.db.Mutate(Mutation, txn_tag="WritePathInfos") + except NotFound as error: + if "Parent row for row [" in str(error): + raise abstract_db.UnknownClientError(client_id) from error + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadPathInfos( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components_list: Collection[Sequence[str]], + ) -> dict[tuple[str, ...], Optional[objects_pb2.PathInfo]]: + """Retrieves path info records for given paths.""" + # Early return to avoid issues with unnesting empty path list in the query. + if not components_list: + return {} + + results = {tuple(components): None for components in components_list} + + query = """ + SELECT p.Path, p.CreationTime, p.IsDir, + ps.CreationTime, ps.Stat, + ph.CreationTime, ph.FileHash + FROM Paths AS p + LEFT JOIN PathFileStats AS ps + ON p.ClientId = ps.ClientId + AND p.Type = ps.Type + AND p.Path = ps.Path + AND p.LastFileStatTime = ps.CreationTime + LEFT JOIN PathFileHashes AS ph + ON p.ClientId = ph.ClientId + AND p.Type = ph.Type + AND p.Path = ph.Path + AND p.LastFileHashTime = ph.CreationTime + WHERE p.ClientId = {client_id} + AND p.Type = {type} + AND p.Path IN UNNEST({paths}) + """ + params = { + "client_id": client_id, + "type": int(path_type), + "paths": list(map(EncodePathComponents, components_list)), + } + + for row in self.db.ParamQuery(query, params, txn_tag="ReadPathInfos"): + path, creation_time, is_dir, *row = row + stat_creation_time, stat_bytes, *row = row + hash_creation_time, hash_bytes, *row = row + () = row + + path_info = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=DecodePathComponents(path), + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + directory=is_dir, + ) + + if stat_bytes is not None: + path_info.stat_entry.ParseFromString(stat_bytes) + path_info.last_stat_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + stat_creation_time + ).AsMicrosecondsSinceEpoch() + + if hash_bytes is not None: + path_info.hash_entry.ParseFromString(hash_bytes) + path_info.last_hash_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + hash_creation_time + ).AsMicrosecondsSinceEpoch() + + results[tuple(path_info.components)] = path_info + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadPathInfo( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components: Sequence[str], + timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> objects_pb2.PathInfo: + """Retrieves a path info record for a given path.""" + query = """ + SELECT p.CreationTime, + p.IsDir, + -- File stat information. + p.LastFileStatTime, + (SELECT s.Stat + FROM PathFileStats AS s + WHERE s.ClientId = {client_id} + AND s.Type = {type} + AND s.Path = {path} + AND s.CreationTime <= {timestamp} + ORDER BY s.CreationTime DESC + LIMIT 1), + -- File hash information. + p.LastFileHashTime, + (SELECT h.FileHash + FROM PathFileHashes AS h + WHERE h.ClientId = {client_id} + AND h.Type = {type} + AND h.Path = {path} + AND h.CreationTime <= {timestamp} + ORDER BY h.CreationTime DESC + LIMIT 1) + FROM Paths AS p + WHERE p.ClientId = {client_id} + AND p.Type = {type} + AND p.Path = {path} + """ + params = { + "client_id": client_id, + "type": int(path_type), + "path": EncodePathComponents(components), + } + + if timestamp is not None: + params["timestamp"] = timestamp.AsDatetime() + else: + params["timestamp"] = rdfvalue.RDFDatetime.Now().AsDatetime() + + try: + row = self.db.ParamQuerySingle(query, params, txn_tag="ReadPathInfo") + except NotFound: + raise abstract_db.UnknownPathError(client_id, path_type, components) # pylint: disable=raise-missing-from + + creation_time, is_dir, *row = row + last_file_stat_time, stat_bytes, *row = row + last_file_hash_time, hash_bytes, *row = row + () = row + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=components, + directory=is_dir, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + if last_file_stat_time is not None: + result.last_stat_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + last_file_stat_time + ).AsMicrosecondsSinceEpoch() + if last_file_hash_time is not None: + result.last_hash_entry_timestamp = rdfvalue.RDFDatetime.FromDatetime( + last_file_hash_time + ).AsMicrosecondsSinceEpoch() + + if stat_bytes is not None: + result.stat_entry.ParseFromString(stat_bytes) + if hash_bytes is not None: + result.hash_entry.ParseFromString(hash_bytes) + + return result + + @db_utils.CallLogged + @db_utils.CallAccounted + def ListDescendantPathInfos( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components: Sequence[str], + timestamp: Optional[rdfvalue.RDFDatetime] = None, + max_depth: Optional[int] = None, + ) -> Sequence[objects_pb2.PathInfo]: + """Lists path info records that correspond to descendants of given path.""" + results = [] + + # The query should include not only descendants of the path but the listed + # path path itself as well. We need to do that to ensure the path actually + # exists and raise if it does not (or if it is not a directory). + query = """ + SELECT p.Path, + p.CreationTime, + p.IsDir, + -- File stat information. + p.LastFileStatTime, + (SELECT s.Stat + FROM PathFileStats AS s + WHERE s.ClientId = p.ClientId + AND s.Type = p.Type + AND s.path = p.Path + AND s.CreationTime <= {timestamp} + ORDER BY s.CreationTime DESC + LIMIT 1), + -- File hash information. + p.LastFileHashTime, + (SELECT h.FileHash + FROM PathFileHashes AS h + WHERE h.ClientId = p.ClientId + AND h.Type = p.Type + AND h.Path = p.Path + AND h.CreationTime <= {timestamp} + ORDER BY h.CreationTime DESC + LIMIT 1) + FROM Paths AS p + WHERE p.ClientId = {client_id} + AND p.Type = {type} + """ + params = { + "client_id": client_id, + "type": int(path_type), + } + + # We add a constraint on path only if path components are non-empty. Empty + # path components indicate root path and it means that everything should be + # listed anyway. Treating the root path the same as other paths leads to + # issues with trailing slashes. + if components: + query += """ + AND (p.Path = {path} OR STARTS_WITH(p.Path, CONCAT({path}, b'/'))) + """ + params["path"] = EncodePathComponents(components) + + if timestamp is not None: + params["timestamp"] = timestamp.AsDatetime() + else: + params["timestamp"] = rdfvalue.RDFDatetime.Now().AsDatetime() + + if max_depth is not None: + query += " AND p.Depth <= {depth}" + params["depth"] = len(components) + max_depth + + for row in self.db.ParamQuery( + query, params, txn_tag="ListDescendantPathInfos" + ): + path, creation_time, is_dir, *row = row + last_file_stat_time, stat_bytes, *row = row + last_file_hash_time, hash_bytes, *row = row + () = row + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=DecodePathComponents(path), + directory=is_dir, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + if last_file_stat_time is not None: + result.last_stat_entry_timestamp = ( + rdfvalue.RDFDatetime.FromDatetime(last_file_stat_time) + ).AsMicrosecondsSinceEpoch() + if last_file_hash_time is not None: + result.last_hash_entry_timestamp = ( + rdfvalue.RDFDatetime.FromDatetime(last_file_hash_time) + ).AsMicrosecondsSinceEpoch() + + if stat_bytes is not None: + result.stat_entry.ParseFromString(stat_bytes) + if hash_bytes is not None: + result.hash_entry.ParseFromString(hash_bytes) + + results.append(result) + + results.sort(key=lambda result: tuple(result.components)) + + # Special case: we are being asked to list everything under the root path + # (represented by an empty list of components) but we do not have any path + # information available. We assume that the root path always exists (even if + # we did not collect any data yet) so we have to return something but checks + # that follow would cause us to raise instead. + if not components and not results: + return [] + + # The first element of the results should be the requested path itself. If + # it is not, it means that the requested path does not exist and we should + # raise. We also need to verify that the path is a directory. + if not results or tuple(results[0].components) != tuple(components): + raise abstract_db.UnknownPathError(client_id, path_type, components) + if not results[0].directory: + raise abstract_db.NotDirectoryPathError(client_id, path_type, components) + + # Once we verified that the path exists and is a directory, we should not + # include it in the results (since the method is ought to return only real + # descendants). + del results[0] + + # If timestamp is not specified we return collected paths as they are. But + # if the timestamp is given we are only interested in paths that are expli- + # cit (see below for the definition of "explicitness"). + if timestamp is None: + return results + + # A path is considered to be explicit if it has an associated stat or hash + # information or has an ancestor that is explicit. Thus, we traverse results + # in reverse order so that ancestors are checked for explicitness first. + explicit_results = list() + explicit_ancestors = set() + + for result in reversed(results): + components = tuple(result.components) + if ( + result.HasField("stat_entry") + or result.HasField("hash_entry") + or components in explicit_ancestors + ): + explicit_ancestors.add(components[:-1]) + explicit_results.append(result) + + # Since we have been traversing results in the reverse order, explicit re- + # sults are also reversed. Thus we have to reverse them back. + return list(reversed(explicit_results)) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadPathInfosHistories( + self, + client_id: str, + path_type: objects_pb2.PathInfo.PathType, + components_list: Collection[Sequence[str]], + cutoff: Optional[rdfvalue.RDFDatetime] = None, + ) -> dict[tuple[str, ...], Sequence[objects_pb2.PathInfo]]: + """Reads a collection of hash and stat entries for given paths.""" + # Early return in case of empty components list to avoid awkward issues with + # unnesting an empty array. + if not components_list: + return {} + + results = {tuple(components): [] for components in components_list} + + stat_query = """ + SELECT s.Path, s.CreationTime, s.Stat + FROM PathFileStats AS s + WHERE s.ClientId = {client_id} + AND s.Type = {type} + AND s.Path IN UNNEST({paths}) + """ + hash_query = """ + SELECT h.Path, h.CreationTime, h.FileHash + FROM PathFileHashes AS h + WHERE h.ClientId = {client_id} + AND h.Type = {type} + AND h.Path IN UNNEST({paths}) + """ + params = { + "client_id": client_id, + "type": int(path_type), + "paths": list(map(EncodePathComponents, components_list)), + } + + if cutoff is not None: + stat_query += " AND s.CreationTime <= {cutoff}" + hash_query += " AND h.CreationTime <= {cutoff}" + params["cutoff"] = cutoff.AsDatetime() + + query = f""" + WITH s AS ({stat_query}), + h AS ({hash_query}) + SELECT s.Path, s.CreationTime, s.Stat, + h.Path, h.CreationTime, h.FileHash + FROM s FULL JOIN h ON s.Path = h.Path + """ + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadPathInfosHistories" + ): + stat_path, stat_creation_time, stat_bytes, *row = row + hash_path, hash_creation_time, hash_bytes, *row = row + () = row + + # At least one of the two paths is going to be not null. In case both are + # not null, they are guaranteed to be the same value because of the way + # full join works. + components = DecodePathComponents(stat_path or hash_path) + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int(path_type)), + components=components, + ) + + # Either stat or hash or both have to be available, so at least one of the + # branches below is going to trigger and thus set the timestamp. Not that + # if both are available, they are guaranteed to have the same timestamp so + # overriding the value does no harm. + if stat_bytes is not None: + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + stat_creation_time + ).AsMicrosecondsSinceEpoch() + result.stat_entry.ParseFromString(stat_bytes) + if hash_bytes is not None: + result.timestamp = rdfvalue.RDFDatetime.FromDatetime( + hash_creation_time + ).AsMicrosecondsSinceEpoch() + result.hash_entry.ParseFromString(hash_bytes) + + results[components].append(result) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadLatestPathInfosWithHashBlobReferences( + self, + client_paths: Collection[abstract_db.ClientPath], + max_timestamp: Optional[rdfvalue.RDFDatetime] = None, + ) -> dict[abstract_db.ClientPath, Optional[objects_pb2.PathInfo]]: + """Returns path info with corresponding hash blob references.""" + # Early return in case of empty client paths to avoid issues with syntax er- + # rors due to empty clause list. + if not client_paths: + return {} + + params = {} + + key_clauses = [] + for idx, client_path in enumerate(client_paths): + client_id = client_path.client_id + + key_clauses.append(f"""( + h.ClientId = {{client_id_{idx}}} + AND h.Type = {{type_{idx}}} + AND h.Path = {{path_{idx}}} + )""") + params[f"client_id_{idx}"] = client_id + params[f"type_{idx}"] = int(client_path.path_type) + params[f"path_{idx}"] = EncodePathComponents(client_path.components) + + if max_timestamp is not None: + params["cutoff"] = max_timestamp.AsDatetime() + cutoff_clause = " h.CreationTime <= {cutoff}" + else: + cutoff_clause = " TRUE" + + query = f""" + WITH l AS (SELECT h.ClientId, h.Type, h.Path, + MAX(h.CreationTime) AS LastCreationTime + FROM PathFileHashes AS h + INNER JOIN HashBlobReferences AS b + ON h.FileHash.sha256 = b.HashId + WHERE ({" OR ".join(key_clauses)}) + AND {cutoff_clause} + GROUP BY h.ClientId, h.Type, h.Path) + SELECT l.ClientId, l.Type, l.Path, l.LastCreationTime, + s.Stat, h.FileHash + FROM l + LEFT JOIN @{{{{JOIN_METHOD=APPLY_JOIN}}}} PathFileStats AS s + ON s.ClientId = l.ClientId + AND s.Type = l.Type + AND s.Path = l.Path + AND s.CreationTime = l.LastCreationTime + LEFT JOIN @{{{{JOIN_METHOD=APPLY_JOIN}}}} PathFileHashes AS h + ON h.ClientId = l.ClientId + AND h.Type = l.Type + AND h.Path = l.Path + AND h.CreationTime = l.LastCreationTime + """ + + results = {client_path: None for client_path in client_paths} + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadLatestPathInfosWithHashBlobReferences" + ): + client_id, int_type, path, creation_time, *row = row + stat_bytes, hash_bytes = row + + components = DecodePathComponents(path) + + result = objects_pb2.PathInfo( + path_type=objects_pb2.PathInfo.PathType.Name(int_type), + components=components, + timestamp=rdfvalue.RDFDatetime.FromDatetime( + creation_time + ).AsMicrosecondsSinceEpoch(), + ) + + if stat_bytes is not None: + result.stat_entry.ParseFromString(stat_bytes) + + # Hash is guaranteed to be non-null (because of the query construction). + result.hash_entry.ParseFromString(hash_bytes) + + client_path = abstract_db.ClientPath( + client_id=client_id, + path_type=int_type, + components=components, + ) + + results[client_path] = result + + return results + + +def EncodePathComponents(components: Sequence[str]) -> bytes: + """Converts path components into canonical database representation of a path. + + Individual components are required to be non-empty and not contain any slash + characters ('/'). + + Args: + components: A sequence of path components. + + Returns: + A canonical database representation of the given path. + """ + for component in components: + if not component: + raise ValueError("Empty path component") + if _PATH_SEP in component: + raise ValueError(f"Path component with a '{_PATH_SEP}' character") + + return base64.b64encode(f"{_PATH_SEP}{_PATH_SEP.join(components)}".encode("utf-8")) + + +def DecodePathComponents(path: bytes) -> tuple[str, ...]: + """Converts a path in canonical database representation into path components. + + Args: + path: A path in its canonical database representation. + + Returns: + A sequence of path components. + """ + path = base64.b64decode(path).decode("utf-8") + + if not path.startswith(_PATH_SEP): + raise ValueError(f"Non-absolute path {path!r}") + + if path == _PATH_SEP: + # A special case for root path, since `str.split` for it gives use two empty + # strings. + return () + else: + return tuple(path.split(_PATH_SEP)[1:]) + + +# We use a forward slash as separator as this is the separator accepted by all +# supported platforms (including Windows) and is disallowed to appear in path +# components. Another viable choice would be a null byte character but that is +# very inconvenient to look at when browsing the database. +_PATH_SEP = "/" diff --git a/grr/server/grr_response_server/databases/spanner_paths_test.py b/grr/server/grr_response_server/databases/spanner_paths_test.py new file mode 100644 index 000000000..e5bb60b91 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_paths_test.py @@ -0,0 +1,58 @@ +from collections.abc import Sequence + +from absl.testing import absltest + +from grr_response_server.databases import db_paths_test +from grr_response_server.databases import spanner_paths +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabasePathsTest( + db_paths_test.DatabaseTestPathsMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + pass + + +class EncodePathComponentsTest(absltest.TestCase): + + def testEmptyComponent(self): + with self.assertRaises(ValueError): + spanner_paths.EncodePathComponents(("foo", "", "bar")) + + def testSlashComponent(self): + with self.assertRaises(ValueError): + spanner_paths.EncodePathComponents(("foo", "bar/baz", "quux")) + + +class EncodeDecodePathComponentsTest(absltest.TestCase): + + def testEmpty(self): + self._testComponents(()) + + def testSingle(self): + self._testComponents(("foo",)) + + def testMultiple(self): + self._testComponents(("foo", "bar", "baz", "quux")) + + def testUnicode(self): + self._testComponents(("zażółć", "gęślą", "jaźń")) + + def _testComponents(self, components: Sequence[str]): # pylint: disable=invalid-name + encoded = spanner_paths.EncodePathComponents(components) + decoded = spanner_paths.DecodePathComponents(encoded) + + self.assertSequenceEqual(components, decoded) + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_setup.sh b/grr/server/grr_response_server/databases/spanner_setup.sh new file mode 100755 index 000000000..cbd27d9fd --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_setup.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Script to prepare the GRR protobufs and database on Spanner + +echo "1/3 : Bundling GRR protos in spanner_grr.pb..." +if [ ! -f ./spanner_grr.pb ]; then + protoc -I=../../../proto -I=/usr/include --include_imports --descriptor_set_out=./spanner_grr.pb \ + ../../../proto/grr_response_proto/analysis.proto \ + ../../../proto/grr_response_proto/artifact.proto \ + ../../../proto/grr_response_proto/flows.proto \ + ../../../proto/grr_response_proto/hunts.proto \ + ../../../proto/grr_response_proto/jobs.proto \ + ../../../proto/grr_response_proto/knowledge_base.proto \ + ../../../proto/grr_response_proto/objects.proto \ + ../../../proto/grr_response_proto/output_plugin.proto \ + ../../../proto/grr_response_proto/signed_commands.proto \ + ../../../proto/grr_response_proto/sysinfo.proto \ + ../../../proto/grr_response_proto/user.proto \ + ../../../proto/grr_response_proto/rrg.proto \ + ../../../proto/grr_response_proto/rrg/fs.proto \ + ../../../proto/grr_response_proto/rrg/startup.proto \ + ../../../proto/grr_response_proto/rrg/action/execute_signed_command.proto +fi + +echo "2/3 : Creating GRR database on Spanner..." +gcloud spanner databases create ${SPANNER_DATABASE} --instance ${SPANNER_INSTANCE} + +echo "3/3 : Creating tables ..." +gcloud spanner databases ddl update ${SPANNER_DATABASE} --instance=${SPANNER_INSTANCE} --ddl-file=spanner.sdl --proto-descriptors-file=spanner_grr.pb \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries.py b/grr/server/grr_response_server/databases/spanner_signed_binaries.py new file mode 100644 index 000000000..31baa51c0 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python +"""A module with signed binaries methods of the Spanner backend.""" + +from typing import Sequence, Tuple + +from google.api_core.exceptions import NotFound +from google.cloud import spanner as spanner_lib + +from grr_response_core.lib import rdfvalue +from grr_response_proto import objects_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class SignedBinariesMixin: + """A Spanner database mixin with implementation of signed binaries.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteSignedBinaryReferences( + self, + binary_id: objects_pb2.SignedBinaryID, + references: objects_pb2.BlobReferences, + ) -> None: + """Writes blob references for a signed binary to the DB. + + Args: + binary_id: Signed binary id for the binary. + references: Blob references for the given binary. + """ + row = { + "Type": int(binary_id.binary_type), + "Path": binary_id.path, + "BlobReferences": references, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + } + self.db.InsertOrUpdate( + table="SignedBinaries", row=row, txn_tag="WriteSignedBinaryReferences" + ) + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadSignedBinaryReferences( + self, binary_id: objects_pb2.SignedBinaryID + ) -> Tuple[objects_pb2.BlobReferences, rdfvalue.RDFDatetime]: + """Reads blob references for the signed binary with the given id. + + Args: + binary_id: Signed binary id for the binary. + + Returns: + A tuple of the signed binary's rdf_objects.BlobReferences and an + RDFDatetime representing the time when the references were written to the + DB. + """ + binary_type = int(binary_id.binary_type) + + try: + row = self.db.Read( + table="SignedBinaries", + key=(binary_type, binary_id.path), + cols=("BlobReferences", "CreationTime"), + txn_tag="ReadSignedBinaryReferences" + ) + except NotFound as error: + raise db.UnknownSignedBinaryError(binary_id) from error + + raw_references = row[0] + creation_time = row[1] + + references = objects_pb2.BlobReferences() + references.ParseFromString(raw_references) + return references, rdfvalue.RDFDatetime.FromDatetime(creation_time) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadIDsForAllSignedBinaries(self) -> Sequence[objects_pb2.SignedBinaryID]: + """Returns ids for all signed binaries in the DB.""" + results = [] + + query = """ + SELECT sb.Type, sb.Path + FROM SignedBinaries as sb + """ + + for [binary_type, binary_path] in self.db.Query( + query, txn_tag="ReadIDsForAllSignedBinaries" + ): + binary_id = objects_pb2.SignedBinaryID( + binary_type=binary_type, path=binary_path + ) + results.append(binary_id) + + return results + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteSignedBinaryReferences( + self, + binary_id: objects_pb2.SignedBinaryID, + ) -> None: + """Deletes blob references for the given signed binary from the DB. + + Does nothing if no entry with the given id exists in the DB. + + Args: + binary_id: An id of the signed binary to delete. + """ + def Mutation(mut: spanner_utils.Mutation) -> None: + mut.delete("SignedBinaries", spanner_lib.KeySet(keys=[[binary_id.binary_type, binary_id.path]]) + ) + + try: + self.db.Mutate(Mutation, txn_tag="DeleteSignedBinaryReferences") + except NotFound as error: + raise db.UnknownSignedBinaryError(binary_id) from error diff --git a/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py b/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py new file mode 100644 index 000000000..edba473ac --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_binaries_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_signed_binaries_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseSignedBinariesTest( + db_signed_binaries_test.DatabaseTestSignedBinariesMixin, + spanner_test_lib.TestCase, +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands.py b/grr/server/grr_response_server/databases/spanner_signed_commands.py new file mode 100644 index 000000000..bfbdd7bba --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_commands.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +"""A module with signed command methods of the Spanner database implementation.""" +import base64 + +from typing import Sequence + +from google.api_core.exceptions import AlreadyExists, InvalidArgument, NotFound +from google.cloud import spanner as spanner_lib + +from grr_response_proto import signed_commands_pb2 +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class SignedCommandsMixin: + """A Spanner database mixin with implementation of signed command methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteSignedCommands( + self, + signed_commands: Sequence[signed_commands_pb2.SignedCommand], + ) -> None: + """Writes a signed command to the database.""" + def Mutation(mut) -> None: + for signed_command in signed_commands: + mut.insert( + table="SignedCommands", + columns=("Id", "OperatingSystem", "Ed25519Signature", "Command"), + values=[( + signed_command.id, + signed_command.operating_system, + base64.b64encode(bytes(signed_command.ed25519_signature)), + base64.b64encode(bytes(signed_command.command)) + )] + ) + + try: + self.db.Mutate(Mutation, txn_tag="WriteSignedCommand") + except AlreadyExists as e: + raise db.AtLeastOneDuplicatedSignedCommandError(signed_commands) from e + except InvalidArgument as e: + raise db.AtLeastOneDuplicatedSignedCommandError(signed_commands) from e + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadSignedCommand( + self, + id_: str, + operating_system: signed_commands_pb2.SignedCommand.OS, + ) -> signed_commands_pb2.SignedCommand: + """Reads signed command from the database.""" + params = {} + query = """ + SELECT + c.Id, c.Ed25519Signature, c.Command + FROM SignedCommands AS c + WHERE + c.Id = {id} + AND + c.OperatingSystem = {operating_system} + """ + params["id"] = id_ + params["operating_system"] = operating_system + + try: + ( + id_, + ed25519_signature, + command_bytes, + ) = self.db.ParamQuerySingle(query, params, txn_tag="ReadSignedCommand") + except NotFound as ex: + raise db.NotFoundError() from ex + + signed_command = signed_commands_pb2.SignedCommand() + signed_command.id = id_ + signed_command.operating_system = operating_system + signed_command.ed25519_signature = base64.b64decode(ed25519_signature) + signed_command.command = base64.b64decode(command_bytes) + + return signed_command + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadSignedCommands( + self, + ) -> Sequence[signed_commands_pb2.SignedCommand]: + """Reads signed command from the database.""" + query = """ + SELECT + c.Id, c.OperatingSystem, c.Ed25519Signature, c.Command + FROM + SignedCommands AS c + """ + signed_commands = [] + for ( + command_id, + operating_system, + signature, + command_bytes, + ) in self.db.Query(query, txn_tag="ReadSignedCommand"): + + signed_command = signed_commands_pb2.SignedCommand() + signed_command.id = command_id + signed_command.operating_system = operating_system + signed_command.ed25519_signature = base64.b64decode(signature) + signed_command.command = base64.b64decode(command_bytes) + + signed_commands.append(signed_command) + + return signed_commands + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteAllSignedCommands( + self, + ) -> None: + """Deletes all signed command from the database.""" + to_delete = self.ReadSignedCommands() + if not to_delete: + return + + def Mutation(mut) -> None: + for command in to_delete: + mut.delete("SignedCommands", spanner_lib.KeySet(keys=[[command.id, int(command.operating_system)]])) + + self.db.Mutate(Mutation, txn_tag="DeleteAllSignedCommands") diff --git a/grr/server/grr_response_server/databases/spanner_signed_commands_test.py b/grr/server/grr_response_server/databases/spanner_signed_commands_test.py new file mode 100644 index 000000000..647d06130 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_signed_commands_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_signed_commands_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseSignedCommandsTest( + db_signed_commands_test.DatabaseTestSignedCommandsMixin, + spanner_test_lib.TestCase, +): + """Spanner signed commands tests.""" + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_test.sdl b/grr/server/grr_response_server/databases/spanner_test.sdl new file mode 100644 index 000000000..7737e1402 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_test.sdl @@ -0,0 +1,11 @@ +CREATE TABLE `Table`( + `Key` STRING(MAX) NOT NULL, + `Column` STRING(MAX), + `Time` TIMESTAMP OPTIONS (allow_commit_timestamp = true) +) PRIMARY KEY (`Key`); + +CREATE TABLE `Subtable`( + `Key` STRING(MAX) NOT NULL, + `Subkey` STRING(MAX), +) PRIMARY KEY (`Key`, `Subkey`), + INTERLEAVE IN PARENT `Table` ON DELETE CASCADE; \ No newline at end of file diff --git a/grr/server/grr_response_server/databases/spanner_test_lib.py b/grr/server/grr_response_server/databases/spanner_test_lib.py new file mode 100644 index 000000000..58d7b2ae9 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_test_lib.py @@ -0,0 +1,152 @@ +"""A library with utilities for testing the Spanner database implementation.""" +import os +import random +import unittest + +from typing import Optional + +from absl.testing import absltest + +from google.cloud import spanner_v1 as spanner_lib +from google.cloud import spanner_admin_database_v1 +from google.cloud.spanner import Client, KeySet +from google.cloud.spanner_admin_database_v1.types import spanner_database_admin + +from grr_response_server.databases import spanner_utils +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import spanner as spanner_db + +OPERATION_TIMEOUT_SECONDS = 240 + +PROD_SCHEMA_SDL_PATH = "grr/server/grr_response_server/databases/spanner.sdl" +TEST_SCHEMA_SDL_PATH = "grr/server/grr_response_server/databases/spanner_test.sdl" + +PROTO_DESCRIPTOR_PATH = "grr/server/grr_response_server/databases/spanner_grr.pb" + +def _GetEnvironOrSkip(key): + value = os.environ.get(key) + if value is None: + raise unittest.SkipTest("'%s' variable is not set" % key) + return value + +def _readSchemaFromFile(file_path): + """Reads DDL statements from a file.""" + with open(file_path, 'r') as f: + # Read the whole file and split by semicolon. + # Filter out any empty strings resulting from split. + ddl_statements = [stmt.strip() for stmt in f.read().split(';') if stmt.strip()] + return ddl_statements + +def _readProtoDescriptorFromFile(): + """Reads DDL statements from a file.""" + with open(PROTO_DESCRIPTOR_PATH, 'rb') as f: + proto_descriptors = f.read() + return proto_descriptors + +def Init(sdl_path: str, proto_bundle: bool) -> None: + """Initializes the Spanner testing environment. + + This must be called only once per test process. A `setUpModule` method is + a perfect place for it. + + """ + global _TEST_DB + + if _TEST_DB is not None: + raise AssertionError("Spanner test library already initialized") + + project_id = _GetEnvironOrSkip("SPANNER_PROJECT_ID") + instance_id = _GetEnvironOrSkip("SPANNER_INSTANCE") + database_id = _GetEnvironOrSkip("SPANNER_DATABASE") + "-" + str(random.randint(1, 100000)) + + spanner_client = Client(project_id) + database_admin_api = spanner_client.database_admin_api + + ddl_statements = _readSchemaFromFile(sdl_path) + + proto_descriptors = bytes() + + if proto_bundle: + proto_descriptors = _readProtoDescriptorFromFile() + + request = spanner_database_admin.CreateDatabaseRequest( + parent=database_admin_api.instance_path(spanner_client.project, instance_id), + create_statement=f"CREATE DATABASE `{database_id}`", + extra_statements=ddl_statements, + proto_descriptors=proto_descriptors + ) + + operation = database_admin_api.create_database(request=request) + + print("Waiting for operation to complete...") + database = operation.result(OPERATION_TIMEOUT_SECONDS) + + print( + "Created database {} on instance {}".format( + database.name, + database_admin_api.instance_path(spanner_client.project, instance_id), + ) + ) + + instance = spanner_client.instance(instance_id) + _TEST_DB = instance.database(database_id) + +def TearDown() -> None: + """Tears down the Spanner testing environment. + + This must be called once per process after all the tests. + A `tearDownModule` is a perfect place for it. + """ + global _TEST_DB + + if _TEST_DB is not None: + # Create a client + _TEST_DB.drop() + _TEST_DB = None + + +class TestCase(absltest.TestCase): + """A base test case class for Spanner tests. + + This class takes care of setting up a clean database for every test method. It + is intended to be used with database test suite mixins. + """ + + project_id = None + + def setUp(self): + super().setUp() + + self.project_id = _GetEnvironOrSkip("PROJECT_ID") + + _clean_database() + + self.raw_db = spanner_utils.Database(_TEST_DB, self.project_id) + + spannerDB = spanner_db.SpannerDB(self.raw_db) + self.db = abstract_db.DatabaseValidationWrapper(spannerDB) + +def _get_table_names(db): + with db.snapshot() as snapshot: + query_result = snapshot.execute_sql( + "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE = 'BASE TABLE';" + ) + table_names = set() + for row in query_result: + table_names.add(row[0]) + + return table_names + + +def _clean_database() -> None: + """Creates an empty test spanner database.""" + + table_names = _get_table_names(_TEST_DB) + keyset = KeySet(all_=True) + + with _TEST_DB.batch() as batch: + # Deletes sample data from all tables in the given database. + for table_name in table_names: + batch.delete(table_name, keyset) + +_TEST_DB: spanner_lib.database = None diff --git a/grr/server/grr_response_server/databases/spanner_time_test.py b/grr/server/grr_response_server/databases/spanner_time_test.py new file mode 100644 index 000000000..4618c497f --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_time_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_time_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseTimeTest( + db_time_test.DatabaseTimeTestMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_users.py b/grr/server/grr_response_server/databases/spanner_users.py new file mode 100644 index 000000000..ea21adbe1 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_users.py @@ -0,0 +1,535 @@ +#!/usr/bin/env python +"""A library with user methods of Spanner database implementation.""" +import base64 +import datetime +import logging +import uuid + +from typing import Optional, Sequence, Tuple + +from google.api_core.exceptions import NotFound + +from google.cloud import spanner as spanner_lib + +from grr_response_core.lib import rdfvalue +from grr_response_proto import jobs_pb2 +from grr_response_proto import objects_pb2 +from grr_response_proto import user_pb2 +from grr_response_server.databases import db as abstract_db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils + + +class UsersMixin: + """A Spanner database mixin with implementation of user methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteGRRUser( + self, + username: str, + email: Optional[str] = None, + password: Optional[jobs_pb2.Password] = None, + user_type: Optional["objects_pb2.GRRUser.UserType"] = None, + canary_mode: Optional[bool] = None, + ui_mode: Optional["user_pb2.GUISettings.UIMode"] = None, + ) -> None: + """Writes user object for a user with a given name.""" + row = {"Username": username} + + if email is not None: + row["Email"] = email + + if password is not None: + row["Password"] = password + + if user_type is not None: + row["Type"] = int(user_type) + + if ui_mode is not None: + row["UiMode"] = int(ui_mode) + + if canary_mode is not None: + row["CanaryMode"] = canary_mode + + self.db.InsertOrUpdate(table="Users", row=row, txn_tag="WriteGRRUser") + + + @db_utils.CallLogged + @db_utils.CallAccounted + def DeleteGRRUser(self, username: str) -> None: + """Deletes the user and all related metadata with the given username.""" + keyset = spanner_lib.KeySet(keys=[(username,)]) + + def Transaction(txn) -> None: + try: + txn.read(table="Users", columns=("Username",), keyset=keyset).one() + except NotFound: + raise abstract_db.UnknownGRRUserError(username) + + username_range = spanner_lib.KeyRange(start_closed=[username], end_closed=[username]) + txn.delete(table="ApprovalRequests", keyset=spanner_lib.KeySet(ranges=[username_range])) + txn.delete(table="Users", keyset=keyset) + + self.db.Transact(Transaction, txn_tag="DeleteGRRUser") + + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadGRRUser(self, username: str) -> objects_pb2.GRRUser: + """Reads a user object corresponding to a given name.""" + cols = ("Email", "Password", "Type", "CanaryMode", "UiMode") + try: + row = self.db.Read(table="Users", + key=[username], + cols=cols, + txn_tag="ReadGRRUser") + except NotFound as error: + raise abstract_db.UnknownGRRUserError(username) from error + + user = objects_pb2.GRRUser( + username=username, + email=row[0], + user_type=row[2], + canary_mode=row[3], + ui_mode=row[4], + ) + + if row[1]: + pw = jobs_pb2.Password() + pw.ParseFromString(row[1]) + user.password.CopyFrom(pw) + + return user + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadGRRUsers( + self, + offset: int = 0, + count: Optional[int] = None, + ) -> Sequence[objects_pb2.GRRUser]: + """Reads GRR users with optional pagination, sorted by username.""" + if count is None: + # TODO(b/196379916): We use the same value as F1 implementation does. But + # a better solution would be to dynamically not ignore the `LIMIT` clause + # in the query if the count parameter is not provided. This is not trivial + # as queries have to be provided as docstrings (an utility that does it + # on the fly has to be created to hack around this limitation). + count = 2147483647 + + users = [] + + query = """ + SELECT u.Username, u.Email, u.Password, u.Type, u.CanaryMode, u.UiMode + FROM Users AS u + ORDER BY u.Username + LIMIT {count} + OFFSET {offset} + """ + params = { + "offset": offset, + "count": count, + } + + for row in self.db.ParamQuery(query, params, txn_tag="ReadGRRUsers"): + username, email, password, typ, canary_mode, ui_mode = row + + user = objects_pb2.GRRUser( + username=username, + email=email, + user_type=typ, + canary_mode=canary_mode, + ui_mode=ui_mode, + ) + + if password: + user.password.ParseFromString(password) + + users.append(user) + + return users + + @db_utils.CallLogged + @db_utils.CallAccounted + def CountGRRUsers(self) -> int: + """Returns the total count of GRR users.""" + query = """ + SELECT COUNT(*) + FROM Users + """ + + (count,) = self.db.QuerySingle(query, txn_tag="CountGRRUsers") + return count + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteApprovalRequest(self, request: objects_pb2.ApprovalRequest) -> str: + """Writes an approval request object.""" + approval_id = str(uuid.uuid4()) + + row = { + "Requestor": request.requestor_username, + "ApprovalId": approval_id, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "ExpirationTime": ( + rdfvalue.RDFDatetime() + .FromMicrosecondsSinceEpoch(request.expiration_time) + .AsDatetime() + ), + "Reason": request.reason, + "NotifiedUsers": list(request.notified_users), + "CcEmails": list(request.email_cc_addresses), + } + + if request.approval_type == _APPROVAL_TYPE_CLIENT: + row["SubjectClientId"] = request.subject_id + elif request.approval_type == _APPROVAL_TYPE_HUNT: + row["SubjectHuntId"] = request.subject_id + elif request.approval_type == _APPROVAL_TYPE_CRON_JOB: + row["SubjectCronJobId"] = request.subject_id + else: + raise ValueError(f"Unsupported approval type: {request.approval_type}") + + self.db.Insert( + table="ApprovalRequests", row=row, txn_tag="WriteApprovalRequest" + ) + + return approval_id + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadApprovalRequest( + self, + username: str, + approval_id: str, + ) -> objects_pb2.ApprovalRequest: + """Reads an approval request object with a given id.""" + + query = """ + SELECT r.SubjectClientId, r.SubjectHuntId, r.SubjectCronJobId, + r.Reason, + r.CreationTime, r.ExpirationTime, + r.NotifiedUsers, r.CcEmails, + ARRAY(SELECT AS STRUCT g.Grantor, + g.CreationTime + FROM ApprovalGrants AS g + WHERE g.Requestor = r.Requestor + AND g.ApprovalId = r.ApprovalId) AS Grants + FROM ApprovalRequests AS r + WHERE r.Requestor = {requestor} + AND r.ApprovalId = {approval_id} + """ + params = { + "requestor": username, + "approval_id": approval_id, + } + + try: + row = self.db.ParamQuerySingle( + query, params, txn_tag="ReadApprovalRequest" + ) + except NotFound: + # TODO: Improve error message of this error class. + raise abstract_db.UnknownApprovalRequestError(approval_id) + + subject_client_id, subject_hunt_id, subject_cron_job_id, *row = row + reason, *row = row + creation_time, expiration_time, *row = row + notified_users, cc_emails, grants = row + + request = objects_pb2.ApprovalRequest( + requestor_username=username, + approval_id=approval_id, + reason=reason, + timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), + expiration_time=RDFDatetime(expiration_time).AsMicrosecondsSinceEpoch(), + notified_users=notified_users, + email_cc_addresses=cc_emails, + ) + + if subject_client_id is not None: + request.subject_id = subject_client_id + request.approval_type = _APPROVAL_TYPE_CLIENT + elif subject_hunt_id is not None: + request.subject_id = subject_hunt_id + request.approval_type = _APPROVAL_TYPE_HUNT + elif subject_cron_job_id is not None: + request.subject_id = subject_cron_job_id + request.approval_type = _APPROVAL_TYPE_CRON_JOB + else: + # This should not happen as the condition to one of these being always + # set if enforced by the database schema. + message = "No subject set for approval '%s' of user '%s'" + logging.error(message, approval_id, username) + + for grantor, creation_time in grants: + grant = objects_pb2.ApprovalGrant() + grant.grantor_username = grantor + grant.timestamp = RDFDatetime(creation_time).AsMicrosecondsSinceEpoch() + request.grants.add().CopyFrom(grant) + + return request + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadApprovalRequests( + self, + username: str, + typ: "objects_pb2.ApprovalRequest.ApprovalType", + subject_id: Optional[str] = None, + include_expired: Optional[bool] = False, + ) -> Sequence[objects_pb2.ApprovalRequest]: + """Reads approval requests of a given type for a given user.""" + requests = [] + + # We need to use double curly braces for parameters as we also parametrize + # over index that is substituted using standard Python templating. + query = """ + SELECT r.ApprovalId, + r.SubjectClientId, r.SubjectHuntId, r.SubjectCronJobId, + r.Reason, + r.CreationTime, r.ExpirationTime, + r.NotifiedUsers, r.CcEmails, + ARRAY(SELECT AS STRUCT g.Grantor, + g.CreationTime + FROM ApprovalGrants AS g + WHERE g.Requestor = r.Requestor + AND g.ApprovalId = r.ApprovalId) AS Grants + FROM ApprovalRequests@{{{{FORCE_INDEX={index}}}}} AS r + WHERE r.Requestor = {{requestor}} + """ + params = { + "requestor": username, + } + + # By default we use the "by requestor" index but in case a specific subject + # is given we can also use a more specific index (overridden below). + index = "ApprovalRequestsByRequestor" + + if typ == _APPROVAL_TYPE_CLIENT: + query += " AND r.SubjectClientId IS NOT NULL" + if subject_id is not None: + query += " AND r.SubjectClientId = {{subject_client_id}}" + params["subject_client_id"] = subject_id + index = "ApprovalRequestsByRequestorSubjectClientId" + elif typ == _APPROVAL_TYPE_HUNT: + query += " AND r.SubjectHuntId IS NOT NULL" + if subject_id is not None: + query += " AND r.SubjectHuntId = {{subject_hunt_id}}" + params["subject_hunt_id"] = subject_id + index = "ApprovalRequestsByRequestorSubjectHuntId" + elif typ == _APPROVAL_TYPE_CRON_JOB: + query += " AND r.SubjectCronJobId IS NOT NULL" + if subject_id is not None: + query += " AND r.SubjectCronJobId = {{subject_cron_job_id}}" + params["subject_cron_job_id"] = subject_id + index = "ApprovalRequestsByRequestorSubjectCronJobId" + else: + raise ValueError(f"Unsupported approval type: {typ}") + + if not include_expired: + query += " AND r.ExpirationTime > CURRENT_TIMESTAMP()" + + query = query.format(index=index) + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadApprovalRequests" + ): + approval_id, *row = row + subject_client_id, subject_hunt_id, subject_cron_job_id, *row = row + reason, *row = row + creation_time, expiration_time, *row = row + notified_users, cc_emails, grants = row + + request = objects_pb2.ApprovalRequest( + requestor_username=username, + approval_id=approval_id, + reason=reason, + timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), + expiration_time=RDFDatetime( + expiration_time + ).AsMicrosecondsSinceEpoch(), + notified_users=notified_users, + email_cc_addresses=cc_emails, + ) + + if subject_client_id is not None: + request.subject_id = subject_client_id + request.approval_type = _APPROVAL_TYPE_CLIENT + elif subject_hunt_id is not None: + request.subject_id = subject_hunt_id + request.approval_type = _APPROVAL_TYPE_HUNT + elif subject_cron_job_id is not None: + request.subject_id = subject_cron_job_id + request.approval_type = _APPROVAL_TYPE_CRON_JOB + else: + # This should not happen as the condition to one of these being always + # set if enforced by the database schema. + message = "No subject set for approval '%s' of user '%s'" + logging.error(message, approval_id, username) + + for grantor, creation_time in grants: + grant = objects_pb2.ApprovalGrant() + grant.grantor_username = grantor + grant.timestamp = RDFDatetime(creation_time).AsMicrosecondsSinceEpoch() + request.grants.add().CopyFrom(grant) + + requests.append(request) + + return requests + + @db_utils.CallLogged + @db_utils.CallAccounted + def GrantApproval( + self, + requestor_username: str, + approval_id: str, + grantor_username: str, + ) -> None: + """Grants approval for a given request using given username.""" + row = { + "Requestor": requestor_username, + "ApprovalId": approval_id, + "Grantor": grantor_username, + "GrantId": str(uuid.uuid4()), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + } + + self.db.Insert(table="ApprovalGrants", row=row, txn_tag="GrantApproval") + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteUserNotification( + self, + notification: objects_pb2.UserNotification, + ) -> None: + """Writes a notification for a given user.""" + row = { + "Username": notification.username, + "NotificationId": str(uuid.uuid4()), + "Type": int(notification.notification_type), + "State": int(notification.state), + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + "Message": notification.message, + } + if notification.reference: + row["Reference"] = base64.b64encode(notification.reference.SerializeToString()) + + try: + self.db.Insert( + table="UserNotifications", row=row, txn_tag="WriteUserNotification" + ) + except NotFound: + raise abstract_db.UnknownGRRUserError(notification.username) + + @db_utils.CallLogged + @db_utils.CallAccounted + def ReadUserNotifications( + self, + username: str, + state: Optional["objects_pb2.UserNotification.State"] = None, + timerange: Optional[ + Tuple[Optional[rdfvalue.RDFDatetime], Optional[rdfvalue.RDFDatetime]] + ] = None, + ) -> Sequence[objects_pb2.UserNotification]: + """Reads notifications scheduled for a user within a given timerange.""" + notifications = [] + + params = { + "username": username, + } + query = """ + SELECT n.Type, n.State, n.CreationTime, + n.Message, n.Reference + FROM UserNotifications AS n + WHERE n.Username = {username} + """ + + if state is not None: + params["state"] = int(state) + query += " AND n.state = {state}" + + if timerange is not None: + begin_time, end_time = timerange + if begin_time is not None: + params["begin_time"] = begin_time.AsDatetime() + query += " AND n.CreationTime >= {begin_time}" + if end_time is not None: + params["end_time"] = end_time.AsDatetime() + query += " AND n.CreationTime <= {end_time}" + + query += " ORDER BY n.CreationTime DESC" + + for row in self.db.ParamQuery( + query, params, txn_tag="ReadUserNotifications" + ): + typ, state, creation_time, message, reference = row + + notification = objects_pb2.UserNotification( + username=username, + notification_type=typ, + state=state, + timestamp=RDFDatetime(creation_time).AsMicrosecondsSinceEpoch(), + message=message, + ) + + if reference: + notification.reference.ParseFromString(reference) + + notifications.append(notification) + + return notifications + + @db_utils.CallLogged + @db_utils.CallAccounted + def UpdateUserNotifications( + self, + username: str, + timestamps: Sequence[rdfvalue.RDFDatetime], + state: Optional["objects_pb2.UserNotification.State"] = None, + ): + """Updates existing user notification objects.""" + # `UNNEST` used in the query does not like empty arrays, so we return early + # in such cases. + if not timestamps: + return + + params = { + "username": username, + "state": int(state), + } + + param_placeholders = ", ".join([f"{{ts{i}}}" for i in range(len(timestamps))]) + for i, timestamp in enumerate(timestamps): + param_name = f"ts{i}" + params[param_name] = timestamp.AsDatetime() + + query = f""" + UPDATE UserNotifications n + SET n.State = {state} + WHERE n.Username = '{username}' + AND n.CreationTime IN ({param_placeholders}) + """ + + self.db.ParamExecute(query, params, txn_tag="UpdateUserNotifications") + + +def _HexApprovalID(approval_id: int) -> str: + return f"{approval_id:016x}" + + +def _UnhexApprovalID(approval_id: str) -> int: + return int(approval_id, base=16) + + + +def RDFDatetime(time: datetime.datetime) -> rdfvalue.RDFDatetime: + return rdfvalue.RDFDatetime.FromDatetime(time) + + +_APPROVAL_TYPE_CLIENT = objects_pb2.ApprovalRequest.APPROVAL_TYPE_CLIENT +_APPROVAL_TYPE_HUNT = objects_pb2.ApprovalRequest.APPROVAL_TYPE_HUNT +_APPROVAL_TYPE_CRON_JOB = objects_pb2.ApprovalRequest.APPROVAL_TYPE_CRON_JOB diff --git a/grr/server/grr_response_server/databases/spanner_users_test.py b/grr/server/grr_response_server/databases/spanner_users_test.py new file mode 100644 index 000000000..ca0692686 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_users_test.py @@ -0,0 +1,22 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_users_test +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseUsersTest( + db_users_test.DatabaseTestUsersMixin, spanner_test_lib.TestCase +): + pass # Test methods are defined in the base mixin class. + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_utils.py b/grr/server/grr_response_server/databases/spanner_utils.py new file mode 100644 index 000000000..647bf4337 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_utils.py @@ -0,0 +1,468 @@ +"""Spanner-related helpers and other utilities.""" + +from collections.abc import Callable, Iterable, Mapping, Sequence +import datetime +import decimal +import re +from typing import Any, Optional, Tuple, TypeVar + +from google.cloud.spanner_v1 import database as spanner_lib +from google.cloud.spanner_v1.transaction import Transaction as _Transaction +from google.cloud.spanner_v1.batch import _BatchBase as _Mutation + +from google.cloud.spanner import KeyRange, KeySet +from google.cloud.spanner_v1 import param_types + +from grr_response_core.lib.util import collection + + +Row = Tuple[Any, ...] +Cursor = Iterable[Row] + +_T = TypeVar("_T") +class Mutation(_Mutation): + """A wrapper around the PySpanner Mutation class.""" + pass + +class Transaction(_Transaction): + """A wrapper around the PySpanner Transaction class.""" + pass + +class Database: + """A wrapper around the PySpanner class. + + The wrapper is supposed to streamline the usage of Spanner database through + an abstraction that is much harder to misuse. The wrapper will run retryable + queries through a transaction runner handling all brittle logic for the user. + """ + + _PYSPANNER_PARAM_REGEX = re.compile(r"@p\d+") + + def __init__(self, pyspanner: spanner_lib.Database, project_id: str) -> None: + super().__init__() + self._pyspanner = pyspanner + self.project_id = project_id + def _parametrize(self, query: str, names: Iterable[str]) -> str: + match = self._PYSPANNER_PARAM_REGEX.search(query) + if match is not None: + raise ValueError(f"Query contains illegal sequence: {match.group(0)}") + + kwargs = {} + for name in names: + kwargs[name] = f"@{name}" + + return query.format(**kwargs) + + def _get_param_type(self, value): + """ + Infers the Google Cloud Spanner type from a Python value. + + Args: + value: The Python value whose Spanner type is to be inferred. + + Returns: + A google.cloud.spanner_v1.types.Type object, or None if the type + cannot be reliably inferred (e.g., for a standalone None value or + an empty list). + Raises: + TypeError: Raised for any unsupported type or empty container value. + """ + if value is None: + # Cannot determine a specific Spanner type from a None value alone. + # This indicates that the type is ambiguous without further schema + # context. + return None + + py_type = type(value) + + if py_type is int: + return param_types.INT64 + elif py_type is float: + return param_types.FLOAT64 + elif py_type is str: + return param_types.STRING + elif py_type is bool: + return param_types.BOOL + elif py_type is bytes: + return param_types.BYTES + elif py_type is datetime.date: + return param_types.DATE + elif py_type is datetime.datetime: + # Note: Spanner TIMESTAMPs are stored in UTC. Ensure datetime objects + # are timezone-aware (UTC) when writing data. This function only maps the + # type. + return param_types.TIMESTAMP + elif py_type is decimal.Decimal: + return param_types.NUMERIC + elif py_type is list: + if len(value) > 0: + return param_types.Array(self._get_param_type(value[0])) + else: + raise TypeError(f"Empty value for Python type: {py_type.__name__} for Spanner type conversion.") + else: + # Potentially raise an error for unsupported types or return None + # For a generic solution, raising an error for unknown types is often safer. + raise TypeError(f"Unsupported Python type: {py_type.__name__} for Spanner type conversion.") + + def Transact( + self, + func: Callable[["Transaction"], _T], + txn_tag: Optional[str] = None, + ) -> _T: + + """Execute the given callback function in a Spanner transaction. + + Args: + func: A transaction function to execute. + txn_tag: Transaction tag to apply. + + Returns: + The result of the transaction function executed. + """ + return self._pyspanner.run_in_transaction(func, transaction_tag=txn_tag) + + def Mutate( + self, func: Callable[["Mutation"], None], txn_tag: Optional[str] = None + ) -> None: + """Execute the given callback function in a Spanner mutation. + + Args: + func: A mutation function to execute. + txn_tag: Spanner transaction tag. + """ + + self.Transact(func, txn_tag=txn_tag) + + def Query(self, query: str, txn_tag: Optional[str] = None) -> Cursor: + """Queries Spanner database using the given query string. + + Args: + query: An SQL string. + txn_tag: Spanner transaction tag. + + Returns: + A cursor over the query results. + """ + with self._pyspanner.snapshot() as snapshot: + results = snapshot.execute_sql(query, request_options={"request_tag": txn_tag}) + + return results + + def QuerySingle(self, query: str, txn_tag: Optional[str] = None) -> Row: + """Queries PySpanner for a single row using the given query string. + + Args: + query: An SQL string. + txn_tag: Spanner transaction tag. + + Returns: + A single row matching the query. + + Raises: + NotFound: If the query did not return any results. + ValueError: If the query yielded more than one result. + """ + return self.Query(query, txn_tag=txn_tag).one() + + def ParamQuery( + self, query: str, params: Mapping[str, Any], + param_type: Optional[dict] = None, txn_tag: Optional[str] = None + ) -> Cursor: + """Queries PySpanner database using the given query string with params. + + The query string should specify parameters with the standard Python format + placeholder syntax [1]. Note that parameters inside string literals in the + query itself have to be escaped. + + Also, the query literal is not allowed to contain any '@p{idx}' strings + inside as that would lead to an incorrect behaviour when evaluating the + query. To prevent mistakes the function will raise an exception in such + cases. + + [1]: https://docs.python.org/3/library/stdtypes.html#str.format + + Args: + query: An SQL string with parameter placeholders. + params: A dictionary mapping parameter name to a value. + txn_tag: Spanner transaction tag. + + Returns: + A cursor over the query results. + + Raises: + ValueError: If the query contains disallowed sequences. + KeyError: If some parameter is not specified. + """ + if not param_type: + param_type = {} + names, _ = collection.Unzip(params.items()) + query = self._parametrize(query, names) + + for key, value in params.items(): + if key not in param_type: + try: + param_type[key] = self._get_param_type(value) + except TypeError as e: + print(f"Warning for key '{key}': {e}. Setting type to None.") + param_type[key] = None # Or re-raise, or handle differently + + with self._pyspanner.snapshot() as snapshot: + results = snapshot.execute_sql( + query, + params=params, + param_types=param_type, + request_options={"request_tag": txn_tag} + ) + + return results + + def ParamQuerySingle( + self, query: str, params: Mapping[str, Any], + param_type: Optional[dict] = None, txn_tag: Optional[str] = None + ) -> Row: + """Queries the database for a single row using with a query with params. + + See documentation for `ParamQuery` to learn more about the syntax of query + parameters and other caveats. + + Args: + query: An SQL string with parameter placeholders. + params: A dictionary mapping parameter name to a value. + txn_tag: Spanner transaction tag. + + Returns: + A single result of running the query. + + Raises: + NotFound: If the query did not return any results. + ValueError: If the query yielded more than one result. + ValueError: If the query contains disallowed sequences. + KeyError: If some parameter is not specified. + """ + return self.ParamQuery(query, params, param_type=param_type, txn_tag=txn_tag).one() + + def ParamExecute( + self, query: str, params: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Executes the given query with parameters against a Spanner database. + + Args: + query: An SQL string with parameter placeholders. + params: A dictionary mapping parameter name to a value. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + + Raises: + ValueError: If the query contains disallowed sequences. + KeyError: If some parameter is not specified. + """ + names, _ = collection.Unzip(params.items()) + query = self._parametrize(query, names) + + param_type = {} + for key, value in params.items(): + try: + param_type[key] = self._get_param_type(value) + except TypeError as e: + print(f"Warning for key '{key}': {e}. Setting type to None.") + param_type[key] = None # Or re-raise, or handle differently + + def param_execute(txn: Transaction): + txn.execute_update( + query, + params=params, + param_types=param_type, + request_options={"request_tag": txn_tag}, + ) + + self._pyspanner.run_in_transaction(param_execute) + + def ExecutePartitioned( + self, query: str, txn_tag: Optional[str] = None + ) -> None: + """Executes the given query against a Spanner database. + + This is a more efficient variant of the `Execute` method, but it does not + guarantee atomicity. See the official documentation on partitioned updates + for more information [1]. + + [1]: go/spanner-partitioned-dml + + Args: + query: An SQL query string to execute. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + return self._pyspanner.execute_partitioned_dml(query, + request_options={"request_tag": txn_tag}) + + def Insert( + self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Insert a row into the given table. + + Args: + table: A table into which the row is to be inserted. + row: A mapping from column names to column values of the row. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + columns, values = collection.Unzip(row.items()) + + columns = list(columns) + values = list(values) + + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: + batch.insert( + table=table, + columns=columns, + values=[values] + ) + + def Update( + self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Updates a row in the given table. + + Args: + table: A table in which the row is to be updated. + row: A mapping from column names to column values of the row. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + columns, values = collection.Unzip(row.items()) + + columns = list(columns) + values = list(values) + + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: + batch.update( + table=table, + columns=columns, + values=[values] + ) + + def InsertOrUpdate( + self, table: str, row: Mapping[str, Any], txn_tag: Optional[str] = None + ) -> None: + """Insert or update a row into the given table within the transaction. + + Args: + table: A table into which the row is to be inserted. + row: A mapping from column names to column values of the row. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + columns, values = collection.Unzip(row.items()) + + columns = list(columns) + values = list(values) + + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: + batch.insert_or_update( + table=table, + columns=columns, + values=[values] + ) + + def Delete( + self, table: str, key: Sequence[Any], txn_tag: Optional[str] = None + ) -> None: + """Deletes a specified row from the given table. + + Args: + table: A table from which the row is to be deleted. + key: A sequence of values denoting the key of the row to delete. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + keyset = KeySet(all_=True) + if key: + keyset = KeySet(keys=[key]) + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: + batch.delete(table, keyset) + + def DeleteWithPrefix(self, table: str, key_prefix: Sequence[Any], + txn_tag: Optional[str] = None) -> None: + """Deletes a range of rows with common key prefix from the given table. + + Args: + table: A table from which rows are to be deleted. + key_prefix: A sequence of value denoting the prefix of the key of rows to delete. + txn_tag: Spanner transaction tag. + + Returns: + Nothing. + """ + range = KeyRange(start_closed=key_prefix, end_closed=key_prefix) + keyset = KeySet(ranges=[range]) + + with self._pyspanner.batch(request_options={"request_tag": txn_tag}) as batch: + batch.delete(table, keyset) + + def Read( + self, + table: str, + key: Sequence[Any], + cols: Sequence[str], + txn_tag: Optional[str] = None + ) -> Row: + """Read a single row with the given key from the specified table. + + Args: + table: A name of the table to read from. + key: A key of the row to read. + cols: Columns of the row to read. + txn_tag: Spanner transaction tag. + + Returns: + A mapping from columns to values of the read row. + """ + keyset = KeySet(keys=[key]) + with self._pyspanner.snapshot() as snapshot: + results = snapshot.read( + table=table, + columns=cols, + keyset=keyset, + request_options={"request_tag": txn_tag} + ) + return results.one() + + def ReadSet( + self, + table: str, + rows: KeySet, + cols: Sequence[str], + txn_tag: Optional[str] = None + ) -> Cursor: + """Read a set of rows from the specified table. + + Args: + table: A name of the table to read from. + rows: A set of keys specifying which rows to read. + cols: Columns of the row to read. + txn_tag: Spanner transaction tag. + + Returns: + Mappings from columns to values of the rows read. + """ + with self._pyspanner.snapshot() as snapshot: + results = snapshot.read( + table=table, + columns=cols, + keyset=rows, + request_options={"request_tag": txn_tag} + ) + return results diff --git a/grr/server/grr_response_server/databases/spanner_utils_test.py b/grr/server/grr_response_server/databases/spanner_utils_test.py new file mode 100644 index 000000000..bf92c7bb3 --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_utils_test.py @@ -0,0 +1,311 @@ +import datetime +from typing import Any, List + +from absl.testing import absltest + +from google.cloud import spanner as spanner_lib +from google.api_core.exceptions import NotFound + + +from grr_response_server.databases import spanner_test_lib + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.TEST_SCHEMA_SDL_PATH, False) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class DatabaseTest(spanner_test_lib.TestCase): + + def setUp(self): + super().setUp() + + ####################################### + # Transact Tests + ####################################### + def testTransactionTransactional(self): + + def TransactionWrite(txn) -> None: + txn.insert( + table="Table", + columns=("Key",), + values=[("foo",), ("bar",)] + ) + + def TransactionRead(txn) -> List[Any]: + result = list(txn.execute_sql("SELECT t.Key FROM Table AS t")) + return result + + self.raw_db.Transact(TransactionWrite) + results = self.raw_db.Transact(TransactionRead) + self.assertCountEqual(results, [["foo"], ["bar"]]) + + ####################################### + # Query Tests + ####################################### + def testQuerySimple(self): + results = list(self.raw_db.Query("SELECT 'foo', 42")) + self.assertEqual(results, [["foo", 42]]) + + def testQueryWithPlaceholders(self): + results = list(self.raw_db.Query("SELECT '{}', '@p0'")) + self.assertEqual(results, [["{}", "@p0"]]) + + ####################################### + # QuerySingle Tests + ####################################### + def testQuerySingle(self): + result = self.raw_db.QuerySingle("SELECT 'foo', 42") + self.assertEqual(result, ["foo", 42]) + + def testQuerySingleEmpty(self): + with self.assertRaises(NotFound): + self.raw_db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([])") + + def testQuerySingleMultiple(self): + with self.assertRaises(ValueError): + self.raw_db.QuerySingle("SELECT 'foo', 42 FROM UNNEST([1, 2])") + + ####################################### + # ParamQuery Tests + ####################################### + def testParamQuerySingleParam(self): + query = "SELECT {abc}" + params = {"abc": 1337} + + results = list(self.raw_db.ParamQuery(query, params)) + self.assertEqual(results, [[1337,]]) + + def testParamQueryMultipleParams(self): + timestamp = datetime.datetime.now(datetime.timezone.utc) + + query = "SELECT {int}, {str}, {timestamp}" + params = {"int": 1337, "str": "quux", "timestamp": timestamp} + + results = list(self.raw_db.ParamQuery(query, params)) + self.assertEqual(results, [[1337, "quux", timestamp]]) + + def testParamQueryMissingParams(self): + with self.assertRaisesRegex(KeyError, "bar"): + self.raw_db.ParamQuery("SELECT {foo}, {bar}", {"foo": 42}) + + def testParamQueryExtraParams(self): + query = "SELECT 42, {foo}" + params = {"foo": "foo", "bar": "bar"} + + results = list(self.raw_db.ParamQuery(query, params)) + self.assertEqual(results, [[42, "foo"]]) + + def testParamQueryIllegalSequence(self): + with self.assertRaisesRegex(ValueError, "@p1337"): + self.raw_db.ParamQuery("SELECT @p1337", {}) + + def testParamQueryLegalSequence(self): + results = list(self.raw_db.ParamQuery("SELECT '@p', '@q'", {})) + self.assertEqual(results, [["@p", "@q"]]) + + def testParamQueryBraceEscape(self): + results = list(self.raw_db.ParamQuery("SELECT '{{foo}}'", {})) + self.assertEqual(results, [["{foo}",]]) + + ####################################### + # ParamExecute Tests + ####################################### + def testParamExecuteSingleParam(self): + query = """ + INSERT INTO Table(Key) + VALUES ({key}) + """ + params = {"key": "foo"} + + self.raw_db.ParamExecute(query, params) + + ####################################### + # ParamQuerySingle Tests + ####################################### + def testParamQuerySingle(self): + query = "SELECT {str}, {int}" + params = {"str": "foo", "int": 42} + + result = self.raw_db.ParamQuerySingle(query, params) + self.assertEqual(result, ["foo", 42]) + + def testParamQuerySingleEmpty(self): + query = "SELECT {str}, {int} FROM UNNEST([])" + params = {"str": "foo", "int": 42} + + with self.assertRaises(NotFound): + self.raw_db.ParamQuerySingle(query, params) + + def testParamQuerySingleMultiple(self): + query = "SELECT {str}, {int} FROM UNNEST([1, 2])" + params = {"str": "foo", "int": 42} + + with self.assertRaises(ValueError): + self.raw_db.ParamQuerySingle(query, params) + + ####################################### + # ExecutePartitioned Tests + ####################################### + def testExecutePartitioned(self): + self.raw_db.Insert(table="Table", row={"Key": "foo"}) + self.raw_db.Insert(table="Table", row={"Key": "bar"}) + self.raw_db.Insert(table="Table", row={"Key": "baz"}) + + self.raw_db.ExecutePartitioned("DELETE FROM Table AS t WHERE t.Key LIKE 'ba%'") + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) + self.assertLen(results, 1) + self.assertEqual(results[0], ["foo",]) + + ####################################### + # Insert Tests + ####################################### + def testInsert(self): + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "bar", "Column": "bar@x.com"}) + + results = list(self.raw_db.Query("SELECT t.Column FROM Table AS t")) + self.assertCountEqual(results, [["foo@x.com",], ["bar@x.com",]]) + + ####################################### + # Update Tests + ####################################### + def testUpdate(self): + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "bar@y.com"}) + self.raw_db.Update(table="Table", row={"Key": "foo", "Column": "qux@y.com"}) + + results = list(self.raw_db.Query("SELECT t.Column FROM Table AS t")) + self.assertEqual(results, [["qux@y.com",]]) + + def testUpdateNotExisting(self): + with self.assertRaises(NotFound): + self.raw_db.Update(table="Table", row={"Key": "foo", "Column": "x@y.com"}) + + ####################################### + # InsertOrUpdate Tests + ####################################### + def testInsertOrUpdate(self): + row = {"Key": "foo"} + + row["Column"] = "bar@example.com" + self.raw_db.InsertOrUpdate(table="Table", row=row) + + row["Column"] = "baz@example.com" + self.raw_db.InsertOrUpdate(table="Table", row=row) + + results = list(self.raw_db.Query("SELECT t.Column FROM Table AS t")) + self.assertEqual(results, [["baz@example.com",]]) + + ####################################### + # Delete Tests + ####################################### + def testDelete(self): + self.raw_db.InsertOrUpdate(table="Table", row={"Key": "foo"}) + self.raw_db.Delete(table="Table", key=("foo",)) + + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) + self.assertEmpty(results) + + def testDeleteSingle(self): + self.raw_db.Insert(table="Table", row={"Key": "foo"}) + self.raw_db.InsertOrUpdate(table="Table", row={"Key": "bar"}) + self.raw_db.Delete(table="Table", key=("foo",)) + + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) + self.assertEqual(results, [["bar",]]) + + def testDeleteNotExisting(self): + # Should not raise. + self.raw_db.Delete(table="Table", key=("foo",)) + + ####################################### + # DeleteWithPrefix Tests + ####################################### + def testDeleteWithPrefix(self): + self.raw_db.Insert(table="Table", row={"Key": "foo"}) + self.raw_db.Insert(table="Table", row={"Key": "quux"}) + + self.raw_db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "bar"}) + self.raw_db.Insert(table="Subtable", row={"Key": "foo", "Subkey": "baz"}) + self.raw_db.Insert(table="Subtable", row={"Key": "quux", "Subkey": "norf"}) + + self.raw_db.DeleteWithPrefix(table="Subtable", key_prefix=["foo"]) + + results = list(self.raw_db.Query("SELECT t.Key, t.Subkey FROM Subtable AS t")) + self.assertLen(results, 1) + self.assertEqual(results[0], ["quux", "norf"]) + + ####################################### + # Read Tests + ####################################### + def testReadSimple(self): + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + + result = self.raw_db.Read(table="Table", key=("foo",), cols=("Column",)) + self.assertEqual(result, ["foo@x.com"]) + + def testReadNotExisting(self): + with self.assertRaises(NotFound): + self.raw_db.Read(table="Table", key=("foo",), cols=("Column",)) + ####################################### + # ReadSet Tests + ####################################### + def testReadSetEmpty(self): + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + + rows = spanner_lib.KeySet() + results = list(self.raw_db.ReadSet(table="Table", rows=rows, cols=("Column",))) + + self.assertEmpty(results) + + def testReadSetSimple(self): + self.raw_db.Insert(table="Table", row={"Key": "foo", "Column": "foo@x.com"}) + self.raw_db.Insert(table="Table", row={"Key": "bar", "Column": "bar@y.com"}) + self.raw_db.Insert(table="Table", row={"Key": "baz", "Column": "baz@z.com"}) + + keyset = spanner_lib.KeySet(keys=[["foo"], ["bar"]]) + results = list(self.raw_db.ReadSet(table="Table", rows=keyset, cols=("Column",))) + + self.assertIn(["foo@x.com"], results) + self.assertIn(["bar@y.com"], results) + self.assertNotIn(["baz@z.com"], results) + + ####################################### + # Mutate Tests + ####################################### + def testMutateSimple(self): + + def Mutation(mut) -> None: + mut.insert( + table="Table", + columns=("Key",), + values=[("foo",)] + ) + mut.insert( + table="Table", + columns=("Key",), + values=[("bar",)] + ) + + self.raw_db.Mutate(Mutation) + + results = list(self.raw_db.Query("SELECT t.Key FROM Table AS t")) + self.assertCountEqual(results, [["foo",], ["bar",]]) + + def testMutateException(self): + + def Mutation(mut) -> None: + mut.insert( + table="Table", + columns=("Key",), + values=[("foo",)] + ) + raise RuntimeError() + + with self.assertRaises(RuntimeError): + self.raw_db.Mutate(Mutation) + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/grr_response_server/databases/spanner_yara.py b/grr/server/grr_response_server/databases/spanner_yara.py new file mode 100644 index 000000000..350e6ecfa --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_yara.py @@ -0,0 +1,61 @@ +"""A module with YARA methods of the Spanner database implementation.""" +import base64 + +from google.api_core.exceptions import NotFound + +from google.cloud import spanner as spanner_lib +from grr_response_server.databases import db +from grr_response_server.databases import db_utils +from grr_response_server.databases import spanner_utils +from grr_response_server.models import blobs as models_blobs + + +class YaraMixin: + """A Spanner database mixin with implementation of YARA methods.""" + + db: spanner_utils.Database + + @db_utils.CallLogged + @db_utils.CallAccounted + def WriteYaraSignatureReference( + self, + blob_id: models_blobs.BlobID, + username: str, + ) -> None: + """Marks the specified blob id as a YARA signature.""" + row = { + "BlobId": base64.b64encode(bytes(blob_id)), + "Creator": username, + "CreationTime": spanner_lib.COMMIT_TIMESTAMP, + } + + try: + self.db.InsertOrUpdate( + table="YaraSignatureReferences", + row=row, + txn_tag="WriteYaraSignatureReference", + ) + except Exception as error: + if "fk_yara_signature_reference_creator_username" in str(error): + raise db.UnknownGRRUserError(username) from error + else: + raise + + @db_utils.CallLogged + @db_utils.CallAccounted + def VerifyYaraSignatureReference( + self, + blob_id: models_blobs.BlobID, + ) -> bool: + """Verifies whether the specified blob is a YARA signature.""" + key = (base64.b64encode(bytes(blob_id)),) + + try: + self.db.Read(table="YaraSignatureReferences", + key=key, + cols=("BlobId",), + txn_tag="VerifyYaraSignatureReference") + except NotFound: + return False + + return True diff --git a/grr/server/grr_response_server/databases/spanner_yara_test.py b/grr/server/grr_response_server/databases/spanner_yara_test.py new file mode 100644 index 000000000..616191b1c --- /dev/null +++ b/grr/server/grr_response_server/databases/spanner_yara_test.py @@ -0,0 +1,23 @@ +from absl.testing import absltest + +from grr_response_server.databases import db_yara_test_lib +from grr_response_server.databases import spanner_test_lib + + +def setUpModule() -> None: + spanner_test_lib.Init(spanner_test_lib.PROD_SCHEMA_SDL_PATH, True) + + +def tearDownModule() -> None: + spanner_test_lib.TearDown() + + +class SpannerDatabaseYaraTest( + db_yara_test_lib.DatabaseTestYaraMixin, spanner_test_lib.TestCase +): + # Test methods are defined in the base mixin class. + pass + + +if __name__ == "__main__": + absltest.main() diff --git a/grr/server/setup.py b/grr/server/setup.py index 16a674003..207f11722 100644 --- a/grr/server/setup.py +++ b/grr/server/setup.py @@ -179,6 +179,7 @@ def make_release_tree(self, base_dir, files): install_requires=[ "google-api-python-client==1.12.11", "google-auth==2.23.3", + "google-cloud-spanner==3.54.0", "google-cloud-storage==2.13.0", "google-cloud-pubsub==2.18.4", "grr-api-client==%s" % VERSION.get("Version", "packagedepends"),