/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.Tensor;

public class AdaGrad
implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(AdaGrad.class.getName());
    @Config(mandatory=true, description="Initial learning rate used to scale the gradients.")
    private double initialLearningRate;
    @Config(description="Epsilon for numerical stability around zero.")
    private double epsilon = 1.0E-6;
    @Config(description="Initial value for the gradient accumulator.")
    private double initialValue = 0.0;
    private Tensor[] gradsSquared;

    public AdaGrad(double initialLearningRate, double epsilon, double initialValue) {
        this.initialLearningRate = initialLearningRate;
        this.epsilon = epsilon;
        this.initialValue = initialValue;
    }

    public AdaGrad(double initialLearningRate, double epsilon) {
        this(initialLearningRate, epsilon, 0.0);
    }

    public AdaGrad(double initialLearningRate) {
        this(initialLearningRate, 1.0E-6, 0.0);
    }

    private AdaGrad() {
    }

    @Override
    public void initialise(Parameters parameters) {
        this.gradsSquared = parameters.getEmptyCopy();
        if (this.initialValue != 0.0) {
            for (Tensor t : this.gradsSquared) {
                t.scalarAddInPlace(this.initialValue);
            }
        }
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        DoubleUnaryOperator square = a -> weight * weight * a * a;
        DoubleUnaryOperator scale = a -> weight * this.initialLearningRate / (this.epsilon + Math.sqrt(a));
        for (int i = 0; i < updates.length; ++i) {
            Tensor curGradsSquared = this.gradsSquared[i];
            Tensor curGrad = updates[i];
            curGradsSquared.intersectAndAddInPlace(curGrad, square);
            curGrad.hadamardProductInPlace(curGradsSquared, scale);
        }
        return updates;
    }

    public String toString() {
        return "AdaGrad(initialLearningRate=" + this.initialLearningRate + ",epsilon=" + this.epsilon + ",initialValue=" + this.initialValue + ")";
    }

    @Override
    public void reset() {
        this.gradsSquared = null;
    }

    @Override
    public AdaGrad copy() {
        return new AdaGrad(this.initialLearningRate, this.epsilon);
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }
}

