package blog;

import blog.Function;
import common.Util;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Properties;
import java.util.Set;

/* loaded from: input_file:blog/UrnBallsSplitMergeNoIds.class */
public class UrnBallsSplitMergeNoIds implements Proposer {
    private Model model;
    private Evidence evidence;
    private List queries;
    private Type ballType;
    private Type drawType;
    private Type colorType;
    private POP ballPOP;
    private RandomFunction fBallDrawn;
    private RandomFunction fColor;
    private RandomFunction fObsColor;
    private static final double NUM_BALLS_MOVE_PROB = 0.25d;

    public UrnBallsSplitMergeNoIds(Model model, Properties properties) {
        this.model = model;
        this.ballType = model.getType("Ball");
        this.drawType = model.getType("Draw");
        this.ballPOP = (POP) this.ballType.getPOPs().iterator().next();
        this.fBallDrawn = (RandomFunction) model.getFunction(new Function.Sig("BallDrawn", this.drawType));
        this.fColor = (RandomFunction) model.getFunction(new Function.Sig("Color", this.ballType));
        this.fObsColor = (RandomFunction) model.getFunction(new Function.Sig("ObsColor", this.drawType));
        this.colorType = this.fColor.getRetType();
    }

    @Override // blog.Proposer
    public PartialWorldDiff initialize(Evidence evidence, List list) {
        this.evidence = evidence;
        this.queries = list;
        PartialWorldDiff partialWorldDiff = new PartialWorldDiff(new DefaultPartialWorld(Collections.EMPTY_SET));
        for (BasicVar basicVar : evidence.getEvidenceVars()) {
            partialWorldDiff.setValue(basicVar, evidence.getObservedValue(basicVar));
        }
        partialWorldDiff.setValue(new NumberVar(this.ballPOP, Collections.EMPTY_LIST), new Integer(this.drawType.getGuaranteedObjects().size()));
        int i = 1;
        Iterator it = this.drawType.getGuaranteedObjects().iterator();
        while (it.hasNext()) {
            RandFuncAppVar randFuncAppVar = new RandFuncAppVar(this.fBallDrawn, Collections.singletonList(it.next()));
            int i2 = i;
            i++;
            Object obj = NonGuaranteedObject.get(this.ballPOP, new Object[0], i2);
            partialWorldDiff.setValue(randFuncAppVar, obj);
            if (Util.verbose()) {
                System.out.println("Color probability for ball " + obj + ":");
            }
            sampleColor(partialWorldDiff, obj);
        }
        return partialWorldDiff;
    }

    @Override // blog.Proposer
    public double proposeNextState(PartialWorldDiff partialWorldDiff) {
        if (Util.random() < NUM_BALLS_MOVE_PROB) {
            return doNumBallsMove(partialWorldDiff);
        }
        List sampleWithoutReplacement = Util.sampleWithoutReplacement(this.drawType.getGuaranteedObjects(), 2);
        if (sampleWithoutReplacement.size() == 0) {
            if (!Util.verbose()) {
                return 0.0d;
            }
            System.out.println("No draws, so can't change anything.");
            return 0.0d;
        }
        if (sampleWithoutReplacement.size() != 1) {
            return doSplitMerge(partialWorldDiff, sampleWithoutReplacement.get(0), sampleWithoutReplacement.get(1));
        }
        if (Util.verbose()) {
            System.out.println("Only one draw; resampling ball color.");
        }
        Object valueSingleArg = this.fBallDrawn.getValueSingleArg(sampleWithoutReplacement.get(0), partialWorldDiff);
        return getColorSamplingLogProb(partialWorldDiff, valueSingleArg) - sampleColor(partialWorldDiff, valueSingleArg);
    }

    @Override // blog.Proposer
    public void updateStats(boolean z) {
    }

    @Override // blog.Proposer
    public void printStats() {
    }

    private double doNumBallsMove(PartialWorld partialWorld) {
        if (Util.verbose()) {
            System.out.println("Changing number of balls...");
        }
        adjustNumBalls(partialWorld, Util.random() < 0.5d ? -1 : 1);
        return 0.0d;
    }

    private double doSplitMerge(PartialWorld partialWorld, Object obj, Object obj2) {
        double sampleColor;
        BayesNetVar randFuncAppVar = new RandFuncAppVar(this.fBallDrawn, Collections.singletonList(obj));
        Object value = partialWorld.getValue(randFuncAppVar);
        BasicVar randFuncAppVar2 = new RandFuncAppVar(this.fBallDrawn, Collections.singletonList(obj2));
        Object value2 = partialWorld.getValue(randFuncAppVar2);
        if (value == value2) {
            Set unusedBalls = getUnusedBalls(partialWorld);
            if (unusedBalls.isEmpty()) {
                return 0.0d;
            }
            if (Util.verbose()) {
                System.out.println("Splitting ball " + value);
            }
            if (Util.verbose()) {
                System.out.println("Backward color sampling prob: ");
            }
            double colorSamplingLogProb = 0.0d + getColorSamplingLogProb(partialWorld, value);
            Set<BasicVar> varsWithValue = partialWorld.getVarsWithValue(value);
            Object uniformSample = Util.uniformSample(unusedBalls);
            double d = colorSamplingLogProb - (-Math.log(unusedBalls.size()));
            partialWorld.setValue(randFuncAppVar2, uniformSample);
            for (BasicVar basicVar : varsWithValue) {
                if (!basicVar.equals(randFuncAppVar) && !basicVar.equals(randFuncAppVar2) && Util.random() < 0.5d) {
                    partialWorld.setValue(basicVar, uniformSample);
                }
            }
            double size = d - ((varsWithValue.size() - 2) * Math.log(0.5d));
            if (Util.verbose()) {
                System.out.println("Split probability " + Math.exp((varsWithValue.size() - 2) * Math.log(0.5d)));
            }
            if (Util.verbose()) {
                System.out.println("Forward color sampling prob for " + value);
            }
            double sampleColor2 = size - sampleColor(partialWorld, value);
            if (Util.verbose()) {
                System.out.println("Forward color sampling prob for " + uniformSample);
            }
            sampleColor = sampleColor2 - sampleColor(partialWorld, uniformSample);
        } else {
            if (Util.verbose()) {
                System.out.println("Merging balls " + value + " and " + value2);
            }
            if (Util.verbose()) {
                System.out.println("Backward color sampling prob for " + value);
            }
            double colorSamplingLogProb2 = 0.0d + getColorSamplingLogProb(partialWorld, value);
            if (Util.verbose()) {
                System.out.println("Backward color sampling prob for " + value2);
            }
            double colorSamplingLogProb3 = colorSamplingLogProb2 + getColorSamplingLogProb(partialWorld, value2);
            ArrayList arrayList = new ArrayList(partialWorld.getVarsWithValue(value));
            arrayList.addAll(partialWorld.getVarsWithValue(value2));
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                partialWorld.setValue((BasicVar) it.next(), value);
            }
            partialWorld.setValue(new RandFuncAppVar(this.fColor, Collections.singletonList(value2)), null);
            double size2 = colorSamplingLogProb3 + (-Math.log(getUnusedBalls(partialWorld).size())) + ((arrayList.size() - 2) * Math.log(0.5d));
            if (Util.verbose()) {
                System.out.println("Backward split sampling prob: " + Math.exp((arrayList.size() - 2) * Math.log(0.5d)));
            }
            if (Util.verbose()) {
                System.out.println("Forward color sampling prob: ");
            }
            sampleColor = size2 - sampleColor(partialWorld, value);
        }
        return sampleColor;
    }

    private double sampleColor(PartialWorld partialWorld, Object obj) {
        RandFuncAppVar randFuncAppVar = new RandFuncAppVar(this.fColor, Collections.singletonList(obj));
        List guaranteedObjects = this.colorType.getGuaranteedObjects();
        double[] computeGibbsProbsForColor = computeGibbsProbsForColor(partialWorld, obj);
        int sampleWithProbs = Util.sampleWithProbs(computeGibbsProbsForColor);
        partialWorld.setValue(randFuncAppVar, guaranteedObjects.get(sampleWithProbs));
        if (Util.verbose()) {
            System.out.println("\t" + computeGibbsProbsForColor[sampleWithProbs]);
        }
        return Math.log(computeGibbsProbsForColor[sampleWithProbs]);
    }

    private double getColorSamplingLogProb(PartialWorld partialWorld, Object obj) {
        RandFuncAppVar randFuncAppVar = new RandFuncAppVar(this.fColor, Collections.singletonList(obj));
        double[] computeGibbsProbsForColor = computeGibbsProbsForColor(partialWorld, obj);
        int guaranteedObjIndex = this.colorType.getGuaranteedObjIndex(partialWorld.getValue(randFuncAppVar));
        if (Util.verbose()) {
            System.out.println("\t" + computeGibbsProbsForColor[guaranteedObjIndex]);
        }
        return Math.log(computeGibbsProbsForColor[guaranteedObjIndex]);
    }

    private double[] computeGibbsProbsForColor(PartialWorld partialWorld, Object obj) {
        BasicVar randFuncAppVar = new RandFuncAppVar(this.fColor, Collections.singletonList(obj));
        Object value = partialWorld.getValue(randFuncAppVar);
        List guaranteedObjects = this.colorType.getGuaranteedObjects();
        Set varsWithValue = partialWorld.getVarsWithValue(obj);
        HashSet hashSet = new HashSet();
        Iterator it = varsWithValue.iterator();
        while (it.hasNext()) {
            hashSet.add(new RandFuncAppVar(this.fObsColor, Collections.singletonList(((BasicVar) it.next()).args()[0])));
        }
        double[] dArr = new double[guaranteedObjects.size()];
        double d = 0.0d;
        for (int i = 0; i < guaranteedObjects.size(); i++) {
            partialWorld.setValue(randFuncAppVar, guaranteedObjects.get(i));
            dArr[i] = Math.exp(partialWorld.getLogProbOfValue(randFuncAppVar));
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                int i2 = i;
                dArr[i2] = dArr[i2] * Math.exp(partialWorld.getLogProbOfValue((BasicVar) it2.next()));
            }
            d += dArr[i];
        }
        for (int i3 = 0; i3 < dArr.length; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] / d;
        }
        if (value == null) {
            partialWorld.setValue(randFuncAppVar, null);
        } else {
            partialWorld.setValue(randFuncAppVar, value);
        }
        return dArr;
    }

    private void adjustNumBalls(PartialWorld partialWorld, int i) {
        NumberVar numberVar = new NumberVar(this.ballPOP, Collections.EMPTY_LIST);
        int intValue = ((Integer) partialWorld.getValue(numberVar)).intValue();
        partialWorld.setValue(numberVar, new Integer(intValue + i));
        for (int i2 = intValue; i2 > intValue + i; i2--) {
            Object obj = NonGuaranteedObject.get(numberVar, i2);
            for (BasicVar basicVar : partialWorld.getVarsWithValue(obj)) {
                if (Util.verbose()) {
                    System.out.println("Changing " + basicVar + " from " + obj + " to null.");
                }
                partialWorld.setValue(basicVar, Model.NULL);
            }
        }
    }

    private Set getUnusedBalls(PartialWorld partialWorld) {
        LinkedHashSet linkedHashSet = new LinkedHashSet(partialWorld.getSatisfiers(new NumberVar(this.ballPOP, Collections.EMPTY_LIST)));
        Iterator it = this.drawType.getGuaranteedObjects().iterator();
        while (it.hasNext()) {
            linkedHashSet.remove(partialWorld.getValue(new RandFuncAppVar(this.fBallDrawn, Collections.singletonList(it.next()))));
        }
        return linkedHashSet;
    }
}
