package com.linkedin.dagli.fasttext.anonymized;

import com.linkedin.dagli.embedding.classification.FastTextInternal;
import com.linkedin.dagli.fasttext.anonymized.Args;
import com.linkedin.dagli.fasttext.anonymized.io.BufferedLineReader;
import com.linkedin.dagli.fasttext.anonymized.io.LineReader;
import com.linkedin.dagli.math.vector.DenseFloatArrayVector;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.lang.Thread;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.atomic.AtomicLong;

/* loaded from: input_file:com/linkedin/dagli/fasttext/anonymized/FastText.class */
public class FastText {
    private Args args_;
    private Dictionary dict_;
    private Matrix input_;
    private Matrix output_;
    private Model model_;
    private AtomicLong tokenCount_;
    private long start_;
    int threadCount;
    long threadFileSize;
    private String charsetName_ = "UTF-8";
    private Class<? extends LineReader> lineReaderClass_ = BufferedLineReader.class;
    protected Thread.UncaughtExceptionHandler trainThreadExceptionHandler = new Thread.UncaughtExceptionHandler() { // from class: com.linkedin.dagli.fasttext.anonymized.FastText.1
        @Override // java.lang.Thread.UncaughtExceptionHandler
        public void uncaughtException(Thread thread, Throwable th) {
            th.printStackTrace();
        }
    };

    /* loaded from: input_file:com/linkedin/dagli/fasttext/anonymized/FastText$TrainThread.class */
    public class TrainThread extends Thread {
        final FastText _fastText;
        int _threadId;
        final CyclicBarrier _startBarrier;

        public TrainThread(FastText fastText, int i, CyclicBarrier cyclicBarrier) {
            super("FT-TrainThread-" + i);
            this._fastText = fastText;
            this._threadId = i;
            this._startBarrier = cyclicBarrier;
        }

        @Override // java.lang.Thread, java.lang.Runnable
        public void run() {
            LineReader lineReader;
            Model model;
            if (FastText.this.args_.verbose > 2) {
                System.out.println("thread: " + this._threadId + " RUNNING!");
            }
            Exception exc = null;
            LineReader lineReader2 = null;
            try {
                lineReader = (LineReader) FastText.this.lineReaderClass_.getConstructor(String.class, String.class).newInstance(FastText.this.args_.input, FastText.this.charsetName_);
                lineReader.skipLine((this._threadId * FastText.this.threadFileSize) / FastText.this.args_.thread);
                model = new Model(FastText.this.input_, FastText.this.output_, FastText.this.args_, this._threadId);
            } catch (Exception e) {
                exc = e;
                if (0 != 0) {
                    try {
                        lineReader2.close();
                    } catch (IOException e2) {
                        e2.printStackTrace();
                    }
                }
            } catch (Throwable th) {
                if (0 != 0) {
                    try {
                        lineReader2.close();
                    } catch (IOException e3) {
                        e3.printStackTrace();
                    }
                }
                throw th;
            }
            if (FastText.this.args_.model != Args.model_name.sup) {
                throw new IllegalArgumentException();
            }
            model.setTargetCounts(FastText.this.dict_.getLabelCounts());
            long totalTokensRead = FastText.this.dict_.getTotalTokensRead();
            long j = 0;
            LongArrayList longArrayList = new LongArrayList(16);
            IntArrayList intArrayList = new IntArrayList(16);
            IntArrayList intArrayList2 = new IntArrayList(16);
            if (this._startBarrier != null) {
                this._startBarrier.await();
            }
            while (FastText.this.tokenCount_.get() < FastText.this.args_.epoch * totalTokensRead) {
                String[] readLineTokens = lineReader.readLineTokens();
                if (readLineTokens == null) {
                    try {
                        lineReader.rewind();
                        if (FastText.this.args_.verbose > 2) {
                            System.out.println("Input file reloaded!");
                        }
                    } catch (Exception e4) {
                        e4.printStackTrace();
                    }
                    readLineTokens = lineReader.readLineTokens();
                }
                float f = ((float) FastText.this.tokenCount_.get()) / ((float) (FastText.this.args_.epoch * totalTokensRead));
                float f2 = (float) (FastText.this.args_.lr * (1.0d - f));
                if (f2 <= 0.0f) {
                    break;
                }
                j += FastText.this.dict_.getLine(readLineTokens, longArrayList, intArrayList, intArrayList2);
                if (!intArrayList2.isEmpty() && !intArrayList.isEmpty()) {
                    FastText.this.supervised(model, f2, intArrayList, FastText.this.dict_.getNgramRowIDs(longArrayList, FastText.this.args_.wordNgrams, FastText.this.dict_.distinctWordCount(), FastText.this.args_.bucket), intArrayList2);
                    if (j > FastText.this.args_.lrUpdateRate) {
                        FastText.this.tokenCount_.addAndGet(j);
                        j = 0;
                        if (this._threadId == 0 && FastText.this.args_.verbose > 1 && (System.currentTimeMillis() - FastText.this.start_) % 1000 == 0) {
                            FastText.this.printInfo(f, model.getLoss());
                        }
                    }
                }
            }
            if (this._threadId == 0 && FastText.this.args_.verbose > 1) {
                FastText.this.printInfo(1.0f, model.getLoss());
            }
            if (lineReader != null) {
                try {
                    lineReader.close();
                } catch (IOException e5) {
                    e5.printStackTrace();
                }
            }
            synchronized (this._fastText) {
                if (FastText.this.args_.verbose > 2) {
                    System.out.println("\nthread: " + this._threadId + " EXIT!");
                }
                this._fastText.threadCount--;
                this._fastText.notify();
                if (exc != null) {
                    throw new RuntimeException(exc);
                }
            }
        }
    }

    public void printInfo(float f, float f2) {
        float currentTimeMillis = ((float) (System.currentTimeMillis() - this.start_)) / 1000.0f;
        int i = (int) ((currentTimeMillis / f) * (1.0f - f));
        int i2 = i / 3600;
        System.out.printf("\rProgress: %.1f%% words/sec: %d words/sec/thread: %d lr: %.6f loss: %.6f eta: %d h %d m", Float.valueOf(100.0f * f), Integer.valueOf((int) (((float) this.tokenCount_.get()) / currentTimeMillis)), Integer.valueOf((int) ((((float) this.tokenCount_.get()) / currentTimeMillis) / this.args_.thread)), Float.valueOf((float) (this.args_.lr * (1.0f - f))), Float.valueOf(f2), Integer.valueOf(i2), Integer.valueOf((i - (i2 * 3600)) / 60));
    }

    public void supervised(Model model, float f, IntArrayList intArrayList, IntArrayList intArrayList2, IntArrayList intArrayList3) {
        if (intArrayList3.isEmpty() || intArrayList.isEmpty()) {
            return;
        }
        model.update(intArrayList, intArrayList2, intArrayList3.getInt(Utils.randomInt(model.rng, 1, intArrayList3.size()) - 1), f);
    }

    public void loadVectors(String str) throws IOException {
        BufferedReader bufferedReader = null;
        try {
            try {
                BufferedReader bufferedReader2 = new BufferedReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
                String[] split = bufferedReader2.readLine().split(" ");
                int parseInt = Integer.parseInt(split[0]);
                int parseInt2 = Integer.parseInt(split[1]);
                ArrayList arrayList = new ArrayList(parseInt);
                if (parseInt2 != this.args_.dim) {
                    throw new IllegalArgumentException("Dimension of pretrained vectors does not match args -dim option, pretrain dim is " + parseInt2 + ", args dim is " + this.args_.dim);
                }
                Matrix matrix = new Matrix(parseInt, parseInt2);
                for (int i = 0; i < parseInt; i++) {
                    String[] split2 = bufferedReader2.readLine().split(" ");
                    String str2 = split2[0];
                    for (int i2 = 1; i2 <= parseInt2; i2++) {
                        matrix.data_[i][i2 - 1] = Float.parseFloat(split2[i2]);
                    }
                    arrayList.add(str2);
                    this.dict_.addWord(str2);
                }
                this.input_ = new Matrix(this.dict_.distinctWordCount() + this.args_.bucket, this.args_.dim);
                this.input_.uniform(1.0f / this.args_.dim);
                for (int i3 = 0; i3 < parseInt; i3++) {
                    int wordID = this.dict_.getWordID((String) arrayList.get(i3));
                    if (wordID >= 0 && wordID < this.dict_.distinctWordCount()) {
                        for (int i4 = 0; i4 < parseInt2; i4++) {
                            this.input_.data_[wordID][i4] = matrix.data_[i3][i4];
                        }
                    }
                }
                if (bufferedReader2 != null) {
                    try {
                        bufferedReader2.close();
                    } catch (IOException e) {
                        e.printStackTrace();
                    }
                }
            } catch (IOException e2) {
                throw new IOException("Pretrained vectors file cannot be opened!", e2);
            }
        } catch (Throwable th) {
            if (0 != 0) {
                try {
                    bufferedReader.close();
                } catch (IOException e3) {
                    e3.printStackTrace();
                    throw th;
                }
            }
            throw th;
        }
    }

    public int embeddingRowIndexForWordID(int i) {
        return i;
    }

    public static int embeddingRowIndexForNgramHash(long j, int i, int i2) {
        return (int) ((Math.abs(j) % i2) + i);
    }

    public FastTextInternal.Model<String> train(FastTextOptions fastTextOptions) throws IOException, InterruptedException {
        this.args_ = fastTextOptions.getArgs();
        this.dict_ = new Dictionary(this.args_);
        this.dict_.setCharsetName(this.charsetName_);
        this.dict_.setLineReaderClass(this.lineReaderClass_);
        if ("-".equals(this.args_.input)) {
            throw new IOException("Cannot use stdin for training!");
        }
        File file = new File(this.args_.input);
        if (!file.exists() || !file.isFile() || !file.canRead()) {
            throw new IOException("Input file cannot be opened! " + this.args_.input);
        }
        if (this.args_.verbose > 0) {
            PrintStream printStream = System.err;
            long exampleCount = fastTextOptions.getExampleCount();
            String str = this.args_.input;
            printStream.println("Reading " + exampleCount + " examples from file " + printStream);
            if (fastTextOptions.getSynchronizedStart()) {
                System.err.println("Synchronized start has been selected.  There will be a delay after the first pass over the data while the threads find their start positions in the input data file.");
            }
        }
        this.dict_.readFromFile(this.args_.input);
        this.threadFileSize = fastTextOptions.getExampleCount();
        if (Utils.isEmpty(this.args_.pretrainedVectors)) {
            this.input_ = new Matrix(this.dict_.distinctWordCount() + this.args_.bucket, this.args_.dim);
            this.input_.uniform(1.0f / this.args_.dim);
        } else {
            loadVectors(this.args_.pretrainedVectors);
        }
        if (this.args_.model != Args.model_name.sup) {
            throw new IllegalArgumentException();
        }
        this.output_ = new Matrix(this.dict_.distinctLabelCount(), this.args_.dim);
        this.output_.zero();
        this.start_ = System.currentTimeMillis();
        this.tokenCount_ = new AtomicLong(0L);
        long currentTimeMillis = System.currentTimeMillis();
        this.threadCount = this.args_.thread;
        ArrayList arrayList = new ArrayList(this.args_.thread);
        CyclicBarrier cyclicBarrier = fastTextOptions.getSynchronizedStart() ? new CyclicBarrier(this.args_.thread) : null;
        for (int i = 0; i < this.args_.thread; i++) {
            TrainThread trainThread = new TrainThread(this, i, cyclicBarrier);
            trainThread.setUncaughtExceptionHandler(this.trainThreadExceptionHandler);
            trainThread.start();
            arrayList.add(trainThread);
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((Thread) it.next()).join();
        }
        if (this.threadCount != 0) {
            throw new RuntimeException("Not all training threads completed successfully");
        }
        this.model_ = new Model(this.input_, this.output_, this.args_, 0);
        if (this.args_.verbose > 1) {
            System.out.printf("\nTrain time used: %d sec\n", Long.valueOf((System.currentTimeMillis() - currentTimeMillis) / 1000));
        }
        if (this.args_.model != Args.model_name.sup) {
            throw new IllegalArgumentException();
        }
        Long2ObjectOpenHashMap long2ObjectOpenHashMap = new Long2ObjectOpenHashMap(this.dict_.distinctWordCount());
        this.dict_.getWordIDMap().long2IntEntrySet().fastForEach(entry -> {
            long2ObjectOpenHashMap.put(entry.getLongKey(), DenseFloatArrayVector.wrap(this.model_.getInputEmbedding(embeddingRowIndexForWordID(entry.getIntValue()))));
        });
        return new FastTextInternal.Model<>(this.dict_.getLabels(), (DenseFloatArrayVector[]) Arrays.stream(this.model_.getLabelEmbeddings()).map(fArr -> {
            return DenseFloatArrayVector.wrap(fArr);
        }).toArray(i2 -> {
            return new DenseFloatArrayVector[i2];
        }), long2ObjectOpenHashMap, (DenseFloatArrayVector[]) Arrays.stream(this.model_.getInputEmbeddingsStartingAtRow(this.dict_.distinctWordCount())).map(fArr2 -> {
            return DenseFloatArrayVector.wrap(fArr2);
        }).toArray(i3 -> {
            return new DenseFloatArrayVector[i3];
        }), fastTextOptions.getMultilabel(), this.args_.wordNgrams);
    }

    public Args getArgs() {
        return this.args_;
    }

    public Dictionary getDict() {
        return this.dict_;
    }

    public String getLabel(int i) {
        return this.dict_.getLabel(i);
    }

    public Matrix getInput() {
        return this.input_;
    }

    public Matrix getOutput() {
        return this.output_;
    }

    public Model getModel() {
        return this.model_;
    }

    public void setArgs(Args args) {
        this.args_ = args;
    }

    public void setDict(Dictionary dictionary) {
        this.dict_ = dictionary;
    }

    public void setInput(Matrix matrix) {
        this.input_ = matrix;
    }

    public void setOutput(Matrix matrix) {
        this.output_ = matrix;
    }

    public void setModel(Model model) {
        this.model_ = model;
    }

    public String getCharsetName() {
        return this.charsetName_;
    }

    public Class<? extends LineReader> getLineReaderClass() {
        return this.lineReaderClass_;
    }

    public void setCharsetName(String str) {
        this.charsetName_ = str;
    }

    public void setLineReaderClass(Class<? extends LineReader> cls) {
        this.lineReaderClass_ = cls;
    }
}
