/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.DeepBoofConstants;
import deepboof.Function;
import deepboof.backward.NumericalGradient;
import deepboof.misc.TensorOps;
import deepboof.misc.TensorOps_F64;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class NumericalGradient_F64
implements NumericalGradient<Tensor_F64> {
    Function<Tensor_F64> function;
    double T = DeepBoofConstants.TEST_TOL_A_F64;
    Tensor_F64 output = new Tensor_F64();
    Tensor_F64 input;
    List<Tensor_F64> parameters;

    @Override
    public void configure(double T) {
        if (T <= 0.0) {
            throw new IllegalArgumentException("T must be > 0");
        }
        this.T = T;
    }

    @Override
    public void setFunction(Function<Tensor_F64> function) {
        this.function = function;
    }

    @Override
    public void differentiate(Tensor_F64 input, List<Tensor_F64> parameters, Tensor_F64 dout, Tensor_F64 gradientInput, List<Tensor_F64> gradientParameters) {
        int N = input.length(0);
        this.output.reshape(TensorOps.WI(N, this.function.getOutputShape()));
        this.input = input;
        this.parameters = parameters;
        this.process(input, dout, gradientInput);
        for (int i = 0; i < parameters.size(); ++i) {
            this.process(parameters.get(i), dout, gradientParameters.get(i));
        }
    }

    private void process(Tensor_F64 target, Tensor_F64 dout, Tensor_F64 gradientTarget) {
        int length = target.length();
        for (int i = 0; i < length; ++i) {
            int indexTarget = target.startIndex + i;
            double v = target.d[indexTarget];
            target.d[indexTarget] = v + this.T;
            this.function.setParameters(this.parameters);
            this.function.forward(this.input, this.output);
            TensorOps_F64.elementMult(this.output, dout, this.output);
            double plus_T = TensorOps_F64.elementSum(this.output);
            target.d[indexTarget] = v - this.T;
            this.function.setParameters(this.parameters);
            this.function.forward(this.input, this.output);
            TensorOps_F64.elementMult(this.output, dout, this.output);
            double minus_T = TensorOps_F64.elementSum(this.output);
            target.d[indexTarget] = v;
            int indexGradient = gradientTarget.startIndex + i;
            gradientTarget.d[indexGradient] = (plus_T - minus_T) / (2.0 * this.T);
        }
    }
}

