package umontreal.iro.lecuyer.probdistmulti;

import umontreal.iro.lecuyer.functions.MathFunction;
import umontreal.iro.lecuyer.util.Num;
import umontreal.iro.lecuyer.util.RootFinder;

/* loaded from: input_file:lib/ssj.jar:umontreal/iro/lecuyer/probdistmulti/NegativeMultinomialDist.class */
public class NegativeMultinomialDist extends DiscreteDistributionIntMulti {
    protected double gamma;
    protected double[] p;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:lib/ssj.jar:umontreal/iro/lecuyer/probdistmulti/NegativeMultinomialDist$Function.class */
    public static class Function implements MathFunction {
        protected double[] Fl;
        protected int[] ups;
        protected int n;
        protected int M;
        protected int sumUps;

        public Function(int i, int i2, int[] iArr, double[] dArr) {
            this.n = i;
            this.M = i2;
            this.Fl = new double[dArr.length];
            System.arraycopy(dArr, 0, this.Fl, 0, dArr.length);
            this.ups = new int[iArr.length];
            System.arraycopy(iArr, 0, this.ups, 0, iArr.length);
            this.sumUps = 0;
            for (int i3 : iArr) {
                this.sumUps += i3;
            }
        }

        @Override // umontreal.iro.lecuyer.functions.MathFunction
        public double evaluate(double d) {
            double d2 = 0.0d;
            for (int i = 0; i < this.M; i++) {
                d2 += this.Fl[i] / (d + i);
            }
            return d2 - Math.log1p(this.sumUps / (this.n * d));
        }
    }

    public NegativeMultinomialDist(double d, double[] dArr) {
        setParams(d, dArr);
    }

    @Override // umontreal.iro.lecuyer.probdistmulti.DiscreteDistributionIntMulti
    public double prob(int[] iArr) {
        return prob_(this.gamma, this.p, iArr);
    }

    @Override // umontreal.iro.lecuyer.probdistmulti.DiscreteDistributionIntMulti
    public double[] getMean() {
        return getMean_(this.gamma, this.p);
    }

    @Override // umontreal.iro.lecuyer.probdistmulti.DiscreteDistributionIntMulti
    public double[][] getCovariance() {
        return getCovariance_(this.gamma, this.p);
    }

    @Override // umontreal.iro.lecuyer.probdistmulti.DiscreteDistributionIntMulti
    public double[][] getCorrelation() {
        return getCorrelation_(this.gamma, this.p);
    }

    private static void verifParam(double d, double[] dArr) {
        double d2 = 0.0d;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] < 0.0d || dArr[i] >= 1.0d) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            d2 += dArr[i];
        }
        if (d2 >= 1.0d) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }

    private static double prob_(double d, double[] dArr, int[] iArr) {
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.0d;
        if (iArr.length != dArr.length) {
            throw new IllegalArgumentException("x and p must have the same size");
        }
        for (int i = 0; i < dArr.length; i++) {
            d2 += dArr[i];
            d3 += iArr[i];
            d4 += Num.lnFactorial(iArr[i]);
            d5 += iArr[i] * Math.log(dArr[i]);
        }
        return Math.exp((Num.lnGamma(d + d3) - (Num.lnGamma(d) + d4)) + (d * Math.log(1.0d - d2)) + d5);
    }

    public static double prob(double d, double[] dArr, int[] iArr) {
        verifParam(d, dArr);
        return prob_(d, dArr, iArr);
    }

    private static double cdf_(double d, double[] dArr, int[] iArr) {
        throw new UnsupportedOperationException("cdf not implemented");
    }

    public static double cdf(double d, double[] dArr, int[] iArr) {
        verifParam(d, dArr);
        return cdf_(d, dArr, iArr);
    }

    private static double[] getMean_(double d, double[] dArr) {
        double d2 = 0.0d;
        double[] dArr2 = new double[dArr.length];
        for (double d3 : dArr) {
            d2 += d3;
        }
        double d4 = 1.0d - d2;
        for (int i = 0; i < dArr.length; i++) {
            dArr2[i] = (d * dArr[i]) / d4;
        }
        return dArr2;
    }

    public static double[] getMean(double d, double[] dArr) {
        verifParam(d, dArr);
        return getMean_(d, dArr);
    }

    private static double[][] getCovariance_(double d, double[] dArr) {
        double d2 = 0.0d;
        double[][] dArr2 = new double[dArr.length][dArr.length];
        for (double d3 : dArr) {
            d2 += d3;
        }
        double d4 = 1.0d - d2;
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr2[i][i2] = ((d * dArr[i]) * dArr[i2]) / (d4 * d4);
            }
            dArr2[i][i] = ((d * dArr[i]) * (dArr[i] + d4)) / (d4 * d4);
        }
        return dArr2;
    }

    public static double[][] getCovariance(double d, double[] dArr) {
        verifParam(d, dArr);
        return getCovariance_(d, dArr);
    }

    private static double[][] getCorrelation_(double d, double[] dArr) {
        double[][] dArr2 = new double[dArr.length][dArr.length];
        double d2 = 0.0d;
        for (double d3 : dArr) {
            d2 += d3;
        }
        double d4 = 1.0d - d2;
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr2[i][i2] = Math.sqrt((dArr[i] * dArr[i2]) / ((d4 + dArr[i]) * (d4 + dArr[i2])));
            }
            dArr2[i][i] = 1.0d;
        }
        return dArr2;
    }

    public static double[][] getCorrelation(double d, double[] dArr) {
        verifParam(d, dArr);
        return getCorrelation_(d, dArr);
    }

    @Deprecated
    public static double[] getMaximumLikelihoodEstimate(int[][] iArr, int i, int i2) {
        return getMLE(iArr, i, i2);
    }

    public static double[] getMLE(int[][] iArr, int i, int i2) {
        double[] dArr = new double[i2 + 1];
        int[] iArr2 = new int[i];
        double[] dArr2 = new double[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            dArr2[i3] = 0.0d;
        }
        for (int i4 = 0; i4 < i; i4++) {
            iArr2[i4] = 0;
            for (int i5 = 0; i5 < i2; i5++) {
                int i6 = i4;
                iArr2[i6] = iArr2[i6] + iArr[i4][i5];
                int i7 = i5;
                dArr2[i7] = dArr2[i7] + iArr[i4][i5];
            }
        }
        for (int i8 = 0; i8 < i2; i8++) {
            int i9 = i8;
            dArr2[i9] = dArr2[i9] / i;
        }
        int i10 = iArr2[0];
        for (int i11 = 1; i11 < i; i11++) {
            if (iArr2[i11] > i10) {
                i10 = iArr2[i11];
            }
        }
        if (i10 >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("gamma/p_i too large");
        }
        double[] dArr3 = new double[i10];
        for (int i12 = 0; i12 < i10; i12++) {
            int i13 = 0;
            for (int i14 = 0; i14 < i; i14++) {
                if (iArr2[i14] > i12) {
                    i13++;
                }
            }
            dArr3[i12] = i13 / i;
        }
        dArr[0] = RootFinder.brentDekker(1.0E-9d, 1.0E9d, new Function(i, i10, iArr2, dArr3), 1.0E-5d);
        double[] dArr4 = new double[i2];
        double d = 0.0d;
        for (int i15 = 0; i15 < i2; i15++) {
            dArr4[i15] = dArr2[i15] / dArr[0];
            d += dArr4[i15];
        }
        for (int i16 = 0; i16 < i2; i16++) {
            dArr[i16 + 1] = dArr4[i16] / (1.0d + d);
            if (dArr[i16 + 1] > 1.0d) {
                throw new IllegalArgumentException("p_i > 1");
            }
        }
        return dArr;
    }

    public double getGamma() {
        return this.gamma;
    }

    public double[] getP() {
        return this.p;
    }

    public void setParams(double d, double[] dArr) {
        double d2 = 0.0d;
        if (d <= 0.0d) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        this.gamma = d;
        this.dimension = dArr.length;
        this.p = new double[this.dimension];
        for (int i = 0; i < this.dimension; i++) {
            if (dArr[i] < 0.0d || dArr[i] >= 1.0d) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            d2 += dArr[i];
            this.p[i] = dArr[i];
        }
        if (d2 >= 1.0d) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }
}
