diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt b/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt index 169e718e5..0dc15ad3b 100644 --- a/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt +++ b/alerting/src/main/kotlin/org/opensearch/alerting/AlertingPlugin.kt @@ -168,6 +168,7 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R companion object { @JvmField val OPEN_SEARCH_DASHBOARDS_USER_AGENT = "OpenSearch-Dashboards" @JvmField val UI_METADATA_EXCLUDE = arrayOf("monitor.${Monitor.UI_METADATA_FIELD}") + @JvmField val TENANT_ID_HEADER = "x-tenant-id" @JvmField val MONITOR_BASE_URI = "/_plugins/_alerting/monitors" @JvmField val WORKFLOW_BASE_URI = "/_plugins/_alerting/workflows" @JvmField val REMOTE_BASE_URI = "/_plugins/_alerting/remote" diff --git a/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetMonitorAction.kt b/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetMonitorAction.kt index aabf0d2af..f57572243 100644 --- a/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetMonitorAction.kt +++ b/alerting/src/main/kotlin/org/opensearch/alerting/transport/TransportGetMonitorAction.kt @@ -12,12 +12,11 @@ import org.apache.logging.log4j.LogManager import org.apache.lucene.search.join.ScoreMode import org.opensearch.OpenSearchStatusException import org.opensearch.action.ActionRequest -import org.opensearch.action.get.GetRequest -import org.opensearch.action.get.GetResponse import org.opensearch.action.search.SearchRequest import org.opensearch.action.search.SearchResponse import org.opensearch.action.support.ActionFilters import org.opensearch.action.support.HandledTransportAction +import org.opensearch.alerting.AlertingPlugin import org.opensearch.alerting.opensearchapi.suspendUntil import org.opensearch.alerting.settings.AlertingSettings import org.opensearch.alerting.util.ScheduledJobUtils.Companion.WORKFLOW_DELEGATE_PATH @@ -43,10 +42,11 @@ import org.opensearch.core.rest.RestStatus import org.opensearch.core.xcontent.NamedXContentRegistry import org.opensearch.index.IndexNotFoundException import org.opensearch.index.query.QueryBuilders +import org.opensearch.remote.metadata.client.GetDataObjectRequest import org.opensearch.remote.metadata.client.SdkClient +import org.opensearch.remote.metadata.common.SdkClientUtils import org.opensearch.search.builder.SearchSourceBuilder import org.opensearch.tasks.Task -import org.opensearch.transport.RemoteTransportException import org.opensearch.transport.TransportService import org.opensearch.transport.client.Client @@ -84,102 +84,78 @@ class TransportGetMonitorAction @Inject constructor( val user = readUserFromThreadContext(client) - val getRequest = GetRequest(ScheduledJob.SCHEDULED_JOBS_INDEX, transformedRequest.monitorId) - .version(transformedRequest.version) - .fetchSourceContext(transformedRequest.srcContext) - if (!validateUserBackendRoles(user, actionListener)) { return } - /* - * Remove security context before you call elasticsearch api's. By this time, permissions required - * to call this api are validated. - * Once system-indices [https://github.com/opendistro-for-elasticsearch/security/issues/666] is done, we - * might further improve this logic. Also change try to kotlin-use for auto-closable. - */ + val tenantId = client.threadPool().threadContext.getHeader(AlertingPlugin.TENANT_ID_HEADER) + val getRequest = GetDataObjectRequest.builder() + .index(ScheduledJob.SCHEDULED_JOBS_INDEX) + .id(transformedRequest.monitorId) + .tenantId(tenantId) + .fetchSourceContext(transformedRequest.srcContext) + .build() + client.threadPool().threadContext.stashContext().use { - client.get( - getRequest, - object : ActionListener { - override fun onResponse(response: GetResponse) { - if (!response.isExists) { - actionListener.onFailure( - AlertingException.wrap(OpenSearchStatusException("Monitor not found.", RestStatus.NOT_FOUND)) + sdkClient.getDataObjectAsync(getRequest).whenComplete { response, throwable -> + if (throwable != null) { + val cause = SdkClientUtils.unwrapAndConvertToException(throwable) + if (isIndexNotFoundException(cause)) { + actionListener.onFailure( + AlertingException.wrap( + OpenSearchStatusException("Monitor not found.", RestStatus.NOT_FOUND, cause) ) - return - } - - var monitor: Monitor? = null - if (!response.isSourceEmpty) { - XContentHelper.createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - response.sourceAsBytesRef, - XContentType.JSON - ).use { xcp -> - monitor = ScheduledJob.parse(xcp, response.id, response.version) as Monitor - - // security is enabled and filterby is enabled - if (!checkUserPermissionsWithResource( - user, - monitor?.user, - actionListener, - "monitor", - transformedRequest.monitorId - ) - ) { - return - } - } - } - try { - scope.launch { - val associatedCompositeMonitors = getAssociatedWorkflows(response.id) - actionListener.onResponse( - GetMonitorResponse( - response.id, - response.version, - response.seqNo, - response.primaryTerm, - monitor, - associatedCompositeMonitors - ) - ) - } - } catch (e: Exception) { - log.error("Failed to get associate workflows in get monitor action", e) + ) + } else { + actionListener.onFailure(AlertingException.wrap(cause)) + } + return@whenComplete + } + try { + val getResponse = response.getResponse() + if (getResponse == null || !getResponse.isExists) { + actionListener.onFailure( + AlertingException.wrap(OpenSearchStatusException("Monitor not found.", RestStatus.NOT_FOUND)) + ) + return@whenComplete + } + var monitor: Monitor? = null + if (!getResponse.isSourceEmpty) { + XContentHelper.createParser( + xContentRegistry, LoggingDeprecationHandler.INSTANCE, + getResponse.sourceAsBytesRef, XContentType.JSON + ).use { xcp -> + monitor = ScheduledJob.parse(xcp, getResponse.id, getResponse.version) as Monitor } } - - override fun onFailure(ex: Exception) { - if (isIndexNotFoundException(ex)) { - log.error("Index not found while getting monitor", ex) - actionListener.onFailure( - AlertingException.wrap( - OpenSearchStatusException("Monitor not found. Backing index is missing.", RestStatus.NOT_FOUND, ex) - ) + if (!checkUserPermissionsWithResource(user, monitor?.user, actionListener, "monitor", transformedRequest.monitorId)) { + return@whenComplete + } + scope.launch { + val associatedCompositeMonitors = getAssociatedWorkflows(getResponse.id) + actionListener.onResponse( + GetMonitorResponse( + getResponse.id, getResponse.version, getResponse.seqNo, getResponse.primaryTerm, + monitor, associatedCompositeMonitors ) - } else { - log.error("Unexpected error while getting monitor", ex) - actionListener.onFailure(AlertingException.wrap(ex)) - } + ) } + } catch (e: Exception) { + log.error("Failed to parse monitor from SDK response", e) + actionListener.onFailure(AlertingException.wrap(e)) } - ) + } } } // Checks if the exception is caused by an IndexNotFoundException (directly or nested). private fun isIndexNotFoundException(e: Exception): Boolean { - if (e is IndexNotFoundException) { - return true - } - if (e is RemoteTransportException) { - val cause = e.cause + var cause: Throwable? = e + while (cause != null) { if (cause is IndexNotFoundException) { return true } + cause = cause.cause } return false } diff --git a/alerting/src/test/kotlin/org/opensearch/alerting/transport/TransportGetMonitorActionTests.kt b/alerting/src/test/kotlin/org/opensearch/alerting/transport/TransportGetMonitorActionTests.kt new file mode 100644 index 000000000..3a3e61413 --- /dev/null +++ b/alerting/src/test/kotlin/org/opensearch/alerting/transport/TransportGetMonitorActionTests.kt @@ -0,0 +1,161 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.alerting.transport + +import org.junit.Before +import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito +import org.mockito.Mockito.verify +import org.opensearch.action.support.ActionFilters +import org.opensearch.alerting.AlertingPlugin.Companion.TENANT_ID_HEADER +import org.opensearch.alerting.settings.AlertingSettings +import org.opensearch.cluster.service.ClusterService +import org.opensearch.common.settings.ClusterSettings +import org.opensearch.common.settings.Setting +import org.opensearch.common.settings.Settings +import org.opensearch.common.util.concurrent.ThreadContext +import org.opensearch.commons.alerting.action.GetMonitorRequest +import org.opensearch.commons.alerting.action.GetMonitorResponse +import org.opensearch.core.action.ActionListener +import org.opensearch.core.xcontent.NamedXContentRegistry +import org.opensearch.remote.metadata.client.GetDataObjectRequest +import org.opensearch.remote.metadata.client.GetDataObjectResponse +import org.opensearch.remote.metadata.client.SdkClient +import org.opensearch.rest.RestRequest +import org.opensearch.test.OpenSearchTestCase +import org.opensearch.threadpool.ThreadPool +import org.opensearch.transport.TransportService +import org.opensearch.transport.client.Client +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionStage +import org.mockito.Mockito.`when` as whenever + +class TransportGetMonitorActionTests : OpenSearchTestCase() { + + private lateinit var client: Client + private lateinit var sdkClient: SdkClient + private lateinit var transportService: TransportService + private lateinit var actionFilters: ActionFilters + private lateinit var xContentRegistry: NamedXContentRegistry + private lateinit var clusterService: ClusterService + private lateinit var threadPool: ThreadPool + private lateinit var threadContext: ThreadContext + + @Before + fun setup() { + client = Mockito.mock(Client::class.java) + sdkClient = Mockito.mock(SdkClient::class.java) + transportService = Mockito.mock(TransportService::class.java) + actionFilters = Mockito.mock(ActionFilters::class.java) + xContentRegistry = Mockito.mock(NamedXContentRegistry::class.java) + clusterService = Mockito.mock(ClusterService::class.java) + threadPool = Mockito.mock(ThreadPool::class.java) + threadContext = ThreadContext(Settings.EMPTY) + + whenever(client.threadPool()).thenReturn(threadPool) + whenever(threadPool.threadContext).thenReturn(threadContext) + + val settingSet = hashSetOf>() + settingSet.addAll(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + settingSet.add(AlertingSettings.FILTER_BY_BACKEND_ROLES) + val clusterSettings = ClusterSettings(Settings.EMPTY, settingSet) + whenever(clusterService.clusterSettings).thenReturn(clusterSettings) + } + + fun `test SDK called with correct tenantId and monitorId`() { + val expectedTenantId = "test-tenant:test-scope" + threadContext.putHeader(TENANT_ID_HEADER, expectedTenantId) + + val future: CompletionStage = CompletableFuture.completedFuture( + GetDataObjectResponse.builder() + .id("test-monitor-id") + .index(".opendistro-alerting-config") + .source(emptyMap()) + .build() + ) + whenever(sdkClient.getDataObjectAsync(any(GetDataObjectRequest::class.java))).thenReturn(future) + + val action = createAction(Settings.builder().build()) + val request = GetMonitorRequest("test-monitor-id", 0L, RestRequest.Method.GET, null) + @Suppress("UNCHECKED_CAST") + val listener = Mockito.mock(ActionListener::class.java) as ActionListener + + invokeDoExecute(action, request, listener) + + val requestCaptor = ArgumentCaptor.forClass(GetDataObjectRequest::class.java) + verify(sdkClient).getDataObjectAsync(requestCaptor.capture()) + assertEquals(expectedTenantId, requestCaptor.value.tenantId()) + assertEquals("test-monitor-id", requestCaptor.value.id()) + } + + fun `test SDK called with null tenantId when header absent`() { + val future: CompletionStage = CompletableFuture.completedFuture( + GetDataObjectResponse.builder() + .id("test-monitor-id") + .index(".opendistro-alerting-config") + .source(emptyMap()) + .build() + ) + whenever(sdkClient.getDataObjectAsync(any(GetDataObjectRequest::class.java))).thenReturn(future) + + val action = createAction(Settings.builder().build()) + val request = GetMonitorRequest("test-monitor-id", 0L, RestRequest.Method.GET, null) + @Suppress("UNCHECKED_CAST") + val listener = Mockito.mock(ActionListener::class.java) as ActionListener + + invokeDoExecute(action, request, listener) + + val requestCaptor = ArgumentCaptor.forClass(GetDataObjectRequest::class.java) + verify(sdkClient).getDataObjectAsync(requestCaptor.capture()) + assertNull(requestCaptor.value.tenantId()) + } + + fun `test SDK exception propagated to listener`() { + threadContext.putHeader(TENANT_ID_HEADER, "test-tenant:test-scope") + + val future: CompletionStage = CompletableFuture().also { + it.completeExceptionally(RuntimeException("SDK connection failed")) + } + whenever(sdkClient.getDataObjectAsync(any(GetDataObjectRequest::class.java))).thenReturn(future) + + val action = createAction(Settings.builder().build()) + val request = GetMonitorRequest("test-monitor-id", 0L, RestRequest.Method.GET, null) + @Suppress("UNCHECKED_CAST") + val listener = Mockito.mock(ActionListener::class.java) as ActionListener + + invokeDoExecute(action, request, listener) + + verify(listener).onFailure(any()) + } + + private fun invokeDoExecute( + action: TransportGetMonitorAction, + request: GetMonitorRequest, + listener: ActionListener + ) { + val method = action.javaClass.getDeclaredMethod( + "doExecute", + org.opensearch.tasks.Task::class.java, + org.opensearch.action.ActionRequest::class.java, + ActionListener::class.java + ) + method.isAccessible = true + method.invoke(action, Mockito.mock(org.opensearch.tasks.Task::class.java), request, listener) + } + + private fun createAction(settings: Settings): TransportGetMonitorAction { + val settingSet = hashSetOf>() + settingSet.addAll(ClusterSettings.BUILT_IN_CLUSTER_SETTINGS) + settingSet.add(AlertingSettings.FILTER_BY_BACKEND_ROLES) + val clusterSettings = ClusterSettings(settings, settingSet) + whenever(clusterService.clusterSettings).thenReturn(clusterSettings) + + return TransportGetMonitorAction( + transportService, client, actionFilters, xContentRegistry, clusterService, settings, sdkClient + ) + } +}