/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.zoo.nlp.textgeneration;

import ai.djl.modality.nlp.generate.CausalLMOutput;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslatorContext;

public class OrtGptTranslator
implements NoBatchifyTranslator<NDList, CausalLMOutput> {
    private long kvDim;
    private int numAttentionHeads;
    private int numLayers;

    public OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) {
        this.kvDim = kvDim;
        this.numAttentionHeads = numAttentionHeads;
        this.numLayers = numLayers;
    }

    public NDList processInput(TranslatorContext ctx, NDList input) throws Exception {
        int offset;
        NDList inputNew;
        NDArray useCacheBranch;
        NDManager manager = ctx.getNDManager();
        NDArray inputIds = (NDArray)input.get(0);
        inputIds.setName("input_ids");
        NDArray attentionMask = (NDArray)input.get(2);
        attentionMask.setName("attention_mask");
        if (input.size() == 3) {
            useCacheBranch = manager.create(new boolean[]{false}, new Shape(new long[]{1L}));
            useCacheBranch.setName("use_cache_branch");
            inputNew = new NDList(new NDArray[]{inputIds, attentionMask, useCacheBranch});
            this.initialDummyPastKeyValues(inputIds, manager, inputNew);
        } else {
            useCacheBranch = manager.create(new boolean[]{true}, new Shape(new long[]{1L}));
            useCacheBranch.setName("use_cache_branch");
            inputNew = new NDList(new NDArray[]{inputIds, attentionMask, useCacheBranch});
            inputNew.addAll(input.subNDList(3));
        }
        for (int i = offset = 3; i < this.numLayers * 2 + offset; i += 2) {
            int order = (i - offset) / 2;
            ((NDArray)inputNew.get(i)).setName(String.format("past_key_values.%s.key", order));
            ((NDArray)inputNew.get(i + 1)).setName(String.format("past_key_values.%s.value", order));
        }
        return inputNew;
    }

    public CausalLMOutput processOutput(TranslatorContext ctx, NDList output) throws Exception {
        return new CausalLMOutput((NDArray)output.get(0), output.subNDList(1));
    }

    private void initialDummyPastKeyValues(NDArray inputIds, NDManager manager, NDList list) {
        long numBatch = inputIds.getShape().get(0);
        for (int i = 0; i < this.numLayers * 2; ++i) {
            NDArray array = manager.zeros(new Shape(new long[]{numBatch, this.numAttentionHeads, 1L, this.kvDim}));
            list.add((Object)array);
        }
    }
}

