package com.linkedin.dagli.liblinear;

import com.jeffreypasternack.liblinear.Feature;
import com.jeffreypasternack.liblinear.FeatureNode;
import com.jeffreypasternack.liblinear.Linear;
import com.jeffreypasternack.liblinear.Model;
import com.jeffreypasternack.liblinear.Parameter;
import com.jeffreypasternack.liblinear.Problem;
import com.jeffreypasternack.liblinear.SolverType;
import com.linkedin.dagli.annotation.equality.DeepArrayValueEquality;
import com.linkedin.dagli.annotation.equality.ValueEquality;
import com.linkedin.dagli.input.DenseFeatureVectorInput;
import com.linkedin.dagli.liblinear.AbstractLiblinearTransformer;
import com.linkedin.dagli.math.distribution.ArrayDiscreteDistribution;
import com.linkedin.dagli.math.distribution.BinaryDistribution;
import com.linkedin.dagli.math.distribution.DiscreteDistribution;
import com.linkedin.dagli.math.vector.DenseDoubleArrayVector;
import com.linkedin.dagli.math.vector.DenseVector;
import com.linkedin.dagli.math.vector.Vector;
import com.linkedin.dagli.preparer.AbstractStreamPreparer3;
import com.linkedin.dagli.preparer.PreparerContext;
import com.linkedin.dagli.preparer.PreparerResult;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.util.invariant.Arguments;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.NoSuchElementException;

@ValueEquality
/* loaded from: input_file:com/linkedin/dagli/liblinear/LiblinearClassification.class */
public class LiblinearClassification<L> extends AbstractLiblinearTransformer<L, DiscreteDistribution<L>, Prepared<L>, LiblinearClassification<L>> {
    private static final long serialVersionUID = 1;

    @ValueEquality
    /* loaded from: input_file:com/linkedin/dagli/liblinear/LiblinearClassification$Prepared.class */
    public static class Prepared<L> extends AbstractLiblinearTransformer.Prepared<L, DiscreteDistribution<L>, Prepared<L>> {
        private static final long serialVersionUID = 1;

        @DeepArrayValueEquality
        private final L[] _labels;
        private final boolean _isBinary;

        public Vector getWeightsForLabel(L l) {
            int indexOf = getLabels().indexOf(l);
            if (indexOf < 0) {
                throw new NoSuchElementException("The specified label " + l + " is not predicted by this model");
            }
            return getWeightsForLabelIndex(indexOf);
        }

        public double getBiasForLabel(L l) {
            int indexOf = getLabels().indexOf(l);
            if (indexOf < 0) {
                throw new NoSuchElementException("The specified label " + l + " is not predicted by this model");
            }
            return getBiasForLabelIndex(indexOf);
        }

        public List<L> getLabels() {
            if (this._labels.length != 1 || !(this._labels[0] instanceof Boolean)) {
                return Arrays.asList(this._labels);
            }
            ArrayList arrayList = new ArrayList(2);
            arrayList.add(this._labels[0]);
            arrayList.add(Boolean.valueOf(!((Boolean) arrayList.get(0)).booleanValue()));
            return arrayList;
        }

        private double getWeightForLabel(int i, int i2) {
            return (i2 == 1 && this._model.getNrClass() == 1) ? -this._model.getDecfunCoef(i, 0) : this._model.getDecfunCoef(i, i2);
        }

        private void checkIndex(int i) {
            if (this._labels.length <= 2) {
                if (i >= 2) {
                    throw new IndexOutOfBoundsException("Requested weights for label index " + i + " in a binary model");
                }
            } else if (i > this._labels.length) {
                throw new IndexOutOfBoundsException("Requested index for non-existant label: " + i);
            }
        }

        public DenseVector getWeightsForLabelIndex(int i) {
            checkIndex(i);
            double[] dArr = new double[this._featureCount];
            for (int i2 = 0; i2 < this._featureCount; i2++) {
                dArr[i2] = getWeightForLabel(i2 + 1, i);
            }
            return DenseDoubleArrayVector.wrap(dArr);
        }

        public double getBiasForLabelIndex(int i) {
            checkIndex(i);
            if (this._bias <= 0.0d) {
                return 0.0d;
            }
            return this._bias * getWeightForLabel(this._featureCount + 1, i);
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer.Prepared
        public Model getModel() {
            return this._model;
        }

        Prepared(double d, Model model, Object2IntOpenHashMap<L> object2IntOpenHashMap, int i) {
            super(d, model, i);
            this._labels = (L[]) new Object[object2IntOpenHashMap.size()];
            object2IntOpenHashMap.forEach((obj, num) -> {
                ((L[]) this._labels)[num.intValue()] = obj;
            });
            if (this._labels.length == 2) {
                this._isBinary = (this._labels[0] instanceof Boolean) && (this._labels[1] instanceof Boolean);
            } else if (this._labels.length == 1) {
                this._isBinary = this._labels[0] instanceof Boolean;
            } else {
                this._isBinary = false;
            }
            if ((this._labels.length > 2 && this._model.getNrClass() != this._labels.length) || (this._labels.length <= 2 && this._model.getNrClass() > 2)) {
                throw new IllegalStateException("Dagli wrapper and Liblinear disagree on the number of classes in the model: Dagli thinks there " + this._labels.length + " and Liblinear thinks there are " + this._model.getNrClass());
            }
        }

        public DiscreteDistribution<L> apply(Number number, L l, DenseVector denseVector) {
            ArrayList arrayList = new ArrayList(Math.toIntExact(denseVector.size64() + (this._bias >= 0.0d ? 1 : 0)));
            denseVector.forEach((j, d) -> {
                long j = j + serialVersionUID;
                if (j <= this._featureCount) {
                    arrayList.add(new FeatureNode(Math.toIntExact(j), d));
                }
            });
            if (this._bias >= 0.0d) {
                arrayList.add(new FeatureNode(this._featureCount + 1, this._bias));
            }
            double[] dArr = new double[this._labels.length];
            Linear.predictProbability(this._model, (Feature[]) arrayList.toArray(new Feature[0]), dArr);
            if (this._isBinary) {
                return new BinaryDistribution(this._labels[0].equals(Boolean.TRUE) ? dArr[0] : 1.0d - dArr[0]);
            }
            return new ArrayDiscreteDistribution(this._labels, dArr);
        }

        /* JADX WARN: Multi-variable type inference failed */
        public /* bridge */ /* synthetic */ Object apply(Object obj, Object obj2, Object obj3) {
            return apply((Number) obj, (Number) obj2, (DenseVector) obj3);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/linkedin/dagli/liblinear/LiblinearClassification$Preparer.class */
    public static class Preparer<L> extends AbstractStreamPreparer3<Number, L, DenseVector, DiscreteDistribution<L>, Prepared<L>> {
        private final LiblinearClassification<L> _owner;
        private static final FeatureNode BIAS_PLACEHOLDER_FEATURE = new FeatureNode(Integer.MAX_VALUE, 0.0d);
        private final List<Feature[]> _exampleFeatures;
        private final DoubleList _exampleLabels;
        private final Problem _problem = new Problem();
        private int _maxFeatureIndex = 0;
        private final Object2IntOpenHashMap<L> _labelIDMap = new Object2IntOpenHashMap<>();

        public Preparer(PreparerContext preparerContext, LiblinearClassification<L> liblinearClassification) {
            this._labelIDMap.defaultReturnValue(-1);
            int min = (int) Math.min(2147483647L, preparerContext.getEstimatedExampleCount());
            this._exampleFeatures = new ArrayList(min);
            this._exampleLabels = new DoubleArrayList(min);
            this._problem.bias = -1.0d;
            this._owner = liblinearClassification;
        }

        public void process(Number number, L l, DenseVector denseVector) {
            Feature[] featureArr;
            long size64 = denseVector.size64();
            if (this._owner.getBias() >= 0.0d) {
                featureArr = new Feature[Math.toIntExact(size64 + LiblinearClassification.serialVersionUID)];
                featureArr[featureArr.length - 1] = BIAS_PLACEHOLDER_FEATURE;
            } else {
                featureArr = new Feature[Math.toIntExact(size64)];
            }
            int[] iArr = new int[1];
            Feature[] featureArr2 = featureArr;
            denseVector.forEach((j, d) -> {
                int i = iArr[0];
                iArr[0] = i + 1;
                featureArr2[i] = new FeatureNode(Math.toIntExact(j + LiblinearClassification.serialVersionUID), d);
            });
            if (size64 > 0) {
                this._maxFeatureIndex = Math.max(this._maxFeatureIndex, featureArr[Math.toIntExact(size64 - LiblinearClassification.serialVersionUID)].getIndex());
            }
            int i = this._labelIDMap.getInt(l);
            if (i < 0) {
                i = this._labelIDMap.size();
                this._labelIDMap.put(l, i);
            }
            this._exampleFeatures.add(featureArr);
            this._exampleLabels.add(i);
        }

        /* renamed from: finish, reason: merged with bridge method [inline-methods] */
        public PreparerResult<Prepared<L>> m2finish() {
            this._problem.l = this._exampleFeatures.size();
            this._problem.n = this._maxFeatureIndex + (this._owner.getBias() >= 0.0d ? 1 : 0);
            this._problem.x = (Feature[][]) this._exampleFeatures.toArray(new Feature[0]);
            this._problem.y = this._exampleLabels.toDoubleArray();
            if (this._owner.getBias() >= 0.0d) {
                Feature featureNode = new FeatureNode(this._maxFeatureIndex + 1, this._owner.getBias());
                for (Feature[] featureArr : this._problem.x) {
                    featureArr[featureArr.length - 1] = featureNode;
                }
            }
            Parameter parameter = new Parameter(this._owner.getSolverType(), this._owner.getLikelihoodVersusRegularizationLossMultiplier(), this._owner.getEpsilon(), this._owner.getSVREpsilonLoss());
            parameter.setThreadCount(this._owner.getThreadCount());
            Linear.setDebugOutput(this._owner._silent ? null : System.err);
            return new PreparerResult<>(new Prepared(this._owner._bias, Linear.train(this._problem, parameter), this._labelIDMap, this._maxFeatureIndex));
        }

        /* JADX WARN: Multi-variable type inference failed */
        public /* bridge */ /* synthetic */ void process(Object obj, Object obj2, Object obj3) {
            process((Number) obj, (Number) obj2, (DenseVector) obj3);
        }
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public LiblinearClassification<L> withSolverType(SolverType solverType) {
        Arguments.check(solverType.isLogisticRegressionSolver());
        return clone(liblinearClassification -> {
            liblinearClassification._solverType = solverType;
        });
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: getPreparer, reason: merged with bridge method [inline-methods] */
    public Preparer<L> m0getPreparer(PreparerContext preparerContext) {
        return new Preparer<>(preparerContext, this);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withThreadCount(int i) {
        return super.withThreadCount(i);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withSVREpsilonLoss(double d) {
        return super.withSVREpsilonLoss(d);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withEpsilon(double d) {
        return super.withEpsilon(d);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withLikelihoodVersusRegularizationLossMultiplier(double d) {
        return super.withLikelihoodVersusRegularizationLossMultiplier(d);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withBias(double d) {
        return super.withBias(d);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withSilent(boolean z) {
        return super.withSilent(z);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ DenseFeatureVectorInput withFeaturesInput() {
        return super.withFeaturesInput();
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withFeaturesInput(Producer producer) {
        return super.withFeaturesInput(producer);
    }

    @Override // com.linkedin.dagli.liblinear.AbstractLiblinearTransformer
    public /* bridge */ /* synthetic */ AbstractLiblinearTransformer withLabelInput(Producer producer) {
        return super.withLabelInput(producer);
    }
}
