package org.thema.graphab.model;

import java.util.Arrays;
import org.apache.commons.math.FunctionEvaluationException;
import org.apache.commons.math.MathException;
import org.apache.commons.math.analysis.MultivariateRealFunction;
import org.apache.commons.math.analysis.UnivariateRealFunction;
import org.apache.commons.math.distribution.ChiSquaredDistributionImpl;
import org.apache.commons.math.linear.ArrayRealVector;
import org.apache.commons.math.linear.LUDecompositionImpl;
import org.apache.commons.math.linear.MatrixUtils;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.commons.math.linear.RealVector;

/* loaded from: input_file:org/thema/graphab/model/Logistic.class */
public class Logistic {
    private static int maxIter = 500;
    private static double epsilon = 1.0E-10d;
    private RealVector params;
    private RealMatrix A;
    private RealVector Y;
    private int nVar;
    private int n;
    private LogisticFunction estim;
    private Logistic constLog;

    /* loaded from: input_file:org/thema/graphab/model/Logistic$LogisticFunction.class */
    public static class LogisticFunction implements MultivariateRealFunction {
        private RealVector beta;

        public LogisticFunction(double[] dArr) {
            this.beta = MatrixUtils.createRealVector(dArr);
        }

        @Override // org.apache.commons.math.analysis.MultivariateRealFunction
        public double value(double[] dArr) {
            return 1.0d / (1.0d + Math.exp(-this.beta.dotProduct(dArr)));
        }
    }

    public Logistic(double[][] dArr, double[] dArr2) {
        this.Y = MatrixUtils.createRealVector(dArr2);
        this.nVar = dArr[0].length;
        this.n = dArr.length;
        this.A = MatrixUtils.createRealMatrix(this.n, this.nVar + 1);
        this.A.setSubMatrix(dArr, 0, 1);
        for (int i = 0; i < this.n; i++) {
            this.A.setEntry(i, 0, 1.0d);
        }
    }

    private Logistic(double[] dArr) {
        this.Y = MatrixUtils.createRealVector(dArr);
        this.nVar = 0;
        this.n = dArr.length;
        this.A = MatrixUtils.createRealMatrix(this.n, 1);
        for (int i = 0; i < this.n; i++) {
            this.A.setEntry(i, 0, 1.0d);
        }
    }

    public double[] getCoefs() {
        return this.params.getData();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v49, types: [org.apache.commons.math.linear.RealVector] */
    public double[] regression() throws FunctionEvaluationException {
        ArrayRealVector arrayRealVector = new ArrayRealVector(this.nVar + 1);
        double[] dArr = new double[this.nVar + 1];
        Arrays.fill(dArr, 1.0E-5d);
        RealMatrix createRealDiagonalMatrix = MatrixUtils.createRealDiagonalMatrix(dArr);
        RealVector realVector = null;
        for (int i = 0; i < maxIter; i++) {
            RealVector operate = this.A.operate(arrayRealVector);
            RealVector copy = operate.copy();
            copy.mapToSelf(new UnivariateRealFunction() { // from class: org.thema.graphab.model.Logistic.1
                @Override // org.apache.commons.math.analysis.UnivariateRealFunction
                public double value(double d) {
                    return 1.0d / (1.0d + Math.exp(-d));
                }
            });
            RealVector copy2 = copy.copy();
            copy2.mapToSelf(new UnivariateRealFunction() { // from class: org.thema.graphab.model.Logistic.2
                @Override // org.apache.commons.math.analysis.UnivariateRealFunction
                public double value(double d) {
                    return d * (1.0d - d);
                }
            });
            RealMatrix createColumnRealMatrix = MatrixUtils.createColumnRealMatrix(copy2.ebeMultiply(operate).add(this.Y.subtract(copy)).getData());
            RealMatrix createRealDiagonalMatrix2 = MatrixUtils.createRealDiagonalMatrix(copy2.getData());
            RealMatrix transpose = this.A.transpose();
            arrayRealVector = new LUDecompositionImpl(transpose.multiply(createRealDiagonalMatrix2).multiply(this.A).add(createRealDiagonalMatrix)).getSolver().getInverse().multiply(transpose).multiply(createColumnRealMatrix).getColumnVector(0);
            if (realVector != null) {
                double l1Norm = copy.subtract(realVector).getL1Norm();
                if (Double.isNaN(l1Norm) || l1Norm < this.n * epsilon) {
                    break;
                }
            }
            realVector = copy;
        }
        this.params = arrayRealVector;
        this.estim = new LogisticFunction(this.params.getData());
        if (this.nVar > 0) {
            this.constLog = new Logistic(this.Y.getData());
            this.constLog.regression();
        }
        return arrayRealVector.getData();
    }

    public LogisticFunction getEstimFunction() {
        return this.estim;
    }

    public double[] getEstimation() {
        double[] dArr = new double[this.n];
        for (int i = 0; i < this.n; i++) {
            dArr[i] = this.estim.value(this.A.getRow(i));
        }
        return dArr;
    }

    public double getLikelihood() {
        double d = 1.0d;
        for (int i = 0; i < this.n; i++) {
            d *= Math.pow(this.estim.value(this.A.getRow(i)), this.Y.getEntry(i)) * Math.pow(1.0d - this.estim.value(this.A.getRow(i)), 1.0d - this.Y.getEntry(i));
        }
        return d;
    }

    public double getLogLikelihood() {
        double d = 0.0d;
        for (int i = 0; i < this.n; i++) {
            double log = Math.log(this.estim.value(this.A.getRow(i))) * this.Y.getEntry(i);
            if (!Double.isNaN(log)) {
                d += log;
            }
            double log2 = Math.log(1.0d - this.estim.value(this.A.getRow(i))) * (1.0d - this.Y.getEntry(i));
            if (!Double.isNaN(log2)) {
                d += log2;
            }
        }
        return d;
    }

    public double getDiffLikelihood() {
        return (-2.0d) * (this.constLog.getLogLikelihood() - getLogLikelihood());
    }

    public double getProbaTest() throws MathException {
        if (Double.isInfinite(getDiffLikelihood())) {
            return Double.NaN;
        }
        return 1.0d - new ChiSquaredDistributionImpl(this.nVar).cumulativeProbability(getDiffLikelihood());
    }

    public double getR2() {
        return 1.0d - (getLogLikelihood() / this.constLog.getLogLikelihood());
    }

    public double getAIC() {
        return (2 * this.nVar) - (2.0d * getLogLikelihood());
    }
}
