package blog.distrib;

import Jama.Matrix;
import blog.AbstractCondProbDistrib;
import blog.Model;
import blog.Type;
import common.Util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:blog/distrib/Categorical.class */
public class Categorical extends AbstractCondProbDistrib {
    public static final int NO_FLAGS = 0;
    public static final int NORMALIZED = 1;
    public static final int LOG = 2;
    public static final int SORTED = 4;
    boolean expectProbsAsArg;
    private double[] probs;

    public Categorical(int i) {
        this.expectProbsAsArg = true;
        this.probs = new double[i];
        Arrays.fill(this.probs, 1.0d / i);
        this.expectProbsAsArg = false;
    }

    public Categorical(double[] dArr) {
        this.expectProbsAsArg = true;
        this.probs = (double[]) dArr.clone();
        this.expectProbsAsArg = false;
    }

    public Categorical(double[] dArr, int i) {
        this(dArr);
        if ((i & 1) != 0) {
            if ((i & 2) != 0) {
                for (int i2 = 0; i2 < this.probs.length; i2++) {
                    this.probs[i2] = Math.exp(this.probs[i2]);
                }
                return;
            }
            return;
        }
        if ((i & 2) == 0) {
            double d = 0.0d;
            for (int i3 = 0; i3 < this.probs.length; i3++) {
                d += this.probs[i3];
            }
            for (int i4 = 0; i4 < this.probs.length; i4++) {
                double[] dArr2 = this.probs;
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / d;
            }
            return;
        }
        double d2 = dArr.length == 0 ? 1.0d : dArr[0];
        if ((i & 4) == 0) {
            for (int i6 = 1; i6 < dArr.length; i6++) {
                if (dArr[i6] > d2) {
                    d2 = dArr[i6];
                }
            }
        }
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += Math.exp(d4 - d2);
        }
        double log = Math.log(d3) + d2;
        for (int i7 = 0; i7 < dArr.length; i7++) {
            this.probs[i7] = Math.exp(dArr[i7] - log);
        }
    }

    public Categorical(List list) {
        this.expectProbsAsArg = true;
        if (list.isEmpty()) {
            return;
        }
        this.probs = new double[list.size()];
        for (int i = 0; i < list.size(); i++) {
            this.probs[i] = ((Number) list.get(i)).doubleValue();
        }
        this.expectProbsAsArg = false;
    }

    public int getNumValues() {
        if (this.expectProbsAsArg) {
            throw new IllegalStateException("Categorical distribution was constructed without a probability vector.");
        }
        return this.probs.length;
    }

    public double getProb(int i) {
        if (this.expectProbsAsArg) {
            throw new IllegalStateException("Categorical distribution was constructed without a probability vector.");
        }
        if (i < 0) {
            throw new IllegalArgumentException("Negative index passed to Categorical.getProb: " + i);
        }
        if (i >= this.probs.length) {
            return 0.0d;
        }
        return this.probs[i];
    }

    public double getLogProb(int i) {
        if (this.expectProbsAsArg) {
            throw new IllegalStateException("Categorical distribution was constructed without a probability vector.");
        }
        if (i < 0) {
            throw new IllegalArgumentException("Negative index passed to Categorical.getLogProb: " + i);
        }
        if (i >= this.probs.length) {
            return Double.NEGATIVE_INFINITY;
        }
        return Math.log(this.probs[i]);
    }

    @Override // blog.CondProbDistrib
    public double getProb(List list, Object obj) {
        ensureProbsInited(list);
        int objectIndex = Model.getObjectIndex(obj);
        if (objectIndex == -1 || objectIndex >= this.probs.length) {
            return 0.0d;
        }
        return this.probs[objectIndex];
    }

    public int sampleVal() {
        if (this.expectProbsAsArg) {
            throw new IllegalStateException("Categorical distribution was constructed without a probability vector.");
        }
        return Util.sampleWithProbs(this.probs);
    }

    @Override // blog.CondProbDistrib
    public Object sampleVal(List list, Type type) {
        ensureProbsInited(list);
        Object guaranteedObject = type.getGuaranteedObject(Util.sampleWithProbs(this.probs));
        if (guaranteedObject == null) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.probs.length; i++) {
                arrayList.add(new Double(this.probs[i]));
            }
            System.err.println("Warning: distribution does not sum to 1 over the guaranteed objects of type " + type + ": " + arrayList);
            guaranteedObject = Model.NULL;
        }
        return guaranteedObject;
    }

    private void ensureProbsInited(List list) {
        if (!this.expectProbsAsArg) {
            if (!list.isEmpty()) {
                throw new IllegalArgumentException("Categorical CPD expects no arguments (probabilities were specified as CPD parameters).");
            }
            return;
        }
        if (list.isEmpty()) {
            throw new IllegalArgumentException("Arguments to Categorical CPD should consist of a probability vector, since the probabilities were not specified as CPD parameters.");
        }
        if (!(list.get(0) instanceof Matrix) || ((Matrix) list.get(0)).getColumnDimension() != 1) {
            throw new IllegalArgumentException("Argument to Categorical CPD should be a column vector of probabilities, not: " + list.get(0));
        }
        Matrix matrix = (Matrix) list.get(0);
        this.probs = new double[matrix.getRowDimension()];
        for (int i = 0; i < this.probs.length; i++) {
            this.probs[i] = matrix.get(i, 0);
        }
    }
}
