/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.backward.DFunctionBatchNorm;
import deepboof.impl.backward.standard.BaseDBatchNorm_F64;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class DFunctionBatchNorm_F64
extends BaseDBatchNorm_F64
implements DFunctionBatchNorm<Tensor_F64> {
    public DFunctionBatchNorm_F64(boolean requiresGammaBeta) {
        super(requiresGammaBeta);
    }

    @Override
    protected int[] createShapeVariables(int[] shapeInput) {
        return shapeInput;
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        if (input.length(0) <= 1) {
            throw new IllegalArgumentException("There must be more than 1 minibatch");
        }
        if (this.learningMode) {
            this.forwardsLearning(input, output);
        } else {
            this.forwardsEvaluate(input, output);
        }
    }

    private void forwardsLearning(Tensor_F64 input, Tensor_F64 output) {
        this.tensorDiffX.reshape(input.shape);
        this.tensorXhat.reshape(input.shape);
        this.computeStatisticsAndNormalize(input);
        if (this.requiresGammaBeta) {
            this.applyGammaBeta(output);
        } else {
            output.setTo(this.tensorXhat);
        }
    }

    public void forwardsEvaluate(Tensor_F64 input, Tensor_F64 output) {
        int D2 = TensorOps.outerLength(input.shape, 1);
        int indexIn = input.startIndex;
        int indexOut = output.startIndex;
        if (this.requiresGammaBeta) {
            for (int batch = 0; batch < this.miniBatchSize; ++batch) {
                int indexVar = 0;
                int indexP = this.params.startIndex;
                int end = indexIn + D2;
                while (indexIn < end) {
                    double mean = this.tensorMean.d[indexVar];
                    double stdev_eps = this.tensorStd.d[indexVar];
                    double gamma = this.params.d[indexP++];
                    double beta = this.params.d[indexP++];
                    output.d[indexOut++] = (input.d[indexIn++] - mean) * (gamma / stdev_eps) + beta;
                    ++indexVar;
                }
            }
        } else {
            for (int stack = 0; stack < this.miniBatchSize; ++stack) {
                int indexVar = 0;
                int end = indexIn + D2;
                while (indexIn < end) {
                    double mean = this.tensorMean.d[indexVar];
                    double stdev_eps = this.tensorStd.d[indexVar];
                    output.d[indexOut++] = (input.d[indexIn++] - mean) / stdev_eps;
                    ++indexVar;
                }
            }
        }
    }

    private void applyGammaBeta(Tensor_F64 output) {
        int indexOut = output.startIndex;
        int indexTensor = 0;
        int end = this.params.length();
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexParam = this.params.startIndex;
            while (indexParam < end) {
                double gamma = this.params.d[indexParam++];
                double beta = this.params.d[indexParam++];
                output.d[indexOut++] = gamma * this.tensorXhat.d[indexTensor++] + beta;
            }
        }
    }

    private void computeStatisticsAndNormalize(Tensor_F64 input) {
        int indexVar;
        int stack;
        int indexVar2;
        this.tensorMean.zero();
        this.tensorStd.zero();
        this.tensorXhat.zero();
        double M_var = this.miniBatchSize - 1;
        int indexIn = input.startIndex;
        for (int stack2 = 0; stack2 < this.miniBatchSize; ++stack2) {
            indexVar2 = 0;
            while (indexVar2 < this.D) {
                int n = indexVar2++;
                this.tensorMean.d[n] = this.tensorMean.d[n] + input.d[indexIn++];
            }
        }
        int indexVar3 = 0;
        while (indexVar3 < this.D) {
            int n = indexVar3++;
            this.tensorMean.d[n] = this.tensorMean.d[n] / (double)this.miniBatchSize;
        }
        indexIn = input.startIndex;
        int indexTensor = 0;
        for (stack = 0; stack < this.miniBatchSize; ++stack) {
            indexVar = 0;
            while (indexVar < this.D) {
                double d;
                this.tensorDiffX.d[indexTensor] = d = input.d[indexIn++] - this.tensorMean.d[indexVar];
                int n = indexVar++;
                this.tensorStd.d[n] = this.tensorStd.d[n] + d * d;
                ++indexTensor;
            }
        }
        for (indexVar2 = 0; indexVar2 < this.D; ++indexVar2) {
            this.tensorStd.d[indexVar2] = Math.sqrt(this.tensorStd.d[indexVar2] / M_var + this.EPS);
        }
        indexTensor = 0;
        for (stack = 0; stack < this.miniBatchSize; ++stack) {
            indexVar = 0;
            while (indexVar < this.D) {
                this.tensorXhat.d[indexTensor] = this.tensorDiffX.d[indexTensor] / this.tensorStd.d[indexVar];
                ++indexVar;
                ++indexTensor;
            }
        }
    }

    @Override
    protected void _backwards(Tensor_F64 input, Tensor_F64 dout, Tensor_F64 gradientInput, List<Tensor_F64> gradientParameters) {
        this.tensorDXhat.reshape(input.shape);
        if (this.requiresGammaBeta) {
            this.partialXHat(dout);
        } else {
            this.tensorDXhat.setTo(dout);
        }
        this.partialVariance();
        this.partialMean();
        this.partialX(gradientInput);
        if (this.requiresGammaBeta) {
            this.partialParameters(gradientParameters.get(0), dout);
        }
    }

    private void partialParameters(Tensor_F64 tensorDParam, Tensor_F64 dout) {
        tensorDParam.zero();
        int indexDOut = dout.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexDParam = 0;
            int indexVar = 0;
            while (indexVar < this.D) {
                double d = dout.d[indexDOut];
                int n = indexDParam++;
                tensorDParam.d[n] = tensorDParam.d[n] + d * this.tensorXhat.d[indexTensor];
                int n2 = indexDParam++;
                tensorDParam.d[n2] = tensorDParam.d[n2] + d;
                ++indexVar;
                ++indexTensor;
                ++indexDOut;
            }
        }
    }

    private void partialXHat(Tensor_F64 dout) {
        int indexDOut = dout.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexVar = 0;
            while (indexVar < this.D) {
                this.tensorDXhat.d[indexTensor] = dout.d[indexDOut++] * this.params.d[indexVar * 2];
                ++indexVar;
                ++indexTensor;
            }
        }
    }

    private void partialX(Tensor_F64 tensorDX) {
        double M_var = this.miniBatchSize - 1;
        int indexDX = tensorDX.startIndex;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexVar = 0;
            while (indexVar < this.D) {
                double val = this.tensorDXhat.d[indexTensor] / this.tensorStd.d[indexVar];
                tensorDX.d[indexDX] = val += this.tensorDVar.d[indexVar] * 2.0 * this.tensorDiffX.d[indexTensor] / M_var + this.tensorDMean.d[indexVar] / (double)this.miniBatchSize;
                ++indexVar;
                ++indexTensor;
                ++indexDX;
            }
        }
    }

    private void partialMean() {
        this.tensorDMean.zero();
        this.tensorTmp.zero();
        double M_var = this.miniBatchSize - 1;
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexVar = 0;
            while (indexVar < this.D) {
                int n = indexVar;
                this.tensorTmp.d[n] = this.tensorTmp.d[n] + this.tensorDiffX.d[indexTensor];
                int n2 = indexVar++;
                this.tensorDMean.d[n2] = this.tensorDMean.d[n2] - this.tensorDXhat.d[indexTensor];
                ++indexTensor;
            }
        }
        for (int indexVar = 0; indexVar < this.D; ++indexVar) {
            int n = indexVar;
            this.tensorDMean.d[n] = this.tensorDMean.d[n] / this.tensorStd.d[indexVar];
            int n3 = indexVar;
            this.tensorDMean.d[n3] = this.tensorDMean.d[n3] - 2.0 * this.tensorDVar.d[indexVar] * this.tensorTmp.d[indexVar] / M_var;
        }
    }

    private void partialVariance() {
        this.tensorDVar.zero();
        int indexTensor = 0;
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            int indexVar = 0;
            while (indexVar < this.D) {
                int n = indexVar++;
                this.tensorDVar.d[n] = this.tensorDVar.d[n] + this.tensorDXhat.d[indexTensor] * this.tensorDiffX.d[indexTensor];
                ++indexTensor;
            }
        }
        int indexVar = 0;
        while (indexVar < this.D) {
            double sigmaPow3 = this.tensorStd.d[indexVar];
            sigmaPow3 = sigmaPow3 * sigmaPow3 * sigmaPow3;
            int n = indexVar++;
            this.tensorDVar.d[n] = this.tensorDVar.d[n] / (-2.0 * sigmaPow3);
        }
    }
}

