/*
 * Decompiled with CFR 0.152.
 */
package deepboof.impl.backward.standard;

import deepboof.backward.DFunctionLinear;
import deepboof.impl.backward.standard.BaseDFunction;
import deepboof.impl.forward.standard.FunctionLinear_F64;
import deepboof.misc.TensorOps;
import deepboof.tensors.Tensor_F64;
import java.util.List;

public class DFunctionLinear_F64
extends BaseDFunction<Tensor_F64>
implements DFunctionLinear<Tensor_F64> {
    protected int D;
    protected int M;
    Tensor_F64 weight;
    Tensor_F64 bias;

    public DFunctionLinear_F64(int numberOfOutputs) {
        this.M = numberOfOutputs;
    }

    @Override
    public int getNumberOfOutputs() {
        return this.M;
    }

    @Override
    public void _setParameters(List<Tensor_F64> parameters) {
        this.weight = parameters.get(0);
        this.bias = parameters.get(1);
    }

    @Override
    public void _forward(Tensor_F64 input, Tensor_F64 output) {
        FunctionLinear_F64.forwards(input, output, this.weight, this.bias, this.miniBatchSize, this.D, this.M);
    }

    @Override
    protected void _backwards(Tensor_F64 input, Tensor_F64 dout, Tensor_F64 gradientInput, List<Tensor_F64> gradientParameters) {
        Tensor_F64 inputD = gradientInput;
        Tensor_F64 weightD = gradientParameters.get(0);
        Tensor_F64 biasD = gradientParameters.get(1);
        inputD.zero();
        weightD.zero();
        biasD.zero();
        for (int stack = 0; stack < this.miniBatchSize; ++stack) {
            for (int outputElement = 0; outputElement < this.M; ++outputElement) {
                int indexW = outputElement * this.D + this.weight.startIndex;
                int indexX = stack * this.D + input.startIndex;
                double val_dout = dout.get(stack, outputElement);
                int indexXD = stack * this.D + inputD.startIndex;
                int indexWD = outputElement * this.D + weightD.startIndex;
                for (int i = 0; i < this.D; ++i) {
                    int n = indexXD++;
                    inputD.d[n] = inputD.d[n] + this.weight.d[indexW + i] * val_dout;
                    int n2 = indexWD++;
                    weightD.d[n2] = weightD.d[n2] + input.d[indexX + i] * val_dout;
                }
                int n = biasD.startIndex + outputElement;
                biasD.d[n] = biasD.d[n] + val_dout;
            }
        }
    }

    @Override
    public void _initialize() {
        if (this.shapeInput.length < 1) {
            throw new IllegalArgumentException("Input tensor shape must have a dimension of at least 1");
        }
        this.D = TensorOps.tensorLength(this.shapeInput);
        this.shapeParameters.add(new int[]{this.M, this.D});
        this.shapeParameters.add(new int[]{this.M});
        this.shapeOutput = new int[]{this.M};
    }

    @Override
    public Class<Tensor_F64> getTensorType() {
        return Tensor_F64.class;
    }
}

