/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.model.perceptron.model;

import com.hankcs.hanlp.HanLP;
import com.hankcs.hanlp.algorithm.MaxHeap;
import com.hankcs.hanlp.classification.utilities.io.ConsoleLogger;
import com.hankcs.hanlp.collection.trie.datrie.MutableDoubleArrayTrieInteger;
import com.hankcs.hanlp.corpus.io.ByteArray;
import com.hankcs.hanlp.corpus.io.ByteArrayStream;
import com.hankcs.hanlp.corpus.io.ICacheAble;
import com.hankcs.hanlp.corpus.io.IOUtil;
import com.hankcs.hanlp.model.perceptron.common.TaskType;
import com.hankcs.hanlp.model.perceptron.feature.FeatureMap;
import com.hankcs.hanlp.model.perceptron.feature.FeatureSortItem;
import com.hankcs.hanlp.model.perceptron.feature.ImmutableFeatureMDatMap;
import com.hankcs.hanlp.model.perceptron.instance.Instance;
import com.hankcs.hanlp.model.perceptron.tagset.TagSet;
import com.hankcs.hanlp.utility.MathUtility;
import java.io.BufferedOutputStream;
import java.io.BufferedWriter;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.Collection;
import java.util.Comparator;
import java.util.Map;
import java.util.Set;

public class LinearModel
implements ICacheAble {
    public FeatureMap featureMap;
    public float[] parameter;

    public LinearModel(FeatureMap featureMap, float[] parameter) {
        this.featureMap = featureMap;
        this.parameter = parameter;
    }

    public LinearModel(FeatureMap featureMap) {
        this.featureMap = featureMap;
        this.parameter = new float[featureMap.size() * featureMap.tagSet.size()];
    }

    public LinearModel(String modelFile) throws IOException {
        this.load(modelFile);
    }

    public LinearModel compress(double ratio) {
        return this.compress(ratio, 0.001f);
    }

    public LinearModel compress(double ratio, double threshold) {
        if (ratio < 0.0 || ratio >= 1.0) {
            throw new IllegalArgumentException("\u538b\u7f29\u6bd4\u5fc5\u987b\u4ecb\u4e8e 0 \u548c 1 \u4e4b\u95f4");
        }
        if (ratio == 0.0) {
            return this;
        }
        Set<Map.Entry<String, Integer>> featureIdSet = this.featureMap.entrySet();
        TagSet tagSet = this.featureMap.tagSet;
        MaxHeap<FeatureSortItem> heap = new MaxHeap<FeatureSortItem>((int)((double)(featureIdSet.size() - tagSet.sizeIncludingBos()) * (1.0 - ratio)), new Comparator<FeatureSortItem>(){

            @Override
            public int compare(FeatureSortItem o1, FeatureSortItem o2) {
                return Float.compare(o1.total, o2.total);
            }
        });
        ConsoleLogger.logger.start("\u88c1\u526a\u7279\u5f81...\n", new Object[0]);
        int logEvery = (int)Math.ceil((float)this.featureMap.size() / 10000.0f);
        int n = 0;
        for (Map.Entry<String, Integer> entry : featureIdSet) {
            if (++n % logEvery == 0 || n == this.featureMap.size()) {
                ConsoleLogger.logger.out("\r%.2f%% ", MathUtility.percentage(n, this.featureMap.size()));
            }
            if (entry.getValue() < tagSet.sizeIncludingBos()) continue;
            FeatureSortItem item = new FeatureSortItem(entry, this.parameter, tagSet.size());
            if ((double)item.total < threshold) continue;
            heap.add(item);
        }
        ConsoleLogger.logger.finish("\n\u88c1\u526a\u5b8c\u6bd5\n", new Object[0]);
        int size = heap.size() + tagSet.sizeIncludingBos();
        float[] parameter = new float[size * tagSet.size()];
        MutableDoubleArrayTrieInteger mdat = new MutableDoubleArrayTrieInteger();
        for (Map.Entry<String, Integer> tag : tagSet) {
            mdat.add("BL=" + tag.getKey());
        }
        mdat.add("BL=_BL_");
        for (int i = 0; i < tagSet.size() * tagSet.sizeIncludingBos(); ++i) {
            parameter[i] = this.parameter[i];
        }
        ConsoleLogger.logger.start("\u6784\u5efa\u53cc\u6570\u7ec4trie\u6811...\n", new Object[0]);
        logEvery = (int)Math.ceil((float)heap.size() / 10000.0f);
        n = 0;
        for (FeatureSortItem item : heap) {
            if (++n % logEvery == 0 || n == heap.size()) {
                ConsoleLogger.logger.out("\r%.2f%% ", MathUtility.percentage(n, heap.size()));
            }
            int id = mdat.size();
            mdat.put(item.key, id);
            for (int i = 0; i < tagSet.size(); ++i) {
                parameter[id * tagSet.size() + i] = this.parameter[item.id * tagSet.size() + i];
            }
        }
        ConsoleLogger.logger.finish("\n\u6784\u5efa\u5b8c\u6bd5\n", new Object[0]);
        this.featureMap = new ImmutableFeatureMDatMap(mdat, tagSet);
        this.parameter = parameter;
        return this;
    }

    public void save(String modelFile) throws IOException {
        DataOutputStream out = new DataOutputStream(new BufferedOutputStream(IOUtil.newOutputStream(modelFile)));
        this.save(out);
        out.close();
    }

    public void save(String modelFile, double ratio) throws IOException {
        this.save(modelFile, this.featureMap.entrySet(), ratio);
    }

    public void save(String modelFile, Set<Map.Entry<String, Integer>> featureIdSet, double ratio) throws IOException {
        this.save(modelFile, featureIdSet, ratio, false);
    }

    public void save(String modelFile, Set<Map.Entry<String, Integer>> featureIdSet, double ratio, boolean text) throws IOException {
        float[] parameter = this.parameter;
        this.compress(ratio, 0.001f);
        DataOutputStream out = new DataOutputStream(new BufferedOutputStream(IOUtil.newOutputStream(modelFile)));
        this.save(out);
        out.close();
        if (text) {
            BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(IOUtil.newOutputStream(modelFile + ".txt"), "UTF-8"));
            TagSet tagSet = this.featureMap.tagSet;
            for (Map.Entry<String, Integer> entry : featureIdSet) {
                bw.write(entry.getKey());
                if (featureIdSet.size() == parameter.length) {
                    bw.write("\t");
                    bw.write(String.valueOf(parameter[entry.getValue()]));
                } else {
                    for (int i = 0; i < tagSet.size(); ++i) {
                        bw.write("\t");
                        bw.write(String.valueOf(parameter[entry.getValue() * tagSet.size() + i]));
                    }
                }
                bw.newLine();
            }
            bw.close();
        }
    }

    public void update(Collection<Integer> x, int y) {
        assert (y == 1 || y == -1) : "\u611f\u77e5\u673a\u7684\u6807\u7b7ey\u5fc5\u987b\u662f\u00b11";
        for (Integer f : x) {
            int n = f;
            this.parameter[n] = this.parameter[n] + (float)y;
        }
    }

    public int decode(Collection<Integer> x) {
        float y = 0.0f;
        for (Integer f : x) {
            y += this.parameter[f];
        }
        return y < 0.0f ? -1 : 1;
    }

    public double viterbiDecode(Instance instance) {
        return this.viterbiDecode(instance, instance.tagArray);
    }

    public double viterbiDecode(Instance instance, int[] guessLabel) {
        int[] allLabel = this.featureMap.allLabels();
        int bos = this.featureMap.bosTag();
        int sentenceLength = instance.tagArray.length;
        int labelSize = allLabel.length;
        int[][] preMatrix = new int[sentenceLength][labelSize];
        double[][] scoreMatrix = new double[2][labelSize];
        for (int i = 0; i < sentenceLength; ++i) {
            int _i = i & 1;
            int _i_1 = 1 - _i;
            int[] allFeature = instance.getFeatureAt(i);
            int transitionFeatureIndex = allFeature.length - 1;
            if (0 == i) {
                allFeature[transitionFeatureIndex] = bos;
                for (int j = 0; j < allLabel.length; ++j) {
                    double score;
                    preMatrix[0][j] = j;
                    scoreMatrix[0][j] = score = this.score(allFeature, j);
                }
                continue;
            }
            for (int curLabel = 0; curLabel < allLabel.length; ++curLabel) {
                double maxScore = -2.147483648E9;
                for (int preLabel = 0; preLabel < allLabel.length; ++preLabel) {
                    allFeature[transitionFeatureIndex] = preLabel;
                    double score = this.score(allFeature, curLabel);
                    double curScore = scoreMatrix[_i_1][preLabel] + score;
                    if (!(maxScore < curScore)) continue;
                    maxScore = curScore;
                    preMatrix[i][curLabel] = preLabel;
                    scoreMatrix[_i][curLabel] = maxScore;
                }
            }
        }
        int maxIndex = 0;
        double maxScore = scoreMatrix[sentenceLength - 1 & 1][0];
        for (int index = 1; index < allLabel.length; ++index) {
            if (!(maxScore < scoreMatrix[sentenceLength - 1 & 1][index])) continue;
            maxIndex = index;
            maxScore = scoreMatrix[sentenceLength - 1 & 1][index];
        }
        for (int i = sentenceLength - 1; i >= 0; --i) {
            guessLabel[i] = allLabel[maxIndex];
            maxIndex = preMatrix[i][maxIndex];
        }
        return maxScore;
    }

    public double score(int[] featureVector, int currentTag) {
        double score = 0.0;
        for (int index : featureVector) {
            if (index == -1) continue;
            if (index < -1 || index >= this.featureMap.size()) {
                throw new IllegalArgumentException("\u5728\u6253\u5206\u65f6\u4f20\u5165\u4e86\u975e\u6cd5\u7684\u4e0b\u6807");
            }
            index = index * this.featureMap.tagSet.size() + currentTag;
            score += (double)this.parameter[index];
        }
        return score;
    }

    public void load(String modelFile) throws IOException {
        ByteArrayStream byteArray;
        if (HanLP.Config.DEBUG) {
            ConsoleLogger.logger.start("\u52a0\u8f7d %s ... ", modelFile);
        }
        if (!this.load(byteArray = ByteArrayStream.createByteArrayStream(modelFile))) {
            throw new IOException(String.format("%s \u52a0\u8f7d\u5931\u8d25", modelFile));
        }
        if (HanLP.Config.DEBUG) {
            ConsoleLogger.logger.finish(" \u52a0\u8f7d\u5b8c\u6bd5\n", new Object[0]);
        }
    }

    public TagSet tagSet() {
        return this.featureMap.tagSet;
    }

    @Override
    public void save(DataOutputStream out) throws IOException {
        if (!(this.featureMap instanceof ImmutableFeatureMDatMap)) {
            this.featureMap = new ImmutableFeatureMDatMap(this.featureMap.entrySet(), this.tagSet());
        }
        this.featureMap.save(out);
        for (float aParameter : this.parameter) {
            out.writeFloat(aParameter);
        }
    }

    @Override
    public boolean load(ByteArray byteArray) {
        if (byteArray == null) {
            return false;
        }
        this.featureMap = new ImmutableFeatureMDatMap();
        this.featureMap.load(byteArray);
        int size = this.featureMap.size();
        TagSet tagSet = this.featureMap.tagSet;
        if (tagSet.type == TaskType.CLASSIFICATION) {
            this.parameter = new float[size];
            for (int i = 0; i < size; ++i) {
                this.parameter[i] = byteArray.nextFloat();
            }
        } else {
            this.parameter = new float[size * tagSet.size()];
            for (int i = 0; i < size; ++i) {
                for (int j = 0; j < tagSet.size(); ++j) {
                    this.parameter[i * tagSet.size() + j] = byteArray.nextFloat();
                }
            }
        }
        if (!byteArray.hasMore()) {
            byteArray.close();
        }
        return true;
    }

    public TaskType taskType() {
        return this.featureMap.tagSet.type;
    }
}

