package blog;

import common.Timer;
import common.Util;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.Set;

/* loaded from: input_file:blog/MHSampler.class */
public class MHSampler extends Sampler {
    protected Timer acceptProbTimer;
    protected Timer worldUpdateTimer;
    protected Proposer proposer;
    protected PartialWorldDiff curWorld;
    protected int numTrials;
    protected int totalNumSamples;
    protected int totalNumAccepted;
    protected int numSamplesThisTrial;
    protected int numAcceptedThisTrial;

    public MHSampler(Model model, Properties properties) {
        super(model);
        this.acceptProbTimer = new Timer();
        this.worldUpdateTimer = new Timer();
        this.numTrials = 0;
        this.totalNumSamples = 0;
        this.totalNumAccepted = 0;
        this.numSamplesThisTrial = 0;
        this.numAcceptedThisTrial = 0;
        constructProposer(properties);
    }

    protected void constructProposer(Properties properties) {
        String property = properties.getProperty("proposerClass", "blog.GenericProposer");
        System.out.println("Constructing M-H proposer of class " + property);
        try {
            this.proposer = (Proposer) Class.forName(property).getConstructor(Model.class, Properties.class).newInstance(this.model, properties);
        } catch (Exception e) {
            Util.fatalError(e);
        }
    }

    @Override // blog.Sampler
    public void initialize(Evidence evidence, List list) {
        super.initialize(evidence, list);
        this.numTrials++;
        this.numSamplesThisTrial = 0;
        this.numAcceptedThisTrial = 0;
        if (Util.verbose()) {
            System.out.println("Creating initial world...");
        }
        this.curWorld = this.proposer.initialize(evidence, list);
        if (Util.verbose()) {
            System.out.println("Saving initial world...");
        }
        Timer timer = new Timer();
        timer.start();
        this.curWorld.save();
        if (Util.verbose()) {
            System.out.println("Saving initial world took " + timer.elapsedTime() + " s");
        }
        if (Util.verbose()) {
            System.out.println("Validating initial world...");
        }
        if (!validateIdentifiers(this.curWorld)) {
            Util.fatalError("Fatal identifier errors in initial world.", false);
        }
        if (!evidence.isTrue(this.curWorld)) {
            throw new IllegalStateException("Error: evidence is not true in initial world.");
        }
    }

    public void setWorld(PartialWorld partialWorld) {
        if (partialWorld instanceof PartialWorldDiff) {
            this.curWorld = (PartialWorldDiff) partialWorld;
        } else {
            this.curWorld = new PartialWorldDiff(partialWorld);
        }
    }

    @Override // blog.Sampler
    public void nextSample() {
        this.totalNumSamples++;
        this.numSamplesThisTrial++;
        if (Util.verbose()) {
            System.out.println("Proposing world...");
        }
        double proposeNextState = this.proposer.proposeNextState(this.curWorld);
        if (Util.verbose()) {
            System.out.println();
            System.out.println("\tlog proposal ratio: " + proposeNextState);
        }
        if (!validateIdentifiers(this.curWorld)) {
            Util.fatalError("Fatal identifier errors in proposed world.", false);
        }
        this.acceptProbTimer.start();
        double computeLogProbRatio = computeLogProbRatio(this.curWorld.getSaved(), this.curWorld);
        if (Util.verbose()) {
            System.out.println("\tlog probability ratio: " + computeLogProbRatio);
        }
        double d = computeLogProbRatio + proposeNextState;
        if (Util.verbose()) {
            System.out.println("\tlog acceptance ratio: " + d);
        }
        this.acceptProbTimer.stop();
        if (d < 0.0d && Util.random() >= Math.exp(d)) {
            this.curWorld.revert();
            if (Util.verbose()) {
                System.out.println("\trejected");
            }
            this.proposer.updateStats(false);
            return;
        }
        this.worldUpdateTimer.start();
        this.curWorld.save();
        this.worldUpdateTimer.stop();
        if (Util.verbose()) {
            System.out.println("\taccepted");
        }
        this.totalNumAccepted++;
        this.numAcceptedThisTrial++;
        this.proposer.updateStats(true);
    }

    public PartialWorld nextSample(PartialWorld partialWorld) {
        PartialWorldDiff partialWorldDiff = this.curWorld;
        setWorld(partialWorld);
        nextSample();
        PartialWorld latestWorld = getLatestWorld();
        this.curWorld = partialWorldDiff;
        return latestWorld;
    }

    @Override // blog.Sampler
    public PartialWorld getLatestWorld() {
        return this.curWorld.getSaved();
    }

    private boolean validateIdentifiers(PartialWorldDiff partialWorldDiff) {
        boolean z = true;
        Set newlyOverloadedNumberVars = partialWorldDiff.getNewlyOverloadedNumberVars();
        if (!newlyOverloadedNumberVars.isEmpty()) {
            z = false;
        }
        Iterator it = newlyOverloadedNumberVars.iterator();
        while (it.hasNext()) {
            System.err.println("Error: Number variable " + ((NumberVar) it.next()) + " is satisfied by too many identifiers.");
        }
        Set newlyFloatingIds = partialWorldDiff.getNewlyFloatingIds();
        if (!newlyFloatingIds.isEmpty()) {
            z = false;
        }
        Iterator it2 = newlyFloatingIds.iterator();
        while (it2.hasNext()) {
            System.err.println("Error: Identifier " + it2.next() + " is not the value of any basic variable.");
        }
        return z;
    }

    private double computeLogProbRatio(PartialWorld partialWorld, PartialWorldDiff partialWorldDiff) {
        double computeLogMultRatio = 0.0d + computeLogMultRatio(partialWorld, partialWorldDiff);
        for (BayesNetVar bayesNetVar : partialWorldDiff.getVarsWithChangedProbs()) {
            double logProbOfValue = partialWorld.getLogProbOfValue(bayesNetVar);
            if (Util.verbose() && logProbOfValue == Double.NEGATIVE_INFINITY) {
                System.out.println("Zero probability in old world: " + bayesNetVar + " = " + partialWorld.getValue(bayesNetVar));
            }
            double logProbOfValue2 = partialWorldDiff.getLogProbOfValue(bayesNetVar);
            if (Util.verbose() && logProbOfValue2 == Double.NEGATIVE_INFINITY) {
                System.out.println("Zero probability in proposed world: " + bayesNetVar + " = " + partialWorldDiff.getValue(bayesNetVar));
            }
            if (this.evidence.getEvidenceVars().contains(bayesNetVar)) {
                Object observedValue = this.evidence.getObservedValue(bayesNetVar);
                if (!partialWorld.getValue(bayesNetVar).equals(observedValue)) {
                    logProbOfValue = Double.NEGATIVE_INFINITY;
                }
                if (!partialWorldDiff.getValue(bayesNetVar).equals(observedValue)) {
                    logProbOfValue2 = Double.NEGATIVE_INFINITY;
                }
            }
            if (Util.verbose()) {
                System.out.println("Variable " + bayesNetVar + " going from log prob " + logProbOfValue + " to log prob " + logProbOfValue2);
            }
            if (logProbOfValue != logProbOfValue2) {
                computeLogMultRatio = (computeLogMultRatio - logProbOfValue) + logProbOfValue2;
            }
        }
        return computeLogMultRatio;
    }

    private double computeLogMultRatio(PartialWorld partialWorld, PartialWorldDiff partialWorldDiff) {
        double d = 0.0d;
        for (NumberVar numberVar : partialWorldDiff.getVarsWithChangedMultipliers()) {
            int size = partialWorld.getValue(numberVar) == null ? 0 : partialWorld.getSatisfiers(numberVar).size();
            int size2 = partialWorld.getAssertedIdsForPOPApp(numberVar).size();
            int size3 = partialWorldDiff.getValue(numberVar) == null ? 0 : partialWorldDiff.getSatisfiers(numberVar).size();
            int size4 = partialWorldDiff.getAssertedIdsForPOPApp(numberVar).size();
            if (Util.verbose()) {
                System.out.println("For " + numberVar + ":");
                System.out.println("\tcurrently " + size + " satisfiers, " + size2 + " IDs");
                System.out.println("\tproposed " + size3 + " satisfiers, " + size4 + " IDs");
            }
            int i = size3 - size4;
            int i2 = size - size2;
            if (i >= size || i2 >= size3) {
                d = (d + Util.logPartialFactorial(size3, size4)) - Util.logPartialFactorial(size, size2);
            } else {
                if (size3 > size) {
                    d += Util.logPartialFactorial(size3, size3 - size);
                } else if (size > size3) {
                    d -= Util.logPartialFactorial(size, size - size3);
                }
                if (i < i2) {
                    d += Util.logPartialFactorial(i2, i2 - i);
                } else if (i2 < i) {
                    d -= Util.logPartialFactorial(i, i - i2);
                }
            }
        }
        return d;
    }

    @Override // blog.Sampler
    public void printStats() {
        System.out.println("======== MH Trial Stats ========");
        if (this.totalNumSamples > 0) {
            if (this.numSamplesThisTrial > 0) {
                System.out.println("Fraction of proposals accepted (this trial): " + (this.numAcceptedThisTrial / this.numSamplesThisTrial));
            }
            System.out.println("Fraction of proposals accepted (running avg, all trials): " + (this.totalNumAccepted / this.totalNumSamples));
            System.out.println("Time spent computing acceptance probs: " + this.acceptProbTimer.elapsedTime() + " s");
            System.out.println("Time spent updating world: " + this.worldUpdateTimer.elapsedTime() + " s");
        } else {
            System.out.println("No samples yet.");
        }
        this.proposer.printStats();
    }
}
