/*
 * Decompiled with CFR 0.152.
 */
package cn.langpy.nlp2cron.core;

import cn.langpy.nlp2cron.core.CrondConfig;
import com.alibaba.fastjson.JSONObject;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.RawTensor;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.NdArray;
import org.tensorflow.ndarray.StdArrays;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;

public class CrondModel {
    static Session model = null;
    static CrondConfig config = null;

    public static void init(String path) {
        CrondModel.loadModel(path);
    }

    public static void init(CrondConfig crondConfig) {
        config = crondConfig;
        CrondModel.loadModel(config.getModelPath());
    }

    public static void init() {
        CrondModel.loadModel(config.getModelPath());
    }

    private static void loadModel(String path) {
        try {
            SavedModelBundle savedModel = SavedModelBundle.load((String)path, (String[])new String[]{"serve"});
            model = savedModel.session();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    public static String predict(String message) {
        Tensor output = (Tensor)model.runner().feed("serving_default_input_1:0", CrondModel.toVec(message)).fetch("StatefulPartitionedCall:0").run().get(0);
        RawTensor rawTensor = output.asRawTensor();
        FloatDataBuffer floatDataBuffer = rawTensor.data().asFloats();
        ArrayList<Float> labelIndexs = new ArrayList<Float>();
        ArrayList<String> crons = new ArrayList<String>();
        JSONObject id2Word = config.getOutputId2Word();
        for (long i = 0L; i < floatDataBuffer.size(); ++i) {
            if ((i + 1L) % 127L != 0L) {
                labelIndexs.add(Float.valueOf(floatDataBuffer.getFloat(i)));
                continue;
            }
            int maxIndex = CrondModel.argmax(labelIndexs);
            String string = id2Word.getString(maxIndex + "");
            crons.add(string);
            labelIndexs.clear();
        }
        rawTensor.close();
        return String.join((CharSequence)"#", crons);
    }

    private static int argmax(List<Float> labelIndexs) {
        double maxValue = 0.0;
        int maxIndex = 0;
        for (int i = 0; i < labelIndexs.size(); ++i) {
            if (!((double)labelIndexs.get(i).floatValue() > maxValue)) continue;
            maxValue = labelIndexs.get(i).floatValue();
            maxIndex = i;
        }
        return maxIndex;
    }

    private static Tensor toVec(String message) {
        float[][] vecShape = new float[1][40];
        JSONObject word2id = config.getInputWord2Id();
        char[] chars = message.toCharArray();
        int i = 0;
        for (char aChar : chars) {
            vecShape[0][i] = word2id.containsKey((Object)(aChar + "")) ? (float)word2id.getInteger(aChar + "").intValue() : (float)word2id.getInteger("<UNK>").intValue();
            ++i;
        }
        FloatNdArray ndArray = StdArrays.ndCopyOf((float[][])vecShape);
        return TFloat32.tensorOf((NdArray)ndArray);
    }

    public static void close() {
        model.close();
    }

    static {
        config = new CrondConfig();
    }
}

