package blog.distrib;

import Jama.Matrix;
import blog.AbstractCondProbDistrib;
import blog.Type;
import java.util.List;

/* loaded from: input_file:blog/distrib/NatNumDistribWithTail.class */
public class NatNumDistribWithTail extends AbstractCondProbDistrib {
    int k;
    Bernoulli mixDistrib;
    Categorical prefixDistrib;
    Geometric geometric;

    public NatNumDistribWithTail(double[] dArr, double d, double d2) {
        this.k = dArr.length;
        this.prefixDistrib = new Categorical(dArr);
        this.mixDistrib = new Bernoulli(d);
        this.geometric = new Geometric(d2);
    }

    public NatNumDistribWithTail(List list) {
        if (list.size() != 3) {
            throw new IllegalArgumentException("NatNumDistribWithTail expects three parameters: the distribution over small numbers, the probability of using that distribution, and the success probability for a geometric distribution over larger numbers.");
        }
        if (!(list.get(0) instanceof Matrix) || ((Matrix) list.get(0)).getColumnDimension() != 1) {
            throw new IllegalArgumentException("First parameter to NatNumDistribWithTail should be a column vector.");
        }
        Matrix matrix = (Matrix) list.get(0);
        this.k = matrix.getRowDimension();
        this.prefixDistrib = new Categorical(matrix.getColumnPackedCopy());
        if (!(list.get(1) instanceof Number)) {
            throw new IllegalArgumentException("Second parameter to NatNumDistribWithTail should be a number (the probability of using the explicit distrib).");
        }
        this.mixDistrib = new Bernoulli(((Number) list.get(1)).doubleValue());
        if (!(list.get(2) instanceof Number)) {
            throw new IllegalArgumentException("Third parameter to NatNumDistribWithTail should be a number (the success prob for the geometric distrib).");
        }
        this.geometric = new Geometric(((Number) list.get(2)).doubleValue());
    }

    public double getProb(int i) {
        return i < this.k ? this.mixDistrib.getProb(true) * this.prefixDistrib.getProb(i) : this.mixDistrib.getProb(false) * this.geometric.getProb(i - this.k);
    }

    public double getLogProb(int i) {
        return i < this.k ? this.mixDistrib.getLogProb(true) + this.prefixDistrib.getLogProb(i) : this.mixDistrib.getLogProb(false) + this.geometric.getLogProb(i - this.k);
    }

    @Override // blog.CondProbDistrib
    public double getProb(List list, Object obj) {
        if (!list.isEmpty()) {
            throw new IllegalArgumentException("NatNumDistribWithTail expects no arguments.");
        }
        if (obj instanceof Integer) {
            return getProb(((Integer) obj).intValue());
        }
        throw new IllegalArgumentException("NatNumDistribWithTail defines distribution over objects of class Integer, not " + obj.getClass());
    }

    public int sampleVal() {
        return this.mixDistrib.sampleVal() ? this.prefixDistrib.sampleVal() : this.k + this.geometric.sampleVal();
    }

    @Override // blog.CondProbDistrib
    public Object sampleVal(List list, Type type) {
        if (list.isEmpty()) {
            return new Integer(sampleVal());
        }
        throw new IllegalArgumentException("NatNumDistribWithTail expects no arguments.");
    }
}
