/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.sparse.data;

import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.neuralsearch.sparse.common.IteratorWrapper;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizationUtil;
import org.opensearch.neuralsearch.sparse.quantization.ByteQuantizer;

public class SparseVector
implements Accountable {
    private final short[] tokens;
    private final byte[] weights;

    public SparseVector(BytesRef bytesRef, ByteQuantizer byteQuantizer) throws IOException {
        this(SparseVector.readToMap(bytesRef), byteQuantizer);
    }

    public int getSize() {
        return this.tokens == null ? 0 : this.tokens.length;
    }

    public SparseVector(Map<Integer, Float> pairs, ByteQuantizer byteQuantizer) {
        this(pairs.entrySet().stream().map(t -> new Item((Integer)t.getKey(), byteQuantizer.quantize(((Float)t.getValue()).floatValue()))).collect(Collectors.toList()));
    }

    public SparseVector(List<Item> items) {
        List<Item> processedItems = this.processListItems(items);
        int size = processedItems.size();
        this.tokens = new short[size];
        this.weights = new byte[size];
        for (int i = 0; i < size; ++i) {
            this.tokens[i] = (short)processedItems.get(i).getToken();
            this.weights[i] = processedItems.get(i).getWeight();
        }
    }

    private List<Item> processListItems(List<Item> items) {
        ArrayList<Item> processedItems = new ArrayList<Item>();
        if (items.isEmpty()) {
            return processedItems;
        }
        items.sort(Comparator.comparingInt(item -> SparseVector.prepareTokenForShortType(item.getToken())));
        processedItems.add(new Item(SparseVector.prepareTokenForShortType(items.getFirst().getToken()), items.getFirst().getWeight()));
        for (int i = 1; i < items.size(); ++i) {
            int token = SparseVector.prepareTokenForShortType(items.get(i).getToken());
            if (token == ((Item)processedItems.getLast()).getToken()) {
                if (ByteQuantizationUtil.compareUnsignedByte(((Item)processedItems.getLast()).weight, items.get(i).getWeight()) >= 0) continue;
                ((Item)processedItems.getLast()).weight = items.get(i).getWeight();
                continue;
            }
            processedItems.add(new Item(token, items.get(i).getWeight()));
        }
        return processedItems;
    }

    public static int prepareTokenForShortType(int token) {
        return token % 65536;
    }

    private static Map<Integer, Float> readToMap(BytesRef bytesRef) throws IOException {
        HashMap<Integer, Float> map = new HashMap<Integer, Float>();
        try (ByteArrayInputStream bais = new ByteArrayInputStream(ArrayUtil.copyOfSubArray((byte[])bytesRef.bytes, (int)bytesRef.offset, (int)bytesRef.length));
             DataInputStream dis = new DataInputStream(bais);){
            while (bais.available() > 0) {
                int key = dis.readInt();
                float value = dis.readFloat();
                map.put(key, Float.valueOf(value));
            }
        }
        return map;
    }

    public byte[] toDenseVector() {
        int size = this.getSize();
        if (size == 0) {
            return new byte[0];
        }
        short maxToken = this.tokens[size - 1];
        byte[] denseVector = new byte[maxToken + 1];
        for (int i = 0; i < size; ++i) {
            denseVector[this.tokens[i]] = this.weights[i];
        }
        return denseVector;
    }

    public int dotProduct(byte[] denseVector) {
        int i;
        int score = 0;
        int size = this.getSize();
        if (size == 0 || denseVector == null || denseVector.length == 0) {
            return 0;
        }
        int unrollFactor = 4;
        int limit = size - size % 4;
        for (i = 0; i < limit && this.tokens[i] < denseVector.length; i += 4) {
            score += ByteQuantizationUtil.multiplyUnsignedByte(this.weights[i], denseVector[this.tokens[i]]);
            if (this.tokens[i + 1] >= denseVector.length) {
                ++i;
                break;
            }
            score += ByteQuantizationUtil.multiplyUnsignedByte(this.weights[i + 1], denseVector[this.tokens[i + 1]]);
            if (this.tokens[i + 2] >= denseVector.length) {
                i += 2;
                break;
            }
            score += ByteQuantizationUtil.multiplyUnsignedByte(this.weights[i + 2], denseVector[this.tokens[i + 2]]);
            if (this.tokens[i + 3] >= denseVector.length) {
                i += 3;
                break;
            }
            score += ByteQuantizationUtil.multiplyUnsignedByte(this.weights[i + 3], denseVector[this.tokens[i + 3]]);
        }
        while (i < size && this.tokens[i] < denseVector.length) {
            score += ByteQuantizationUtil.multiplyUnsignedByte(this.weights[i], denseVector[this.tokens[i]]);
            ++i;
        }
        return score;
    }

    public IteratorWrapper<Item> iterator() {
        return new IteratorWrapper<Item>(new Iterator<Item>(){
            private int size;
            private int current;
            {
                this.size = SparseVector.this.getSize();
                this.current = -1;
            }

            @Override
            public boolean hasNext() {
                return this.current + 1 < this.size;
            }

            @Override
            public Item next() {
                if (!this.hasNext()) {
                    return null;
                }
                ++this.current;
                return new Item(SparseVector.this.tokens[this.current], SparseVector.this.weights[this.current]);
            }
        });
    }

    public long ramBytesUsed() {
        return RamUsageEstimator.shallowSizeOfInstance(SparseVector.class) + RamUsageEstimator.sizeOf((short[])this.tokens) + RamUsageEstimator.sizeOf((byte[])this.weights);
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SparseVector)) {
            return false;
        }
        SparseVector other = (SparseVector)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!Arrays.equals(this.tokens, other.tokens)) {
            return false;
        }
        return Arrays.equals(this.weights, other.weights);
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof SparseVector;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + Arrays.hashCode(this.tokens);
        result = result * 59 + Arrays.hashCode(this.weights);
        return result;
    }

    public static class Item {
        int token;
        byte weight;

        static Item of(int token, byte weight) {
            return new Item(token, weight);
        }

        public int getIntWeight() {
            return ByteQuantizationUtil.getUnsignedByte(this.weight);
        }

        @Generated
        public Item(int token, byte weight) {
            this.token = token;
            this.weight = weight;
        }

        @Generated
        public int getToken() {
            return this.token;
        }

        @Generated
        public byte getWeight() {
            return this.weight;
        }

        @Generated
        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Item)) {
                return false;
            }
            Item other = (Item)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getToken() != other.getToken()) {
                return false;
            }
            return this.getWeight() == other.getWeight();
        }

        @Generated
        protected boolean canEqual(Object other) {
            return other instanceof Item;
        }

        @Generated
        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getToken();
            result = result * 59 + this.getWeight();
            return result;
        }
    }
}

