package com.github.keenon.loglinear.learning;

import com.github.keenon.loglinear.inference.CliqueTree;
import com.github.keenon.loglinear.learning.AbstractFunction;
import com.github.keenon.loglinear.model.ConcatVector;
import com.github.keenon.loglinear.model.GraphicalModel;
import java.util.Iterator;

/* loaded from: input_file:com/github/keenon/loglinear/learning/LogLikelihoodFunction.class */
public class LogLikelihoodFunction extends AbstractFunction<GraphicalModel> {
    public static final String VARIABLE_TRAINING_VALUE = "learning.LogLikelihoodFunction.VARIABLE_TRAINING_VALUE";

    @Override // com.github.keenon.loglinear.learning.AbstractFunction
    public AbstractFunction.FunctionSummaryAtPoint getSummaryForInstance(GraphicalModel graphicalModel, ConcatVector concatVector) {
        double d = 0.0d;
        ConcatVector concatVector2 = new ConcatVector(concatVector.getNumberOfComponents());
        CliqueTree.MarginalResult calculateMarginals = new CliqueTree(graphicalModel, concatVector).calculateMarginals();
        Iterator<GraphicalModel.Factor> it = graphicalModel.factors.iterator();
        while (it.hasNext()) {
            it.next().featuresTable.cacheVectors();
        }
        for (GraphicalModel.Factor factor : graphicalModel.factors) {
            int[] iArr = new int[factor.neigborIndices.length];
            for (int i = 0; i < iArr.length; i++) {
                int deterministicAssignment = getDeterministicAssignment(calculateMarginals.marginals[factor.neigborIndices[i]]);
                if (deterministicAssignment != -1) {
                    iArr[i] = deterministicAssignment;
                } else {
                    iArr[i] = Integer.parseInt(graphicalModel.getVariableMetaDataByReference(factor.neigborIndices[i]).get(VARIABLE_TRAINING_VALUE));
                }
            }
            ConcatVector concatVector3 = factor.featuresTable.getAssignmentValue(iArr).get();
            concatVector2.addVectorInPlace(concatVector3, 1.0d);
            d += concatVector3.dotProduct(concatVector);
        }
        double log = d - Math.log(calculateMarginals.partitionFunction);
        for (GraphicalModel.Factor factor2 : graphicalModel.factors) {
            Iterator<int[]> fastPassByReferenceIterator = factor2.featuresTable.fastPassByReferenceIterator();
            int[] next = fastPassByReferenceIterator.next();
            while (true) {
                double assignmentValue = calculateMarginals.jointMarginals.get(factor2).getAssignmentValue(next);
                if (assignmentValue > 0.0d) {
                    concatVector2.addVectorInPlace(factor2.featuresTable.getAssignmentValue(next).get(), -assignmentValue);
                }
                if (fastPassByReferenceIterator.hasNext()) {
                    fastPassByReferenceIterator.next();
                }
            }
        }
        Iterator<GraphicalModel.Factor> it2 = graphicalModel.factors.iterator();
        while (it2.hasNext()) {
            it2.next().featuresTable.releaseCache();
        }
        return new AbstractFunction.FunctionSummaryAtPoint(log, concatVector2);
    }

    private static int getDeterministicAssignment(double[] dArr) {
        int i = -1;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] == 1.0d) {
                if (i != -1) {
                    return -1;
                }
                i = i2;
            } else if (dArr[i2] != 0.0d) {
                return -1;
            }
        }
        return i;
    }
}
