/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.tasks;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.task.MLCancelBatchJobRequest;
import org.opensearch.ml.common.transport.task.MLCancelBatchJobResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLNodeUtils;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;
import org.opensearch.transport.client.Client;

public class CancelBatchJobTransportAction
extends HandledTransportAction<ActionRequest, MLCancelBatchJobResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(CancelBatchJobTransportAction.class);
    Client client;
    NamedXContentRegistry xContentRegistry;
    ClusterService clusterService;
    ScriptService scriptService;
    ConnectorAccessControlHelper connectorAccessControlHelper;
    ModelAccessControlHelper modelAccessControlHelper;
    EncryptorImpl encryptor;
    MLModelManager mlModelManager;
    MLTaskManager mlTaskManager;
    private MLFeatureEnabledSetting mlFeatureEnabledSetting;

    @Inject
    public CancelBatchJobTransportAction(TransportService transportService, ActionFilters actionFilters, Client client, NamedXContentRegistry xContentRegistry, ClusterService clusterService, ScriptService scriptService, ConnectorAccessControlHelper connectorAccessControlHelper, ModelAccessControlHelper modelAccessControlHelper, EncryptorImpl encryptor, MLTaskManager mlTaskManager, MLModelManager mlModelManager, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
        super("cluster:admin/opensearch/ml/tasks/cancel", transportService, actionFilters, MLCancelBatchJobRequest::new);
        this.client = client;
        this.xContentRegistry = xContentRegistry;
        this.clusterService = clusterService;
        this.scriptService = scriptService;
        this.connectorAccessControlHelper = connectorAccessControlHelper;
        this.modelAccessControlHelper = modelAccessControlHelper;
        this.encryptor = encryptor;
        this.mlTaskManager = mlTaskManager;
        this.mlModelManager = mlModelManager;
        this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLCancelBatchJobResponse> actionListener) {
        MLCancelBatchJobRequest mlCancelBatchJobRequest = MLCancelBatchJobRequest.fromActionRequest((ActionRequest)request);
        String taskId = mlCancelBatchJobRequest.getTaskId();
        GetRequest getRequest = new GetRequest(".plugins-ml-task").id(taskId);
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            this.client.get(getRequest, ActionListener.runBefore((ActionListener)ActionListener.wrap(r -> {
                log.debug("Completed Get Task Request, id:{}", (Object)taskId);
                if (r != null && r.isExists()) {
                    try (XContentParser parser = MLNodeUtils.createXContentParserFromRegistry(this.xContentRegistry, r.getSourceAsBytesRef());){
                        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
                        MLTask mlTask = MLTask.parse((XContentParser)parser);
                        if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && !this.mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled().booleanValue()) {
                            throw new IllegalStateException("Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.");
                        }
                        if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
                            this.processRemoteBatchPrediction(mlTask, actionListener);
                        } else {
                            actionListener.onFailure((Exception)new IllegalArgumentException("The task ID you provided does not have any associated batch job"));
                        }
                    }
                    catch (Exception e) {
                        log.error("Failed to parse ml task {}", (Object)r.getId(), (Object)e);
                        actionListener.onFailure(e);
                    }
                } else {
                    actionListener.onFailure((Exception)new OpenSearchStatusException("Fail to find task", RestStatus.NOT_FOUND, new Object[0]));
                }
            }, e -> {
                if (e instanceof IndexNotFoundException) {
                    actionListener.onFailure((Exception)new MLResourceNotFoundException("Fail to find task"));
                } else {
                    log.error("Failed to get ML task {}", (Object)taskId, e);
                    actionListener.onFailure(e);
                }
            }), () -> ((ThreadContext.StoredContext)context).restore()));
        }
        catch (Exception e2) {
            log.error("Failed to get ML task {}", (Object)taskId, (Object)e2);
            actionListener.onFailure(e2);
        }
    }

    private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancelBatchJobResponse> actionListener) {
        Map remoteJob = mlTask.getRemoteJob();
        HashMap<String, String> parameters = new HashMap<String, String>();
        for (Map.Entry entry : remoteJob.entrySet()) {
            if (entry.getValue() instanceof String) {
                parameters.put((String)entry.getKey(), (String)entry.getValue());
                continue;
            }
            log.debug("Value for key {} is not a String", entry.getKey());
        }
        parameters.computeIfAbsent("TransformJobName", key -> Optional.ofNullable((String)parameters.get("TransformJobArn")).map(jobArn -> jobArn.substring(jobArn.lastIndexOf("/") + 1)).orElse(null));
        RemoteInferenceInputDataSet inferenceInputDataSet = new RemoteInferenceInputDataSet(parameters, ConnectorAction.ActionType.BATCH_PREDICT_STATUS, null);
        MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset((MLInputDataset)inferenceInputDataSet).build();
        String modelId = mlTask.getModelId();
        User user = RestActionUtils.getUserContext(this.client);
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener getModelListener = ActionListener.wrap(model -> this.modelAccessControlHelper.validateModelGroupAccess(user, model.getModelGroupId(), this.client, (ActionListener<Boolean>)ActionListener.wrap(access -> {
                if (!access.booleanValue()) {
                    actionListener.onFailure((Exception)new MLValidationException("You don't have permission to cancel this batch job"));
                } else if (model.getConnector() != null) {
                    Connector connector2 = model.getConnector();
                    this.executeConnector(connector2, mlInput, actionListener);
                } else if (this.clusterService.state().metadata().hasIndex(".plugins-ml-connector")) {
                    ActionListener listener = ActionListener.wrap(connector -> this.executeConnector((Connector)connector, mlInput, actionListener), e -> {
                        log.error("Failed to get connector {}", (Object)model.getConnectorId(), e);
                        actionListener.onFailure(e);
                    });
                    try (ThreadContext.StoredContext threadContext = this.client.threadPool().getThreadContext().stashContext();){
                        this.connectorAccessControlHelper.getConnector(this.client, model.getConnectorId(), (ActionListener<Connector>)ActionListener.runBefore((ActionListener)listener, () -> ((ThreadContext.StoredContext)threadContext).restore()));
                    }
                } else {
                    actionListener.onFailure((Exception)new ResourceNotFoundException("Can't find connector " + model.getConnectorId(), new Object[0]));
                }
            }, e -> {
                log.error("Failed to validate Access for Model Group " + model.getModelGroupId(), (Throwable)e);
                actionListener.onFailure(e);
            })), e -> {
                log.error("Failed to retrieve the ML model with the given ID", (Throwable)e);
                actionListener.onFailure((Exception)new OpenSearchStatusException("Failed to retrieve the ML model for the given task ID", RestStatus.NOT_FOUND, new Object[0]));
            });
            this.mlModelManager.getModel(modelId, null, null, (ActionListener<MLModel>)ActionListener.runBefore((ActionListener)getModelListener, () -> ((ThreadContext.StoredContext)context).restore()));
        }
        catch (Exception e2) {
            log.error("Unable to fetch cancel batch job in ml task ", (Throwable)e2);
            throw new OpenSearchException("Unable to fetch cancel batch job in ml task " + e2.getMessage(), new Object[0]);
        }
    }

    private void executeConnector(Connector connector, MLInput mlInput, ActionListener<MLCancelBatchJobResponse> actionListener) {
        Optional cancelBatchPredictAction = connector.findAction(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.name());
        if (cancelBatchPredictAction.isEmpty() || ((ConnectorAction)cancelBatchPredictAction.get()).getRequestBody() == null) {
            ConnectorAction connectorAction = ConnectorUtils.createConnectorAction((Connector)connector, (ConnectorAction.ActionType)ConnectorAction.ActionType.CANCEL_BATCH_PREDICT);
            connector.addAction(connectorAction);
        }
        connector.decrypt(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.name(), (credential, tenantId) -> this.encryptor.decrypt(credential, null), null);
        RemoteConnectorExecutor connectorExecutor = (RemoteConnectorExecutor)MLEngineClassLoader.initInstance((Object)connector.getProtocol(), (Object)connector, Connector.class);
        connectorExecutor.setScriptService(this.scriptService);
        connectorExecutor.setClusterService(this.clusterService);
        connectorExecutor.setClient(this.client);
        connectorExecutor.setXContentRegistry(this.xContentRegistry);
        connectorExecutor.executeAction(ConnectorAction.ActionType.CANCEL_BATCH_PREDICT.name(), mlInput, ActionListener.wrap(taskResponse -> this.processTaskResponse((MLTaskResponse)taskResponse, actionListener), arg_0 -> actionListener.onFailure(arg_0)));
    }

    private void processTaskResponse(MLTaskResponse taskResponse, ActionListener<MLCancelBatchJobResponse> actionListener) {
        try {
            ModelTensorOutput tensorOutput = (ModelTensorOutput)taskResponse.getOutput();
            if (tensorOutput != null && tensorOutput.getMlModelOutputs() != null && !tensorOutput.getMlModelOutputs().isEmpty()) {
                ModelTensors modelOutput = (ModelTensors)tensorOutput.getMlModelOutputs().get(0);
                if (modelOutput.getStatusCode() != null && modelOutput.getStatusCode().equals(200)) {
                    actionListener.onResponse((Object)new MLCancelBatchJobResponse(RestStatus.OK));
                } else {
                    log.debug("The status code from remote service is: {}", (Object)modelOutput.getStatusCode());
                    actionListener.onFailure((Exception)new OpenSearchException("Couldn't cancel the transform job. Please try again", new Object[0]));
                }
            } else {
                log.debug("ML Model Outputs are null or empty.");
                actionListener.onFailure((Exception)new ResourceNotFoundException("Couldn't fetch status of the transform job", new Object[0]));
            }
        }
        catch (Exception e) {
            log.error("Unable to fetch status for ml task ", (Throwable)e);
        }
    }
}

