diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index adddf6b39d..a5c89814fe 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -36,7 +36,7 @@ public class CommonValue { // warm node public static String WARM_BOX_TYPE = "warm"; public static final String ML_INDEX_INSIGHT_CONFIG_INDEX = ".plugins-ml-index-insight-config"; - public static final String ML_INDEX_INSIGHT_STORAGE_INDEX = ".plugins-ml-index-insight-storage"; + public static final String ML_INDEX_INSIGHT_STORAGE_INDEX = "plugins-ml-index-insight-storage"; public static final String ML_MODEL_GROUP_INDEX = ".plugins-ml-model-group"; public static final String ML_MODEL_INDEX = ".plugins-ml-model"; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/indexInsight/MLIndexInsightGetRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/indexInsight/MLIndexInsightGetRequest.java index ee07107bdc..e2aa0b6519 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/indexInsight/MLIndexInsightGetRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/indexInsight/MLIndexInsightGetRequest.java @@ -18,10 +18,10 @@ import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; -import org.opensearch.ml.common.indexInsight.MLIndexInsightType; import lombok.Builder; import lombok.Getter; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; @Builder @Getter diff --git a/common/src/test/java/org/opensearch/ml/common/indexInsight/StatisticalDataTaskTests.java b/common/src/test/java/org/opensearch/ml/common/indexInsight/StatisticalDataTaskTests.java index 5671885e48..0b1c3d259e 100644 --- a/common/src/test/java/org/opensearch/ml/common/indexInsight/StatisticalDataTaskTests.java +++ b/common/src/test/java/org/opensearch/ml/common/indexInsight/StatisticalDataTaskTests.java @@ -26,8 +26,6 @@ import java.util.Map; import java.util.Set; -import javax.swing.*; - import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; diff --git a/plugin/src/main/java/org/opensearch/ml/action/IndexInsight/GetIndexInsightTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/IndexInsight/GetIndexInsightTransportAction.java index cafe4b3235..5b96425391 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/IndexInsight/GetIndexInsightTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/IndexInsight/GetIndexInsightTransportAction.java @@ -5,7 +5,13 @@ package org.opensearch.ml.action.IndexInsight; +import static org.opensearch.ml.common.CommonValue.ML_INDEX_INSIGHT_STORAGE_INDEX; +import static org.opensearch.ml.common.CommonValue.ML_INDEX_INSIGHT_STORAGE_INDEX_MAPPING_PATH; +import static org.opensearch.ml.common.indexInsight.MLIndexInsightType.FIELD_DESCRIPTION; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED; +import static org.opensearch.ml.common.indexInsight.MLIndexInsightType.LOG_RELATED_INDEX_CHECK; +import static org.opensearch.ml.common.indexInsight.MLIndexInsightType.STATISTICAL_DATA; import java.time.Instant; @@ -15,20 +21,22 @@ import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.MLIndex; -import org.opensearch.ml.common.indexInsight.FieldDescriptionTask; import org.opensearch.ml.common.indexInsight.IndexInsight; -import org.opensearch.ml.common.indexInsight.IndexInsightAccessControllerHelper; -import org.opensearch.ml.common.indexInsight.IndexInsightTask; import org.opensearch.ml.common.indexInsight.IndexInsightTaskStatus; -import org.opensearch.ml.common.indexInsight.LogRelatedIndexCheckTask; import org.opensearch.ml.common.indexInsight.MLIndexInsightType; -import org.opensearch.ml.common.indexInsight.StatisticalDataTask; +import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetAction; import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetRequest; import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightGetResponse; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; +import org.opensearch.ml.indexInsight.FieldDescriptionTask; +import org.opensearch.ml.indexInsight.IndexInsightAccessControllerHelper; +import org.opensearch.ml.common.indexInsight.IndexInsightTask; +import org.opensearch.ml.indexInsight.LogRelatedIndexCheckTask; +import org.opensearch.ml.indexInsight.StatisticalDataTask; import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.tasks.Task; @@ -41,7 +49,7 @@ public class GetIndexInsightTransportAction extends HandledTransportAction { private static final MLIndexInsightType[] ALL_TYPE_ORDER = { MLIndexInsightType.STATISTICAL_DATA, - MLIndexInsightType.FIELD_DESCRIPTION, + FIELD_DESCRIPTION, MLIndexInsightType.LOG_RELATED_INDEX_CHECK }; private final Client client; @@ -49,6 +57,8 @@ public class GetIndexInsightTransportAction extends HandledTransportAction { - ActionListener actionAfterDryRun = ActionListener.wrap(r -> { - executeTaskAndReturn(mlIndexInsightGetRequest, mlIndexInsightGetRequest.getTenantId(), actionListener); - }, actionListener::onFailure); - IndexInsightAccessControllerHelper.verifyAccessController(client, actionAfterDryRun, indexName); + String memoryContainerId = client.threadPool().getThreadContext().getHeader(MEMORY_CONTAINER_ID_FIELD); + String indexMappings = mlIndicesHandler.getMapping(ML_INDEX_INSIGHT_STORAGE_INDEX_MAPPING_PATH); + memoryContainerHelper.getMemoryContainer(memoryContainerId, ActionListener.wrap(mlMemoryContainer -> { + RemoteStore remoteStore = mlMemoryContainer.getConfiguration().getRemoteStore(); + if (remoteStore.getConnectorId() != null) { + remoteMemoryStoreHelper.createRemoteIndex(remoteStore.getConnectorId(), ML_INDEX_INSIGHT_STORAGE_INDEX, indexMappings, ActionListener.wrap(r2 -> { + ActionListener actionAfterDryRun = ActionListener.wrap(r -> { + executeTaskAndReturn(mlIndexInsightGetRequest, mlIndexInsightGetRequest.getTenantId(), actionListener); + }, actionListener::onFailure); + IndexInsightAccessControllerHelper.verifyAccessController(client, actionAfterDryRun, indexName); + }, e -> { + log.error("Failed to create index insight storage", e); + actionListener.onFailure(e); + })); + } + }, e -> { - log.error("Failed to create index insight storage", e); - actionListener.onFailure(e); + log.error("Failed to retrieve memory container", e); })); } @@ -192,7 +216,9 @@ IndexInsightTask createTask(MLIndexInsightGetRequest request) { client, sdkClient, request.getCmkRoleArn(), - request.getAssumeRoleArn() + request.getAssumeRoleArn(), + remoteMemoryStoreHelper, + memoryContainerHelper ); case FIELD_DESCRIPTION: return new FieldDescriptionTask( @@ -200,7 +226,9 @@ IndexInsightTask createTask(MLIndexInsightGetRequest request) { client, sdkClient, request.getCmkRoleArn(), - request.getAssumeRoleArn() + request.getAssumeRoleArn(), + remoteMemoryStoreHelper, + memoryContainerHelper ); case LOG_RELATED_INDEX_CHECK: return new LogRelatedIndexCheckTask( @@ -208,7 +236,9 @@ IndexInsightTask createTask(MLIndexInsightGetRequest request) { client, sdkClient, request.getCmkRoleArn(), - request.getAssumeRoleArn() + request.getAssumeRoleArn(), + remoteMemoryStoreHelper, + memoryContainerHelper ); default: throw new IllegalArgumentException("Unsupported task type: " + request.getTargetIndexInsight()); diff --git a/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java index 26805ac693..e5b08a7709 100644 --- a/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/helper/RemoteMemoryStoreHelper.java @@ -41,6 +41,7 @@ import org.opensearch.action.search.SearchResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Nullable; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -84,6 +85,7 @@ public class RemoteMemoryStoreHelper { public static final String CREATE_INGEST_PIPELINE_ACTION = "create_ingest_pipeline"; public static final String CREATE_INDEX_ACTION = "create_index"; public static final String WRITE_DOC_ACTION = "write_doc"; + public static final String WRITE_DOC_ACTION_WITH_ID = "write_doc_withID"; public static final String BULK_LOAD_ACTION = "bulk_load"; public static final String SEARCH_INDEX_ACTION = "search_index"; public static final String GET_DOC_ACTION = "get_doc"; @@ -471,6 +473,21 @@ public void writeDocument( } } + public void writeDocumentWithDocID( + RemoteStore remoteStore, + String docId, + String indexName, + Map documentSource, + ActionListener listener + ) { + // If connectorId is provided, use the existing method + if (remoteStore.getConnector() != null) { + writeDocument(remoteStore.getConnector(), indexName, documentSource, listener, docId); + } else { + listener.onFailure(new IllegalArgumentException("RemoteStore must have either connectorId or internal connector configured")); + } + } + public void writeDocument( String connectorId, String indexName, @@ -529,6 +546,38 @@ public void writeDocument( } } + public void writeDocument( + Connector connector, + String indexName, + Map documentSource, + ActionListener listener, + String docId + ) { + try { + // Prepare parameters for connector execution + Map parameters = new HashMap<>(); + parameters.put(INDEX_NAME_PARAM, indexName); + parameters.put(DOC_ID_PARAM, docId); + parameters.put(INPUT_PARAM, StringUtils.toJsonWithPlainNumbers(documentSource)); + + // Execute the connector action with write_doc action name + executeConnectorAction(connector, WRITE_DOC_ACTION_WITH_ID, parameters, ActionListener.wrap(response -> { + // Extract document ID from response + XContentParser parser = createParserFromTensorOutput(response); + IndexResponse indexResponse = IndexResponse.fromXContent(parser); + listener.onResponse(indexResponse); + }, e -> { + log.error("Failed to write document to remote index: {}", indexName, e); + listener.onFailure(e); + })); + + } catch (Exception e) { + log.error("Error preparing remote document write for index: {}", indexName, e); + listener.onFailure(e); + } + } + + /** * Performs bulk write operations to remote storage using RemoteStore configuration */ diff --git a/common/src/main/java/org/opensearch/ml/common/indexInsight/AbstractIndexInsightTask.java b/plugin/src/main/java/org/opensearch/ml/indexInsight/AbstractIndexInsightTask.java similarity index 86% rename from common/src/main/java/org/opensearch/ml/common/indexInsight/AbstractIndexInsightTask.java rename to plugin/src/main/java/org/opensearch/ml/indexInsight/AbstractIndexInsightTask.java index 4ee18707dc..3280186137 100644 --- a/common/src/main/java/org/opensearch/ml/common/indexInsight/AbstractIndexInsightTask.java +++ b/plugin/src/main/java/org/opensearch/ml/indexInsight/AbstractIndexInsightTask.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.indexInsight; +package org.opensearch.ml.indexInsight; import static org.opensearch.ml.common.CommonValue.INDEX_INSIGHT_AGENT_NAME; import static org.opensearch.ml.common.CommonValue.INDEX_INSIGHT_GENERATING_TIMEOUT; @@ -11,6 +11,7 @@ import static org.opensearch.ml.common.CommonValue.ML_INDEX_INSIGHT_STORAGE_INDEX; import static org.opensearch.ml.common.indexInsight.IndexInsight.INDEX_NAME_FIELD; import static org.opensearch.ml.common.indexInsight.IndexInsight.TASK_TYPE_FIELD; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.nio.charset.StandardCharsets; @@ -25,8 +26,10 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.support.WriteRequest; import org.opensearch.common.Numbers; import org.opensearch.common.regex.Regex; import org.opensearch.common.util.concurrent.ThreadContext; @@ -39,7 +42,12 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLConfig; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.indexInsight.IndexInsight; +import org.opensearch.ml.common.indexInsight.IndexInsightTask; +import org.opensearch.ml.common.indexInsight.IndexInsightTaskStatus; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.memorycontainer.RemoteStore; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -47,8 +55,8 @@ import org.opensearch.ml.common.transport.config.MLConfigGetRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; -import org.opensearch.remote.metadata.client.GetDataObjectRequest; -import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.client.SearchDataObjectRequest; import org.opensearch.remote.metadata.common.SdkClientUtils; @@ -73,6 +81,9 @@ public abstract class AbstractIndexInsightTask implements IndexInsightTask { protected final SdkClient sdkClient; protected final String cmkRoleArn; protected final String cmkAssumeRoleArn; + protected final RemoteMemoryStoreHelper remoteMemoryStoreHelper; + protected final MemoryContainerHelper memoryContainerHelper; + protected final String memoryContainerId; protected AbstractIndexInsightTask( MLIndexInsightType taskType, @@ -80,7 +91,9 @@ protected AbstractIndexInsightTask( Client client, SdkClient sdkClient, String cmkRoleArn, - String cmkAssumeRoleArn + String cmkAssumeRoleArn, + RemoteMemoryStoreHelper remoteMemoryStoreHelper, + MemoryContainerHelper memoryContainerHelper ) { this.taskType = taskType; this.sourceIndex = sourceIndex; @@ -88,6 +101,9 @@ protected AbstractIndexInsightTask( this.sdkClient = sdkClient; this.cmkRoleArn = cmkRoleArn; this.cmkAssumeRoleArn = cmkAssumeRoleArn; + this.remoteMemoryStoreHelper = remoteMemoryStoreHelper; + this.memoryContainerHelper = memoryContainerHelper; + this.memoryContainerId = client.threadPool().getThreadContext().getHeader(MEMORY_CONTAINER_ID_FIELD); } /** @@ -335,78 +351,40 @@ protected void handlePatternResult(Map patternSource, String ten } private void getIndexInsight(String docId, String tenantId, ActionListener listener) { - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient - .getDataObjectAsync( - GetDataObjectRequest - .builder() - .tenantId(tenantId) - .index(ML_INDEX_INSIGHT_STORAGE_INDEX) - .id(docId) - .cmkRoleArn(cmkRoleArn) - .assumeRoleArn(cmkAssumeRoleArn) - .build() - ) - .whenComplete((r, throwable) -> { - context.restore(); - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - log.error("Failed to get index insight document", cause); - listener.onFailure(cause); - } else { - try { - GetResponse getResponse = r.getResponse(); - assert getResponse != null; - listener.onResponse(getResponse); - } catch (Exception e) { - listener.onFailure(e); - } + memoryContainerHelper.getMemoryContainer(memoryContainerId, ActionListener.wrap(mlMemoryContainer -> { + RemoteStore remoteStore = mlMemoryContainer.getConfiguration().getRemoteStore(); + remoteMemoryStoreHelper.getDocument(remoteStore, ML_INDEX_INSIGHT_STORAGE_INDEX, docId, ActionListener.wrap(listener::onResponse, e -> { + listener.onFailure(new RuntimeException("Fail to retrieve index insight", e)); } - }); - } catch (Exception e) { - listener.onFailure(e); - } + + )); + }, e -> { + listener.onFailure(new RuntimeException("Error happening when retrieve memory container", e)); + })); } + private void writeIndexInsight(IndexInsight indexInsight, String tenantId, ActionListener listener) { String docId = generateDocId(); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - sdkClient - .putDataObjectAsync( - PutDataObjectRequest - .builder() - .tenantId(tenantId) - .index(ML_INDEX_INSIGHT_STORAGE_INDEX) - .dataObject(indexInsight) - .id(docId) - .cmkRoleArn(cmkRoleArn) - .assumeRoleArn(cmkAssumeRoleArn) - .build() - ) - .whenComplete((r, throwable) -> { - context.restore(); - if (throwable != null) { - Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable); - log.error("Failed to write index insight document", cause); - listener.onFailure(cause); - } else { - try { - IndexResponse indexResponse = r.indexResponse(); - assert indexResponse != null; - if (indexResponse.getResult() == DocWriteResponse.Result.CREATED - || indexResponse.getResult() == DocWriteResponse.Result.UPDATED) { - listener.onResponse(true); - } else { - listener.onFailure(new RuntimeException("Failed to put generating index insight doc")); - } - } catch (Exception e) { - listener.onFailure(e); - } + IndexRequest indexRequest = new IndexRequest(ML_INDEX_INSIGHT_STORAGE_INDEX).id(docId).source(indexInsight); + indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + memoryContainerHelper.getMemoryContainer(memoryContainerId, ActionListener.wrap(mlMemoryContainer -> { + RemoteStore remoteStore = mlMemoryContainer.getConfiguration().getRemoteStore(); + remoteMemoryStoreHelper.writeDocumentWithDocID(remoteStore, docId, tenantId, indexRequest.sourceAsMap(), ActionListener.wrap(indexResponse -> { + if (indexResponse.getResult() == DocWriteResponse.Result.CREATED + || indexResponse.getResult() == DocWriteResponse.Result.UPDATED) { + listener.onResponse(true); + } else { + listener.onFailure(new RuntimeException("Failed to put generating index insight doc")); + } + }, e -> { + listener.onFailure(new RuntimeException("Error happening when putting doc", e)); } - }); - } catch (Exception e) { - listener.onFailure(e); - } + + )); + }, e -> { + listener.onFailure(new RuntimeException("Error happening when retrieve memory container", e)); + })); } protected static void getAgentIdToRun(Client client, String tenantId, ActionListener actionListener) { diff --git a/common/src/main/java/org/opensearch/ml/common/indexInsight/FieldDescriptionTask.java b/plugin/src/main/java/org/opensearch/ml/indexInsight/FieldDescriptionTask.java similarity index 94% rename from common/src/main/java/org/opensearch/ml/common/indexInsight/FieldDescriptionTask.java rename to plugin/src/main/java/org/opensearch/ml/indexInsight/FieldDescriptionTask.java index 1eda0c0390..ad7f036ed0 100644 --- a/common/src/main/java/org/opensearch/ml/common/indexInsight/FieldDescriptionTask.java +++ b/plugin/src/main/java/org/opensearch/ml/indexInsight/FieldDescriptionTask.java @@ -3,11 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.indexInsight; +package org.opensearch.ml.indexInsight; import static java.util.concurrent.TimeUnit.SECONDS; -import static org.opensearch.ml.common.indexInsight.StatisticalDataTask.EXAMPLE_DOC_KEYWORD; -import static org.opensearch.ml.common.indexInsight.StatisticalDataTask.IMPORTANT_COLUMN_KEYWORD; +import static org.opensearch.ml.indexInsight.StatisticalDataTask.EXAMPLE_DOC_KEYWORD; +import static org.opensearch.ml.indexInsight.StatisticalDataTask.IMPORTANT_COLUMN_KEYWORD; import static org.opensearch.ml.common.utils.StringUtils.gson; import java.time.Instant; @@ -25,6 +25,12 @@ import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.indexInsight.IndexInsight; +import org.opensearch.ml.common.indexInsight.IndexInsightTask; +import org.opensearch.ml.common.indexInsight.IndexInsightTaskStatus; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; +import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.transport.client.Client; @@ -40,8 +46,9 @@ public class FieldDescriptionTask extends AbstractIndexInsightTask { private static final int BATCH_SIZE = 50; // Hard-coded value for now - public FieldDescriptionTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn) { - super(MLIndexInsightType.FIELD_DESCRIPTION, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn); + public FieldDescriptionTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn, RemoteMemoryStoreHelper remoteMemoryStoreHelper, + MemoryContainerHelper memoryContainerHelper) { + super(MLIndexInsightType.FIELD_DESCRIPTION, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn, remoteMemoryStoreHelper, memoryContainerHelper); } @Override @@ -330,7 +337,7 @@ private Map parseFieldDescription(String modelResponse) { @Override public IndexInsightTask createPrerequisiteTask(MLIndexInsightType prerequisiteType) { if (prerequisiteType == MLIndexInsightType.STATISTICAL_DATA) { - return new StatisticalDataTask(sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn); + return new StatisticalDataTask(sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn, remoteMemoryStoreHelper, memoryContainerHelper); } throw new IllegalStateException("Unsupported prerequisite type: " + prerequisiteType); } diff --git a/common/src/main/java/org/opensearch/ml/common/indexInsight/IndexInsightAccessControllerHelper.java b/plugin/src/main/java/org/opensearch/ml/indexInsight/IndexInsightAccessControllerHelper.java similarity index 97% rename from common/src/main/java/org/opensearch/ml/common/indexInsight/IndexInsightAccessControllerHelper.java rename to plugin/src/main/java/org/opensearch/ml/indexInsight/IndexInsightAccessControllerHelper.java index e0ccaf84e8..a246c68386 100644 --- a/common/src/main/java/org/opensearch/ml/common/indexInsight/IndexInsightAccessControllerHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/indexInsight/IndexInsightAccessControllerHelper.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.indexInsight; +package org.opensearch.ml.indexInsight; import org.opensearch.action.admin.indices.mapping.get.GetMappingsRequest; import org.opensearch.action.search.SearchRequest; diff --git a/common/src/main/java/org/opensearch/ml/common/indexInsight/IndexInsightConfig.java b/plugin/src/main/java/org/opensearch/ml/indexInsight/IndexInsightConfig.java similarity index 98% rename from common/src/main/java/org/opensearch/ml/common/indexInsight/IndexInsightConfig.java rename to plugin/src/main/java/org/opensearch/ml/indexInsight/IndexInsightConfig.java index ac6778a1dc..4c6b54949f 100644 --- a/common/src/main/java/org/opensearch/ml/common/indexInsight/IndexInsightConfig.java +++ b/plugin/src/main/java/org/opensearch/ml/indexInsight/IndexInsightConfig.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.indexInsight; +package org.opensearch.ml.indexInsight; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; diff --git a/common/src/main/java/org/opensearch/ml/common/indexInsight/LogRelatedIndexCheckTask.java b/plugin/src/main/java/org/opensearch/ml/indexInsight/LogRelatedIndexCheckTask.java similarity index 92% rename from common/src/main/java/org/opensearch/ml/common/indexInsight/LogRelatedIndexCheckTask.java rename to plugin/src/main/java/org/opensearch/ml/indexInsight/LogRelatedIndexCheckTask.java index 8df30a823c..8a3bbd8195 100644 --- a/common/src/main/java/org/opensearch/ml/common/indexInsight/LogRelatedIndexCheckTask.java +++ b/plugin/src/main/java/org/opensearch/ml/indexInsight/LogRelatedIndexCheckTask.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.indexInsight; +package org.opensearch.ml.indexInsight; import static org.opensearch.ml.common.utils.StringUtils.MAPPER; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -17,6 +17,11 @@ import org.opensearch.action.search.SearchRequest; import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.ml.common.indexInsight.IndexInsight; +import org.opensearch.ml.common.indexInsight.IndexInsightTask; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; +import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; @@ -81,8 +86,9 @@ Your task is to analyze the structure and semantics of this index, and determine - Your judgment should be based on both semantics and field patterns (e.g., field names like "message", "log", "trace", "span", etc). """; - public LogRelatedIndexCheckTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn) { - super(MLIndexInsightType.LOG_RELATED_INDEX_CHECK, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn); + public LogRelatedIndexCheckTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn, RemoteMemoryStoreHelper remoteMemoryStoreHelper, + MemoryContainerHelper memoryContainerHelper) { + super(MLIndexInsightType.LOG_RELATED_INDEX_CHECK, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn, remoteMemoryStoreHelper, memoryContainerHelper); } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/indexInsight/StatisticalDataTask.java b/plugin/src/main/java/org/opensearch/ml/indexInsight/StatisticalDataTask.java similarity index 96% rename from common/src/main/java/org/opensearch/ml/common/indexInsight/StatisticalDataTask.java rename to plugin/src/main/java/org/opensearch/ml/indexInsight/StatisticalDataTask.java index f44ee2ef3a..091cd0fdff 100644 --- a/common/src/main/java/org/opensearch/ml/common/indexInsight/StatisticalDataTask.java +++ b/plugin/src/main/java/org/opensearch/ml/indexInsight/StatisticalDataTask.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.common.indexInsight; +package org.opensearch.ml.indexInsight; import static org.opensearch.ml.common.utils.StringUtils.gson; @@ -26,7 +26,13 @@ import org.opensearch.cluster.metadata.MappingMetadata; import org.opensearch.core.action.ActionListener; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.indexInsight.IndexInsight; +import org.opensearch.ml.common.indexInsight.IndexInsightTask; +import org.opensearch.ml.common.indexInsight.IndexInsightTaskStatus; +import org.opensearch.ml.common.indexInsight.MLIndexInsightType; import org.opensearch.ml.common.utils.mergeMetaDataUtils.MergeRuleHelper; +import org.opensearch.ml.helper.MemoryContainerHelper; +import org.opensearch.ml.helper.RemoteMemoryStoreHelper; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.search.SearchHit; import org.opensearch.search.aggregations.Aggregation; @@ -88,8 +94,9 @@ public class StatisticalDataTask extends AbstractIndexInsightTask { detailed information: %s """; - public StatisticalDataTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn) { - super(MLIndexInsightType.STATISTICAL_DATA, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn); + public StatisticalDataTask(String sourceIndex, Client client, SdkClient sdkClient, String cmkRoleArn, String cmkAssumeRoleArn, RemoteMemoryStoreHelper remoteMemoryStoreHelper, + MemoryContainerHelper memoryContainerHelper) { + super(MLIndexInsightType.STATISTICAL_DATA, sourceIndex, client, sdkClient, cmkRoleArn, cmkAssumeRoleArn, remoteMemoryStoreHelper, memoryContainerHelper); } @Override diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java index bf6c6a060c..93945a62a0 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteAction.java @@ -8,6 +8,7 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_CONTAINER_ID_FIELD; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AG_UI_DISABLED_MESSAGE; import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_DISABLED_MESSAGE; import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; @@ -162,6 +163,7 @@ MLExecuteTaskRequest getRequest(RestRequest request, NodeClient client) throws I ); } putMcpRequestHeaders(request, client); + client.threadPool().getThreadContext().putHeader(MEMORY_CONTAINER_ID_FIELD, request.header(MEMORY_CONTAINER_ID_FIELD)); } } else if (uri.startsWith(ML_BASE_URI + "/tools/")) { if (!mlFeatureEnabledSetting.isToolExecuteEnabled()) {