/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote.streaming;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportRequest;

public class StreamPredictActionListener<Response extends TransportResponse, Request extends TransportRequest>
implements ActionListener<Response> {
    @Generated
    private static final Logger log = LogManager.getLogger(StreamPredictActionListener.class);
    private final TransportChannel channel;
    private final ActionListener<Response> agentListener;
    private final String memoryId;
    private final String parentInteractionId;

    public StreamPredictActionListener(TransportChannel channel) {
        this(channel, null, null, null);
    }

    public StreamPredictActionListener(TransportChannel channel, ActionListener<Response> agentListener, String memoryId, String parentInteractionId) {
        this.channel = channel;
        this.agentListener = agentListener;
        this.memoryId = memoryId;
        this.parentInteractionId = parentInteractionId;
    }

    public void onStreamResponse(Response response, boolean isLastBatch) {
        assert (response != null);
        Response responseWithMetadata = this.addMetadataToResponse(response);
        this.channel.sendResponseBatch(responseWithMetadata);
        if (isLastBatch) {
            this.channel.completeStream();
        }
    }

    public final void onResponse(Response response) {
        this.onStreamResponse(response, false);
        if (this.agentListener != null) {
            this.agentListener.onResponse(response);
        }
    }

    public void onFailure(Exception e) {
        try {
            MLTaskResponse errorResponse = this.createErrorResponse(e);
            this.channel.sendResponseBatch((TransportResponse)errorResponse);
            this.channel.completeStream();
        }
        catch (Exception exc) {
            try {
                this.channel.completeStream();
            }
            catch (Exception streamException) {
                log.error("Failed to complete stream", (Throwable)streamException);
            }
        }
    }

    private Response addMetadataToResponse(Response response) {
        if (!(response instanceof MLTaskResponse)) {
            return response;
        }
        if (this.agentListener == null) {
            return response;
        }
        MLTaskResponse mlResponse = (MLTaskResponse)response;
        if (mlResponse.getOutput() instanceof ModelTensorOutput) {
            ModelTensorOutput output = (ModelTensorOutput)mlResponse.getOutput();
            ArrayList<ModelTensors> updatedOutputs = new ArrayList<ModelTensors>();
            for (ModelTensors tensors : output.getMlModelOutputs()) {
                ArrayList<ModelTensor> updatedTensors = new ArrayList<ModelTensor>();
                updatedTensors.add(ModelTensor.builder().name("memory_id").result(this.memoryId).build());
                updatedTensors.add(ModelTensor.builder().name("parent_interaction_id").result(this.parentInteractionId).build());
                updatedTensors.addAll(tensors.getMlModelTensors());
                updatedOutputs.add(ModelTensors.builder().mlModelTensors(updatedTensors).build());
            }
            ModelTensorOutput updatedOutput = ModelTensorOutput.builder().mlModelOutputs(updatedOutputs).build();
            return (Response)new MLTaskResponse((MLOutput)updatedOutput);
        }
        return response;
    }

    private MLTaskResponse createErrorResponse(Exception error) {
        String errorMessage = error.getMessage();
        if (errorMessage == null || errorMessage.trim().isEmpty()) {
            errorMessage = "Request failed";
        }
        LinkedHashMap<String, Object> errorData = new LinkedHashMap<String, Object>();
        errorData.put("error", errorMessage);
        errorData.put("is_last", true);
        ModelTensor errorTensor = ModelTensor.builder().name("error").dataAsMap(errorData).build();
        ModelTensors errorTensors = ModelTensors.builder().mlModelTensors(List.of(errorTensor)).build();
        ModelTensorOutput errorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(errorTensors)).build();
        return MLTaskResponse.builder().output((MLOutput)errorOutput).build();
    }
}

