package com.linkedin.dagli.fasttext;

import com.linkedin.dagli.embedding.classification.FastTextInternal;
import com.linkedin.dagli.fasttext.AbstractFastTextModel;
import com.linkedin.dagli.fasttext.anonymized.Args;
import com.linkedin.dagli.fasttext.anonymized.FastText;
import com.linkedin.dagli.fasttext.anonymized.FastTextOptions;
import com.linkedin.dagli.function.FunctionResult1;
import com.linkedin.dagli.producer.Producer;
import com.linkedin.dagli.transformer.AbstractPreparableTransformer2;
import com.linkedin.dagli.transformer.PreparedTransformer2;
import com.linkedin.dagli.util.cryptography.Cryptography;
import com.linkedin.dagli.util.environment.DagliSystemProperties;
import com.linkedin.dagli.util.io.SerializableTempFile;
import com.linkedin.migz.MiGzOutputStream;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Serializable;
import java.io.UncheckedIOException;
import java.lang.invoke.SerializedLambda;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.Iterator;

/* loaded from: input_file:com/linkedin/dagli/fasttext/AbstractFastTextModel.class */
abstract class AbstractFastTextModel<L extends Serializable, R, N extends PreparedTransformer2<Iterable<? extends L>, Iterable<? extends CharSequence>, R>, S extends AbstractFastTextModel<L, R, N, S>> extends AbstractPreparableTransformer2<Iterable<? extends L>, Iterable<? extends CharSequence>, R, N, S> {
    private static final long serialVersionUID = 1;
    protected int _topK = 10;
    protected int _verbosity = 2;
    protected int _minTokenCount = 5;
    protected int _minLabelCount = 0;
    protected int _maxWordNgramLength = 1;
    protected int _bucketCount = 2000000;
    protected double _samplingThreshold = 1.0E-4d;
    protected double _learningRate = 0.05d;
    protected int _learningRateUpdateRate = 100;
    protected int _embeddingLength = 100;
    protected int _epochs = 5;
    protected int _sampledNegatives = 5;
    protected FastTextLoss _loss = FastTextLoss.NEGATIVE_SAMPLING;
    protected int _threads = -1;
    protected boolean _isMultilabel = false;
    protected boolean _synchronizedTrainingStart = false;
    protected SerializableTempFile _pretrainedEmbeddings = null;
    protected FastTextDataSerializationMode _dataSerializationMode = FastTextDataSerializationMode.DEFAULT;

    /* loaded from: input_file:com/linkedin/dagli/fasttext/AbstractFastTextModel$Trainer.class */
    protected static class Trainer<T extends Serializable> {
        private final OutputStreamWriter _exampleWriter;
        private final Path _tempFile;
        private final AbstractFastTextModel<T, ?, ?, ?> _owner;
        private long _exampleCount = 0;
        private final Object2IntOpenHashMap<T> _labelMap = new Object2IntOpenHashMap<>();

        public Trainer(AbstractFastTextModel<T, ?, ?, ?> abstractFastTextModel) {
            this._owner = abstractFastTextModel;
            try {
                this._tempFile = Files.createTempFile(Paths.get(DagliSystemProperties.getTempDirectory(), new String[0]), "FastText", ".dat", new FileAttribute[0]);
                this._tempFile.toFile().deleteOnExit();
                OutputStream newOutputStream = Files.newOutputStream(this._tempFile, new OpenOption[0]);
                newOutputStream = abstractFastTextModel.getDataSerializationMode().isEncrypted() ? Cryptography.getOutputStream(newOutputStream) : newOutputStream;
                if (abstractFastTextModel.getDataSerializationMode().isCompressed()) {
                    newOutputStream = new MiGzOutputStream(newOutputStream, this._owner._threads >= 1 ? this._owner._threads : Runtime.getRuntime().availableProcessors(), 524288);
                }
                this._exampleWriter = new OutputStreamWriter(newOutputStream);
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            } catch (NoSuchAlgorithmException e2) {
                throw new UnsupportedOperationException("The requisite encryption algorithm is not supported by this JVM", e2);
            }
        }

        private int getLabelIndex(T t) {
            return ((Integer) this._labelMap.computeIfAbsent(t, serializable -> {
                return Integer.valueOf(this._labelMap.size());
            })).intValue();
        }

        public FastTextInternal.Model<T> finish() {
            try {
                this._exampleWriter.close();
                Args args = new Args();
                args.hashBits = 63;
                String[] strArr = new String[31];
                strArr[0] = "supervised";
                strArr[1] = "-input";
                strArr[2] = this._tempFile.toAbsolutePath().toString();
                strArr[3] = "-verbose";
                strArr[4] = String.valueOf(this._owner._verbosity);
                strArr[5] = "-minCount";
                strArr[6] = String.valueOf(this._owner._minTokenCount);
                strArr[7] = "-minCountLabel";
                strArr[8] = String.valueOf(this._owner._minLabelCount);
                strArr[9] = "-wordNgrams";
                strArr[10] = String.valueOf(this._owner._maxWordNgramLength);
                strArr[11] = "-bucket";
                strArr[12] = String.valueOf(this._owner._bucketCount);
                strArr[13] = "-t";
                strArr[14] = String.valueOf(this._owner._samplingThreshold);
                strArr[15] = "-lr";
                strArr[16] = String.valueOf(this._owner._learningRate);
                strArr[17] = "-lrUpdateRate";
                strArr[18] = String.valueOf(this._owner._learningRateUpdateRate);
                strArr[19] = "-dim";
                strArr[20] = String.valueOf(this._owner._embeddingLength);
                strArr[21] = "-epoch";
                strArr[22] = String.valueOf(this._owner._epochs);
                strArr[23] = "-neg";
                strArr[24] = String.valueOf(this._owner._sampledNegatives);
                strArr[25] = "-loss";
                strArr[26] = this._owner._loss.getArgumentName();
                strArr[27] = "-pretrainedVectors";
                strArr[28] = this._owner._pretrainedEmbeddings == null ? "" : this._owner._pretrainedEmbeddings.getFile().getAbsolutePath();
                strArr[29] = "-thread";
                strArr[30] = String.valueOf(this._owner._threads >= 1 ? this._owner._threads : Runtime.getRuntime().availableProcessors());
                args.parseArgs(strArr);
                Serializable[] serializableArr = new Serializable[this._labelMap.size()];
                this._labelMap.object2IntEntrySet().fastForEach(entry -> {
                    serializableArr[entry.getIntValue()] = (Serializable) entry.getKey();
                });
                try {
                    FastText fastText = new FastText();
                    fastText.setLineReaderClass(this._owner.getDataSerializationMode().getLineReaderClass());
                    return fastText.train(FastTextOptions.Builder.setArgs(args).setMultilabel(this._owner._isMultilabel).setExampleCount(this._exampleCount).setSynchronizedStart(this._owner._synchronizedTrainingStart).build()).remapLabels(str -> {
                        return serializableArr[FastTextInternal.Util.integerFromLabelString(str)];
                    });
                } catch (Exception e) {
                    throw new RuntimeException(e);
                }
            } catch (IOException e2) {
                throw new UncheckedIOException(e2);
            }
        }

        public void process(Iterable<? extends T> iterable, Iterable<? extends CharSequence> iterable2) {
            this._exampleCount += AbstractFastTextModel.serialVersionUID;
            try {
                Iterator<? extends T> it = iterable.iterator();
                while (it.hasNext()) {
                    this._exampleWriter.write(FastTextInternal.Util.formatIntegerLabel(getLabelIndex(it.next())) + " ");
                }
                Iterator<? extends CharSequence> it2 = iterable2.iterator();
                while (it2.hasNext()) {
                    this._exampleWriter.write(FastTextInternal.Util.formatToken(it2.next()) + " ");
                }
                this._exampleWriter.write(10);
            } catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }
    }

    public boolean getSynchronizedTrainingStart() {
        return this._synchronizedTrainingStart;
    }

    public S withSynchronizedTrainingStart(boolean z) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._synchronizedTrainingStart = z;
        });
    }

    public int getMaxPredictionCount() {
        return this._topK;
    }

    public int getVerbosity() {
        return this._verbosity;
    }

    public int getMinTokenCount() {
        return this._minTokenCount;
    }

    public int getMinLabelCount() {
        return this._minLabelCount;
    }

    public int getMaxWordNgramLength() {
        return this._maxWordNgramLength;
    }

    public int getBucketCount() {
        return this._bucketCount;
    }

    public double getLearningRate() {
        return this._learningRate;
    }

    public int getLearningRateUpdateRate() {
        return this._learningRateUpdateRate;
    }

    public int getEmbeddingLength() {
        return this._embeddingLength;
    }

    public int getEpochCount() {
        return this._epochs;
    }

    public int getSampledNegativeCount() {
        return this._sampledNegatives;
    }

    public FastTextLoss getLossType() {
        return this._loss;
    }

    public int getThreadCount() {
        return this._threads;
    }

    public boolean isMultilabel() {
        return this._isMultilabel;
    }

    public FastTextDataSerializationMode getDataSerializationMode() {
        return this._dataSerializationMode;
    }

    public S withLabelsInput(Producer<? extends Iterable<? extends L>> producer) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._input1 = producer;
        });
    }

    public S withLabelInput(Producer<? extends L> producer) {
        return withLabelsInput(new FunctionResult1().withFunction((v0) -> {
            return Collections.singleton(v0);
        }).withInput(producer));
    }

    public S withTokensInput(Producer<? extends Iterable<? extends CharSequence>> producer) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._input2 = producer;
        });
    }

    public S withPretrainedEmbeddings(File file) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._pretrainedEmbeddings = file == null ? null : new SerializableTempFile(file, "FastTextPretrained", ".vec");
        });
    }

    public S withMaxPredictionCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._topK = i;
        });
    }

    public S withVerbosity(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._verbosity = i;
        });
    }

    public S withMinTokenCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._minTokenCount = i;
        });
    }

    public S withMinLabelCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._minLabelCount = i;
        });
    }

    public S withMaxWordNgramLength(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._maxWordNgramLength = i;
        });
    }

    public S withBucketCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._bucketCount = i;
        });
    }

    public S withLearningRate(double d) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._learningRate = d;
        });
    }

    public S withLearningRateUpdateRate(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._learningRateUpdateRate = i;
        });
    }

    public S withEmbeddingLength(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._embeddingLength = i;
        });
    }

    public S withEpochCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._epochs = i;
        });
    }

    public S withSampledNegativeCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._sampledNegatives = i;
        });
    }

    public S withLossType(FastTextLoss fastTextLoss) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._loss = fastTextLoss;
        });
    }

    public S withThreadCount(int i) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._threads = i;
        });
    }

    public S withMultilabel(boolean z) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._isMultilabel = z;
        });
    }

    public S withDataSerializationMode(FastTextDataSerializationMode fastTextDataSerializationMode) {
        return clone(abstractFastTextModel -> {
            abstractFastTextModel._dataSerializationMode = fastTextDataSerializationMode;
        });
    }

    protected Producer<? extends Iterable<? extends L>> getLabelsInput() {
        return this._input1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Producer<? extends Iterable<? extends CharSequence>> getTokensInput() {
        return this._input2;
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1494517749:
                if (implMethodName.equals("singleton")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("com/linkedin/dagli/util/function/Function1$Serializable") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("java/util/Collections") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/util/Set;")) {
                    return (v0) -> {
                        return Collections.singleton(v0);
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
