package com.github.keenon.loglinear.inference;

import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import java.util.IdentityHashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/keenon/loglinear/inference/CliqueTree.class */
public class CliqueTree {
    private static final Logger log;
    GraphicalModel model;
    ConcatVector weights;
    public static final String VARIABLE_OBSERVED_VALUE = "inference.CliqueTree.VARIABLE_OBSERVED_VALUE";
    private static final boolean CACHE_MESSAGES = true;
    private IdentityHashMap<GraphicalModel.Factor, CachedFactorWithObservations> cachedFactors = new IdentityHashMap<>();
    private TableFactor[] cachedCliqueList;
    private TableFactor[][] cachedMessages;
    private boolean[][] cachedBackwardPassedMessages;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/keenon/loglinear/inference/CliqueTree$CachedFactorWithObservations.class */
    public static class CachedFactorWithObservations {
        TableFactor cachedFactor;
        int[] observations;
        boolean impossibleObservation;

        private CachedFactorWithObservations() {
        }
    }

    /* loaded from: input_file:com/github/keenon/loglinear/inference/CliqueTree$MarginalResult.class */
    public static class MarginalResult {
        public double[][] marginals;
        public double partitionFunction;
        public Map<GraphicalModel.Factor, TableFactor> jointMarginals;

        public MarginalResult(double[][] dArr, double d, Map<GraphicalModel.Factor, TableFactor> map) {
            this.marginals = dArr;
            this.partitionFunction = d;
            this.jointMarginals = map;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/github/keenon/loglinear/inference/CliqueTree$MarginalizationMethod.class */
    public enum MarginalizationMethod {
        SUM,
        MAX
    }

    public CliqueTree(GraphicalModel graphicalModel, ConcatVector concatVector) {
        this.model = graphicalModel;
        this.weights = concatVector.deepClone();
    }

    public MarginalResult calculateMarginals() {
        return messagePassing(MarginalizationMethod.SUM, true);
    }

    public double[][] calculateMarginalsJustSingletons() {
        return messagePassing(MarginalizationMethod.SUM, false).marginals;
    }

    public int[] calculateMAP() {
        double[][] dArr = messagePassing(MarginalizationMethod.MAX, false).marginals;
        int[] iArr = new int[dArr.length];
        for (int i = 0; i < iArr.length; i++) {
            if (dArr[i] != null) {
                for (int i2 = 0; i2 < dArr[i].length; i2++) {
                    if (dArr[i][i2] > dArr[i][iArr[i]]) {
                        iArr[i] = i2;
                    }
                }
            }
            if (this.model.getVariableMetaDataByReference(i).containsKey(VARIABLE_OBSERVED_VALUE)) {
                iArr[i] = Integer.parseInt(this.model.getVariableMetaDataByReference(i).get(VARIABLE_OBSERVED_VALUE));
            }
        }
        return iArr;
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    /* JADX WARN: Code restructure failed: missing block: B:487:0x0dfd, code lost:
    
        r40 = 0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:489:0x0e08, code lost:
    
        if (r40 >= r36.neighborIndices.length) goto L611;
     */
    /* JADX WARN: Code restructure failed: missing block: B:490:0x0e0b, code lost:
    
        r0 = r36.neighborIndices[r40];
     */
    /* JADX WARN: Code restructure failed: missing block: B:491:0x0e1a, code lost:
    
        if (r0[r0] != 0) goto L627;
     */
    /* JADX WARN: Code restructure failed: missing block: B:492:0x0e1d, code lost:
    
        r0[r0] = r39[r40];
     */
    /* JADX WARN: Code restructure failed: missing block: B:494:0x0e27, code lost:
    
        r40 = r40 + 1;
     */
    /* JADX WARN: Code restructure failed: missing block: B:579:0x0ef1, code lost:
    
        r37 = 0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:581:0x0efc, code lost:
    
        if (r37 >= r35.neighborIndices.length) goto L639;
     */
    /* JADX WARN: Code restructure failed: missing block: B:582:0x0eff, code lost:
    
        r0 = r35.neighborIndices[r37];
     */
    /* JADX WARN: Code restructure failed: missing block: B:583:0x0f0e, code lost:
    
        if (r0[r0] != 0) goto L648;
     */
    /* JADX WARN: Code restructure failed: missing block: B:584:0x0f11, code lost:
    
        r0[r0] = r36[r37];
     */
    /* JADX WARN: Code restructure failed: missing block: B:586:0x0f1b, code lost:
    
        r37 = r37 + 1;
     */
    /* JADX WARN: Code restructure failed: missing block: B:624:0x0fe1, code lost:
    
        r37 = 0;
     */
    /* JADX WARN: Code restructure failed: missing block: B:626:0x0fec, code lost:
    
        if (r37 >= r35.neighborIndices.length) goto L651;
     */
    /* JADX WARN: Code restructure failed: missing block: B:627:0x0fef, code lost:
    
        r0 = r35.neighborIndices[r37];
     */
    /* JADX WARN: Code restructure failed: missing block: B:628:0x0ffe, code lost:
    
        if (r0[r0] != 0) goto L660;
     */
    /* JADX WARN: Code restructure failed: missing block: B:629:0x1001, code lost:
    
        r0[r0] = r36[r37];
     */
    /* JADX WARN: Code restructure failed: missing block: B:631:0x100b, code lost:
    
        r37 = r37 + 1;
     */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v539, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v64, types: [double[], double[][]] */
    /*
        Code decompiled incorrectly, please refer to instructions dump.
        To view partially-correct add '--show-bad-code' argument
    */
    private com.github.keenon.loglinear.inference.CliqueTree.MarginalResult messagePassing(com.github.keenon.loglinear.inference.CliqueTree.MarginalizationMethod r9, boolean r10) {
        /*
            Method dump skipped, instructions count: 4297
            To view this dump add '--comments-level debug' option
        */
        throw new UnsupportedOperationException("Method not decompiled: com.github.keenon.loglinear.inference.CliqueTree.messagePassing(com.github.keenon.loglinear.inference.CliqueTree$MarginalizationMethod, boolean):com.github.keenon.loglinear.inference.CliqueTree$MarginalResult");
    }

    private int[] getObservedAssignments(GraphicalModel.Factor factor) {
        int[] iArr = new int[factor.neigborIndices.length];
        for (int i = 0; i < iArr.length; i++) {
            if (this.model.getVariableMetaDataByReference(factor.neigborIndices[i]).containsKey(VARIABLE_OBSERVED_VALUE)) {
                iArr[i] = Integer.parseInt(this.model.getVariableMetaDataByReference(factor.neigborIndices[i]).get(VARIABLE_OBSERVED_VALUE));
            } else {
                iArr[i] = -1;
            }
        }
        return iArr;
    }

    private TableFactor marginalizeMessage(TableFactor tableFactor, int[] iArr, MarginalizationMethod marginalizationMethod) {
        TableFactor tableFactor2 = tableFactor;
        for (int i : tableFactor.neighborIndices) {
            boolean z = false;
            int length = iArr.length;
            int i2 = 0;
            while (true) {
                if (i2 < length) {
                    if (i == iArr[i2]) {
                        z = true;
                    } else {
                        i2++;
                    }
                }
            }
            if (!z) {
                switch (marginalizationMethod) {
                    case SUM:
                        tableFactor2 = tableFactor2.sumOut(i);
                        break;
                    case MAX:
                        tableFactor2 = tableFactor2.maxOut(i);
                        break;
                }
            }
        }
        return tableFactor2;
    }

    private boolean domainsOverlap(TableFactor tableFactor, TableFactor tableFactor2) {
        for (int i : tableFactor.neighborIndices) {
            for (int i2 : tableFactor2.neighborIndices) {
                if (i == i2) {
                    return true;
                }
            }
        }
        return false;
    }

    private boolean assertsEnabled() {
        boolean z = false;
        if (!$assertionsDisabled) {
            z = true;
            if (1 == 0) {
                throw new AssertionError();
            }
        }
        return z;
    }

    static {
        $assertionsDisabled = !CliqueTree.class.desiredAssertionStatus();
        log = LoggerFactory.getLogger(CliqueTree.class);
    }
}
