package hex.api;

import hex.schemas.Word2VecSynonymsV3;
import hex.schemas.Word2VecTransformV3;
import hex.word2vec.Word2VecModel;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.Map;
import water.DKV;
import water.api.Handler;
import water.api.schemas3.KeyV3;
import water.fvec.Frame;

/* loaded from: input_file:hex/api/Word2VecHandler.class */
public class Word2VecHandler extends Handler {
    public Word2VecSynonymsV3 findSynonyms(int i, Word2VecSynonymsV3 word2VecSynonymsV3) {
        Word2VecModel get = DKV.getGet(word2VecSynonymsV3.model.key());
        if (get == null) {
            throw new IllegalArgumentException("missing source model " + word2VecSynonymsV3.model);
        }
        ArrayList arrayList = new ArrayList(get.findSynonyms(word2VecSynonymsV3.word, word2VecSynonymsV3.count).entrySet());
        Collections.sort(arrayList, new Comparator<Map.Entry<String, Float>>() { // from class: hex.api.Word2VecHandler.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Float> entry, Map.Entry<String, Float> entry2) {
                return entry2.getValue().compareTo(entry.getValue());
            }
        });
        word2VecSynonymsV3.synonyms = new String[arrayList.size()];
        word2VecSynonymsV3.scores = new double[arrayList.size()];
        int i2 = 0;
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            word2VecSynonymsV3.synonyms[i2] = (String) ((Map.Entry) it.next()).getKey();
            word2VecSynonymsV3.scores[i2] = ((Float) r0.getValue()).floatValue();
            i2++;
        }
        return word2VecSynonymsV3;
    }

    public Word2VecTransformV3 transform(int i, Word2VecTransformV3 word2VecTransformV3) {
        Word2VecModel get = DKV.getGet(word2VecTransformV3.model.key());
        if (get == null) {
            throw new IllegalArgumentException("missing source model " + word2VecTransformV3.model);
        }
        Frame get2 = DKV.getGet(word2VecTransformV3.words_frame.key());
        if (get2 == null) {
            throw new IllegalArgumentException("missing words frame " + word2VecTransformV3.words_frame);
        }
        if (get2.numCols() != 1) {
            throw new IllegalArgumentException("words frame is expected to have a single string column, got" + get2.numCols());
        }
        word2VecTransformV3.vectors_frame = new KeyV3.FrameKeyV3(get.transform(get2.vec(0))._key);
        return word2VecTransformV3;
    }
}
