/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.memorycontainer.memory;

import java.io.IOException;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.index.IndexResponse;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.action.memorycontainer.memory.FactSearchResult;
import org.opensearch.ml.action.memorycontainer.memory.MemoryInfo;
import org.opensearch.ml.action.memorycontainer.memory.MemoryOperationsService;
import org.opensearch.ml.action.memorycontainer.memory.MemoryProcessingService;
import org.opensearch.ml.action.memorycontainer.memory.MemorySearchService;
import org.opensearch.ml.common.memorycontainer.MLMemoryContainer;
import org.opensearch.ml.common.memorycontainer.MemoryConfiguration;
import org.opensearch.ml.common.memorycontainer.MemoryDecision;
import org.opensearch.ml.common.memorycontainer.MemoryStrategy;
import org.opensearch.ml.common.memorycontainer.PayloadType;
import org.opensearch.ml.common.settings.MLCommonsSettings;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest;
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesResponse;
import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryEvent;
import org.opensearch.ml.common.transport.memorycontainer.memory.MemoryResult;
import org.opensearch.ml.common.transport.memorycontainer.memory.MessageInput;
import org.opensearch.ml.helper.MemoryContainerHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class TransportAddMemoriesAction
extends HandledTransportAction<MLAddMemoriesRequest, MLAddMemoriesResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(TransportAddMemoriesAction.class);
    private final Client client;
    private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
    private final MemoryContainerHelper memoryContainerHelper;
    private final MemoryProcessingService memoryProcessingService;
    private final MemorySearchService memorySearchService;
    private final MemoryOperationsService memoryOperationsService;
    private final ThreadPool threadPool;

    @Inject
    public TransportAddMemoriesAction(TransportService transportService, ActionFilters actionFilters, Client client, SdkClient sdkClient, NamedXContentRegistry xContentRegistry, MLFeatureEnabledSetting mlFeatureEnabledSetting, MemoryContainerHelper memoryContainerHelper, ThreadPool threadPool) {
        super("cluster:admin/opensearch/ml/memory_containers/memories/add", transportService, actionFilters, MLAddMemoriesRequest::new);
        this.client = client;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
        this.memoryContainerHelper = memoryContainerHelper;
        this.memoryProcessingService = new MemoryProcessingService(client, xContentRegistry, memoryContainerHelper);
        this.memorySearchService = new MemorySearchService(memoryContainerHelper);
        this.memoryOperationsService = new MemoryOperationsService(memoryContainerHelper);
        this.threadPool = threadPool;
    }

    protected void doExecute(Task task, MLAddMemoriesRequest request, ActionListener<MLAddMemoriesResponse> actionListener) {
        if (!this.mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
            actionListener.onFailure((Exception)new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN, new Object[0]));
            return;
        }
        User user = RestActionUtils.getUserContext(this.client);
        String ownerId = this.memoryContainerHelper.getOwnerId(user);
        MLAddMemoriesInput input = request.getMlAddMemoryInput();
        input.setOwnerId(ownerId);
        if (input == null) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Memory input is required"));
            return;
        }
        String memoryContainerId = input.getMemoryContainerId();
        if (StringUtils.isBlank((CharSequence)memoryContainerId)) {
            actionListener.onFailure((Exception)new IllegalArgumentException("Memory container ID is required"));
            return;
        }
        this.memoryContainerHelper.getMemoryContainer(memoryContainerId, (ActionListener<MLMemoryContainer>)ActionListener.wrap(container -> {
            if (!this.memoryContainerHelper.checkMemoryContainerAccess(user, (MLMemoryContainer)container)) {
                actionListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have permissions to add memory to this container", RestStatus.FORBIDDEN, new Object[0]));
                return;
            }
            this.createNewSessionIfAbsent(input, (MLMemoryContainer)container, user, actionListener);
        }, arg_0 -> actionListener.onFailure(arg_0)));
    }

    private void createNewSessionIfAbsent(MLAddMemoriesInput input, MLMemoryContainer container, User user, ActionListener<MLAddMemoriesResponse> actionListener) {
        try {
            boolean userProvidedSessionId;
            container.getConfiguration().getParameters().putAll(input.getParameters());
            List messages = input.getMessages();
            MemoryConfiguration configuration = container.getConfiguration();
            boolean bl = userProvidedSessionId = input.getNamespace() != null && input.getNamespace().containsKey("session_id");
            if (!userProvidedSessionId && input.getPayloadType() == PayloadType.CONVERSATIONAL && !configuration.isDisableSession() && configuration.getLlmId() != null) {
                IndexRequest indexRequest = new IndexRequest(configuration.getSessionIndexName());
                ActionListener summaryListener = ActionListener.wrap(summary -> {
                    Instant now = Instant.now();
                    indexRequest.source(Map.of("owner_id", input.getOwnerId(), "memory_container_id", input.getMemoryContainerId(), "summary", summary, "namespace", input.getNamespace(), "created_time", now.getEpochSecond(), "last_updated_time", now.getEpochSecond()));
                    ActionListener responseActionListener = ActionListener.wrap(r -> {
                        input.getNamespace().put("session_id", r.getId());
                        this.processAndIndexMemory(input, container, user, actionListener);
                    }, e -> actionListener.onFailure(e));
                    this.memoryContainerHelper.indexData(configuration, indexRequest, (ActionListener<IndexResponse>)responseActionListener);
                }, exception -> actionListener.onFailure(exception));
                this.memoryProcessingService.summarizeMessages(container.getConfiguration(), messages, (ActionListener<String>)summaryListener);
            } else {
                this.processAndIndexMemory(input, container, user, actionListener);
            }
        }
        catch (Exception e) {
            actionListener.onFailure(e);
        }
    }

    private void processAndIndexMemory(MLAddMemoriesInput input, MLMemoryContainer container, User user, ActionListener<MLAddMemoriesResponse> actionListener) {
        try {
            boolean hasLlmModel;
            boolean infer = input.isInfer();
            MemoryConfiguration memoryConfig = container.getConfiguration();
            boolean bl = hasLlmModel = memoryConfig != null && memoryConfig.getLlmId() != null;
            if (infer && !hasLlmModel) {
                actionListener.onFailure((Exception)new IllegalArgumentException("infer=true requires llm_model_id to be configured in memory storage"));
                return;
            }
            String workingMemoryIndex = container.getConfiguration().getWorkingMemoryIndexName();
            IndexRequest indexRequest = this.createWorkingMemoryRequest(workingMemoryIndex, input);
            ActionListener responseActionListener = ActionListener.wrap(r -> {
                ArrayList allResults = new ArrayList();
                MLAddMemoriesResponse response = MLAddMemoriesResponse.builder().results(allResults).sessionId(input.getSessionId()).workingMemoryId(r.getId()).build();
                actionListener.onResponse((Object)response);
                if (infer) {
                    this.threadPool.executor("opensearch_ml_agentic_memory").execute(() -> {
                        try {
                            this.extractLongTermMemory(input, container, user, (ActionListener<MLAddMemoriesResponse>)ActionListener.wrap(res -> log.debug("Long term memory results: {}", (Object)res.toString()), e -> log.error("Failed to extract longTermMemory id from memory container", (Throwable)e)));
                        }
                        catch (Exception e2) {
                            this.memoryOperationsService.writeErrorToMemoryHistory(memoryConfig, null, input, e2);
                        }
                    });
                }
            }, arg_0 -> actionListener.onFailure(arg_0));
            this.memoryContainerHelper.indexData(memoryConfig, indexRequest, (ActionListener<IndexResponse>)responseActionListener);
        }
        catch (Exception e) {
            log.error("Failed to add memory", (Throwable)e);
            actionListener.onFailure(e);
        }
    }

    private IndexRequest createWorkingMemoryRequest(String workingMemoryIndex, MLAddMemoriesInput mlAddMemoriesInput) {
        IndexRequest indexRequest = new IndexRequest(workingMemoryIndex);
        try {
            XContentBuilder builder = XContentFactory.jsonBuilder();
            mlAddMemoriesInput.toXContent(builder, ToXContent.EMPTY_PARAMS, true);
            indexRequest.source(builder);
            return indexRequest;
        }
        catch (IOException e) {
            this.logger.error("Failed to build index request source", (Throwable)e);
            throw new RuntimeException("Failed to build index request", e);
        }
    }

    private void extractLongTermMemory(MLAddMemoriesInput input, MLMemoryContainer container, User user, ActionListener<MLAddMemoriesResponse> actionListener) {
        List messages = input.getMessages();
        log.debug("Processing {} messages for fact extraction", (Object)messages.size());
        List strategies = container.getConfiguration().getStrategies();
        MemoryConfiguration memoryConfig = container.getConfiguration();
        for (MemoryStrategy strategy : strategies) {
            if (!strategy.isEnabled()) continue;
            Map<String, String> strategyNameSpace = this.getStrategyNameSpace(strategy, input.getNamespace());
            if (strategyNameSpace.size() != strategy.getNamespace().size()) {
                log.info("Skipping strategy {} due to missing namespace", (Object)strategy.getId());
                continue;
            }
            this.memoryProcessingService.runMemoryStrategy(strategy, messages, memoryConfig, (ActionListener<List<String>>)ActionListener.wrap(facts -> this.storeLongTermMemory(strategy, strategyNameSpace, input, messages, user, (List<String>)facts, memoryConfig, actionListener), e -> {
                log.error("Failed to extract facts with LLM", (Throwable)e);
                this.memoryOperationsService.writeErrorToMemoryHistory(memoryConfig, strategyNameSpace, input, (Exception)e);
                actionListener.onFailure((Exception)new OpenSearchException("Failed to extract facts: " + e.getMessage(), (Throwable)e, new Object[0]));
            }));
        }
    }

    private Map<String, String> getStrategyNameSpace(MemoryStrategy strategy, Map<String, String> namespace) {
        HashMap<String, String> strategyNamespace = new HashMap<String, String>();
        for (String key : strategy.getNamespace()) {
            if (!namespace.containsKey(key)) continue;
            strategyNamespace.put(key, namespace.get(key));
        }
        return strategyNamespace;
    }

    private void storeLongTermMemory(MemoryStrategy strategy, Map<String, String> strategyNameSpace, MLAddMemoriesInput input, List<MessageInput> messages, User user, List<String> facts, MemoryConfiguration memoryConfig, ActionListener<MLAddMemoriesResponse> actionListener) {
        ArrayList<IndexRequest> indexRequests = new ArrayList<IndexRequest>();
        ArrayList<MemoryInfo> memoryInfos = new ArrayList<MemoryInfo>();
        if (!facts.isEmpty() && memoryConfig != null && memoryConfig.getLlmId() != null) {
            log.debug("Searching for similar facts in session to make memory decisions");
            this.memorySearchService.searchSimilarFactsForSession(strategy, input, facts, memoryConfig, (ActionListener<List<FactSearchResult>>)ActionListener.wrap(allSearchResults -> {
                log.debug("Found {} total similar facts across all {} new facts", (Object)allSearchResults.size(), (Object)facts.size());
                if (allSearchResults.size() > 0) {
                    this.memoryProcessingService.makeMemoryDecisions(facts, (List<FactSearchResult>)allSearchResults, strategy, memoryConfig, (ActionListener<List<MemoryDecision>>)ActionListener.wrap(decisions -> this.memoryOperationsService.executeMemoryOperations((List<MemoryDecision>)decisions, memoryConfig, strategyNameSpace, user, input, strategy, (ActionListener<List<MemoryResult>>)ActionListener.wrap(operationResults -> {
                        ArrayList allResults = new ArrayList(operationResults);
                        MLAddMemoriesResponse response = MLAddMemoriesResponse.builder().results(allResults).build();
                        actionListener.onResponse((Object)response);
                    }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))), e -> {
                        log.error("Failed to make memory decisions", (Throwable)e);
                        actionListener.onFailure((Exception)new OpenSearchException("Failed to make memory decisions: " + e.getMessage(), (Throwable)e, new Object[0]));
                    }));
                } else {
                    ArrayList<MemoryDecision> decisions2 = new ArrayList<MemoryDecision>();
                    for (int i = 0; i < facts.size(); ++i) {
                        decisions2.add(MemoryDecision.builder().id("fact_" + i).event(MemoryEvent.ADD).text((String)facts.get(i)).build());
                    }
                    this.memoryOperationsService.executeMemoryOperations(decisions2, memoryConfig, strategyNameSpace, user, input, strategy, (ActionListener<List<MemoryResult>>)ActionListener.wrap(operationResults -> {
                        ArrayList allResults = new ArrayList(operationResults);
                        MLAddMemoriesResponse response = MLAddMemoriesResponse.builder().results(allResults).build();
                        actionListener.onResponse((Object)response);
                    }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0)));
                }
            }, e -> {
                log.error("Failed to search similar facts", (Throwable)e);
                actionListener.onFailure((Exception)new OpenSearchException("Failed to search similar facts: " + e.getMessage(), (Throwable)e, new Object[0]));
            }));
        } else {
            this.memoryOperationsService.createFactMemoriesFromList(facts, memoryConfig.getLongMemoryIndexName(), input, strategyNameSpace, user, strategy, indexRequests, memoryInfos, input.getMemoryContainerId());
        }
    }
}

