/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.DeepBoofConstants;
import deepboof.Tensor;
import deepboof.backward.DBatchNorm;
import deepboof.impl.backward.standard.BaseDFunction;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public abstract class BaseDBatchNorm_F64
extends BaseDFunction<Tensor_F64>
implements DBatchNorm<Tensor_F64> {
    protected boolean requiresGammaBeta;
    protected Tensor_F64 tensorMean = new Tensor_F64();
    protected Tensor_F64 tensorStd = new Tensor_F64();
    protected Tensor_F64 tensorXhat = new Tensor_F64();
    protected Tensor_F64 tensorDVar = new Tensor_F64();
    protected Tensor_F64 tensorDMean = new Tensor_F64();
    protected Tensor_F64 tensorDXhat = new Tensor_F64();
    protected Tensor_F64 tensorDiffX = new Tensor_F64();
    protected Tensor_F64 tensorTmp = new Tensor_F64();
    protected int[] shapeVariables;
    protected int D;
    protected Tensor_F64 params = new Tensor_F64(0);
    protected double EPS = DeepBoofConstants.TEST_TOL_F64 * 0.1;

    public BaseDBatchNorm_F64(boolean requiresGammaBeta) {
        this.requiresGammaBeta = requiresGammaBeta;
    }

    @Override
    public void _initialize() {
        this.shapeVariables = this.createShapeVariables(this.shapeInput);
        this.tensorMean.reshape(this.shapeVariables);
        this.tensorStd.reshape(this.shapeVariables);
        this.tensorDVar.reshape(this.shapeVariables);
        this.tensorDMean.reshape(this.shapeVariables);
        this.tensorTmp.reshape(this.shapeVariables);
        this.shapeOutput = (int[])this.shapeInput.clone();
        if (this.requiresGammaBeta) {
            int[] shapeParam = TensorOps.WI(this.shapeVariables, 2);
            this.shapeParameters.add(shapeParam);
            this.params.reshape(shapeParam);
        }
        this.D = TensorOps.tensorLength(this.shapeVariables);
    }

    protected abstract int[] createShapeVariables(int[] var1);

    @Override
    public void _setParameters(List<Tensor_F64> parameters) {
        if (this.requiresGammaBeta) {
            this.params.setTo((Tensor)parameters.get(0));
        } else if (parameters.size() != 0) {
            throw new IllegalArgumentException("There are no parameters since gamma and beta have been turned off");
        }
    }

    @Override
    public double getEPS() {
        return this.EPS;
    }

    @Override
    public void setEPS(double EPS) {
        this.EPS = EPS;
    }

    @Override
    public boolean hasGammaBeta() {
        return this.requiresGammaBeta;
    }

    @Override
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }

    @Override
    public Tensor_F64 getMean(Tensor_F64 output) {
        if (output == null) {
            output = (Tensor_F64)this.tensorMean.createLike();
        }
        output.setTo(this.tensorMean);
        return output;
    }

    @Override
    public Tensor_F64 getVariance(Tensor_F64 output) {
        if (output == null) {
            output = (Tensor_F64)this.tensorStd.createLike();
        }
        output.reshape(this.tensorStd.getShape());
        int indexOut = output.startIndex;
        int indexStd = 0;
        int length = this.tensorStd.length();
        for (int i = 0; i < length; ++i) {
            double d = this.tensorStd.d[indexStd++];
            output.d[indexOut++] = d * d - this.EPS;
        }
        return output;
    }
}

