/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.geneexpbase.classification.svm;

import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import de.julielab.geneexpbase.classification.FeatureUtils;
import de.julielab.geneexpbase.classification.MinMaxScalingStats;
import de.julielab.geneexpbase.classification.StandardizationStats;
import de.julielab.geneexpbase.classification.svm.SVMModel;
import de.julielab.geneexpbase.classification.svm.SVMTrainOptions;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.net.URL;
import java.util.Arrays;
import java.util.Map;
import java.util.stream.IntStream;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import libsvm.svm;
import libsvm.svm_model;
import libsvm.svm_node;
import libsvm.svm_parameter;
import libsvm.svm_problem;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SVM {
    private static final Logger log = LoggerFactory.getLogger(SVM.class);

    public static svm_problem getSvmProblem(double[] labels, double[][] featureMatrix) {
        svm_node[][] nodeMatrix = new svm_node[featureMatrix.length][];
        for (int i = 0; i < featureMatrix.length; ++i) {
            double[] features = featureMatrix[i];
            svm_node[] nodeVector = new svm_node[features.length];
            for (int j = 0; j < features.length; ++j) {
                double featureValue = features[j];
                svm_node node = new svm_node();
                node.index = j + 1;
                node.value = featureValue;
                nodeVector[j] = node;
            }
            nodeMatrix[i] = nodeVector;
        }
        svm_problem prob = new svm_problem();
        prob.l = featureMatrix.length;
        prob.x = nodeMatrix;
        prob.y = labels;
        return prob;
    }

    public static svm_problem getSvmProblem(InstanceList instances) {
        svm_node[][] nodeMatrix = new svm_node[instances.size()][];
        double[] labels = new double[instances.size()];
        for (int i = 0; i < instances.size(); ++i) {
            Instance instance = (Instance)instances.get(i);
            FeatureVector fv = (FeatureVector)instance.getData();
            svm_node[] nodeVector = new svm_node[fv.numLocations()];
            int svmVectorIndex = 0;
            int numFilledIndices = 0;
            int[] indices = fv.getIndices();
            for (int j = 0; j < fv.numLocations(); ++j) {
                double featureValue;
                int index = indices != null ? indices[j] : j;
                double d = featureValue = fv.isBinary() ? 1.0 : fv.getValues()[j];
                if (featureValue == 0.0) continue;
                svm_node node = new svm_node();
                node.index = index + 1;
                node.value = featureValue;
                nodeVector[svmVectorIndex++] = node;
                ++numFilledIndices;
            }
            if (numFilledIndices < nodeVector.length) {
                svm_node[] newVector = new svm_node[numFilledIndices];
                System.arraycopy(nodeVector, 0, newVector, 0, numFilledIndices);
                nodeVector = newVector;
            }
            nodeMatrix[i] = nodeVector;
            assert (instance.getTarget() != null) : "For training, all instances must have their target label set but an instance without a target occurred.";
            labels[i] = ((Label)instance.getTarget()).getIndex();
        }
        svm_problem prob = new svm_problem();
        prob.l = instances.size();
        prob.x = nodeMatrix;
        prob.y = labels;
        return prob;
    }

    public static double[] predict(double[] features, svm_model model) {
        svm_node[] featureVector = new svm_node[features.length];
        for (int i = 0; i < features.length; ++i) {
            double featureValue = features[i];
            svm_node node = new svm_node();
            node.index = i + 1;
            node.value = featureValue;
            featureVector[i] = node;
        }
        double[] prob_estimates = new double[model.nr_class];
        svm.svm_predict_probability(model, featureVector, prob_estimates);
        return prob_estimates;
    }

    public static double[] predictProbability(double[] features, svm_model model) {
        svm_node[] featureVector = new svm_node[features.length];
        for (int i = 0; i < features.length; ++i) {
            double featureValue = features[i];
            svm_node node = new svm_node();
            node.index = i + 1;
            node.value = featureValue;
            featureVector[i] = node;
        }
        double[] prob_estimates = new double[2];
        svm.svm_predict_probability(model, featureVector, prob_estimates);
        return prob_estimates;
    }

    public static SVMModel train(InstanceList instances, SVMTrainOptions options) {
        SVMModel model = new SVMModel(options);
        InstanceList instancesToUse = instances;
        if (options.copyData) {
            instancesToUse = new InstanceList(instances.getPipe());
            for (Instance instance : instances) {
                Instance copy = instance.shallowCopy();
                copy.unLock();
                copy.setData(((FeatureVector)instance.getData()).cloneMatrix());
                copy.lock();
                instancesToUse.add(copy);
            }
        }
        if (options.rangeScaleFeatures) {
            MinMaxScalingStats maxFeatureValues;
            model.minMaxScalingStats = maxFeatureValues = FeatureUtils.scaleFeatures(instancesToUse);
            model.featuresRangeScaled = options.rangeScaleFeatures;
        }
        if (options.centerFeatures && !options.standardizeFeatures) {
            double[] means = FeatureUtils.centerFeatures(instancesToUse);
            model.featureMeans = means;
            model.featuresCentered = options.centerFeatures;
        }
        if (options.standardizeFeatures) {
            StandardizationStats standardizationStats = FeatureUtils.standardizeFeatures(instancesToUse);
            model.featureMeans = standardizationStats.means;
            model.featureStdDeviations = standardizationStats.stdDeviations;
            model.featuresStandardized = options.standardizeFeatures;
        }
        svm_problem prob = SVM.getSvmProblem(instancesToUse);
        svm_parameter param = SVM.getSvmParameter(options);
        SVM.doTraining(model, prob, param);
        return model;
    }

    public static SVMModel train(double[] labels, double[][] originalFeatureMatrix, SVMTrainOptions options) {
        if (originalFeatureMatrix.length == 0) {
            return SVMModel.EMPTY;
        }
        Object featureMatrix = originalFeatureMatrix;
        SVMModel model = new SVMModel(options);
        if (options.copyData) {
            featureMatrix = new double[originalFeatureMatrix.length][];
            for (int i = 0; i < originalFeatureMatrix.length; ++i) {
                double[] features = originalFeatureMatrix[i];
                featureMatrix[i] = Arrays.copyOf(features, features.length);
            }
        }
        if (options.rangeScaleFeatures) {
            MinMaxScalingStats maxFeatureValues;
            model.minMaxScalingStats = maxFeatureValues = FeatureUtils.scaleFeatures(featureMatrix);
            model.featuresRangeScaled = options.rangeScaleFeatures;
        }
        if (options.centerFeatures && !options.standardizeFeatures) {
            double[] means = FeatureUtils.centerFeatures(featureMatrix);
            model.featureMeans = means;
            model.featuresCentered = options.centerFeatures;
        }
        if (options.standardizeFeatures) {
            StandardizationStats standardizationStats = FeatureUtils.standardizeFeatures(featureMatrix);
            model.featureMeans = standardizationStats.means;
            model.featureStdDeviations = standardizationStats.stdDeviations;
            model.featuresStandardized = options.standardizeFeatures;
        }
        svm_problem prob = SVM.getSvmProblem(labels, featureMatrix);
        svm_parameter param = SVM.getSvmParameter(options);
        SVM.doTraining(model, prob, param);
        return model;
    }

    private static void doTraining(SVMModel model, svm_problem prob, svm_parameter param) {
        String errorMsg = svm.svm_check_parameter(prob, param);
        if (errorMsg != null) {
            log.error("Error in the SVM parameters: {}", (Object)errorMsg);
        } else {
            String kType = "";
            switch (param.kernel_type) {
                case 0: {
                    kType = "Linear";
                    break;
                }
                case 1: {
                    kType = "Polynomial";
                    break;
                }
                case 2: {
                    kType = "RBF";
                    break;
                }
                case 3: {
                    kType = "Sigmoid";
                }
            }
            log.info("Starting SVM training with settings:\nKernel type: {}\nC: {}\nGamma: {}\nDegree: {}\nr (coef0): {}", kType, param.C, param.gamma, param.degree, param.coef0);
            svm_model svmModel = svm.svm_train(prob, param);
            log.info("SVM training done");
            model.svmModel = svmModel;
        }
    }

    public static svm_parameter getSvmParameter(SVMTrainOptions options) {
        svm_parameter param = new svm_parameter();
        param.svm_type = options.svmType;
        param.C = options.C;
        param.kernel_type = options.kernelType;
        param.gamma = options.svmGamma;
        param.coef0 = options.coef0;
        param.degree = options.svmDegree;
        param.cache_size = options.cacheSize;
        param.eps = options.eps;
        param.shrinking = options.shrinking ? 1 : 0;
        param.probability = options.probability ? 1 : 0;
        Map<Integer, Double> classWeights = options.classWeights;
        if (classWeights != null) {
            param.nr_weight = classWeights.size();
            param.weight_label = new int[param.nr_weight];
            param.weight = new double[param.nr_weight];
            int i = 0;
            for (Integer classNr : classWeights.keySet()) {
                param.weight_label[i] = classNr;
                param.weight[i] = classWeights.get(classNr);
                ++i;
            }
        }
        return param;
    }

    public static double[] predict(double[] features, SVMModel model) {
        if (model.featuresRangeScaled) {
            FeatureUtils.rangeScaleFeatures(features, model.minMaxScalingStats);
        }
        if (model.featuresCentered && !model.featuresStandardized) {
            FeatureUtils.centerFeatures(features, model.featureMeans);
        }
        if (model.featuresStandardized) {
            FeatureUtils.standardizeFeatures(features, model.featureMeans, model.featureStdDeviations);
        }
        return SVM.predict(features, model.svmModel);
    }

    public static double[] predict(Instance instance, SVMModel model) {
        double[] class_scores;
        assert (instance.getData() instanceof FeatureVector);
        if (model.featuresRangeScaled) {
            FeatureUtils.rangeScaleFeatures(instance, model.minMaxScalingStats);
        }
        if (model.featuresCentered && !model.featuresStandardized) {
            FeatureUtils.centerFeatures(instance, model.featureMeans);
        }
        if (model.featuresStandardized) {
            FeatureUtils.standardizeFeatures(instance, model.featureMeans, model.featureStdDeviations);
        }
        FeatureVector fv = (FeatureVector)instance.getData();
        svm_node[] svmVector = new svm_node[fv.numLocations()];
        int[] indices = fv.getIndices();
        for (int i = 0; i < fv.numLocations(); ++i) {
            int index = indices != null ? indices[i] : i;
            svm_node svm_node2 = new svm_node();
            svm_node2.index = index + 1;
            svm_node2.value = fv.isBinary() ? 1.0 : fv.getValues()[i];
            svmVector[i] = svm_node2;
        }
        int nr_class = model.svmModel.nr_class;
        double[] dArray = class_scores = model.trainOptions.probability ? new double[nr_class] : new double[nr_class * (nr_class - 1) / 2];
        if (model.trainOptions.probability) {
            double predictedClass = svm.svm_predict_probability(model.svmModel, svmVector, class_scores);
        } else {
            double predictedClass = svm.svm_predict_values(model.svmModel, svmVector, class_scores);
        }
        return class_scores;
    }

    public static void storeModel(File destination, SVMModel model) throws IOException {
        try (ObjectOutputStream os = new ObjectOutputStream(new GZIPOutputStream(new FileOutputStream(destination)));){
            os.writeObject(model);
        }
    }

    public static SVMModel readModel(String source) throws FileNotFoundException, ClassNotFoundException, IOException {
        if (source.startsWith("classpath:")) {
            URL resource = SVM.class.getClassLoader().getResource(source.substring(10));
            if (null == resource) {
                throw new IllegalArgumentException("The classpath resource " + source + " could not be found.");
            }
            return SVM.readModel(resource);
        }
        return SVM.readModel(new File(source));
    }

    public static SVMModel readModel(File origin) throws ClassNotFoundException, IOException {
        return SVM.readModel(origin.toURI().toURL());
    }

    public static SVMModel readModel(URL origin) throws IOException, ClassNotFoundException {
        try (InputStream modelStream = origin.openStream();){
            SVMModel sVMModel;
            if (modelStream == null) {
                throw new IllegalArgumentException("No model could be found at location " + origin);
            }
            try (ObjectInputStream is = new ObjectInputStream(new GZIPInputStream(modelStream));){
                sVMModel = (SVMModel)is.readObject();
            }
            return sVMModel;
        }
    }

    public static double getBestLabel(double[] predictedValues, SVMModel svmModel) {
        return SVM.getBestLabel(predictedValues, svmModel.svmModel);
    }

    public static double getBestLabel(double[] predictedValues, svm_model svmModel) {
        assert (predictedValues.length > 0);
        int[] labels = svmModel.label;
        int maxValueIndex = IntStream.range(0, predictedValues.length).reduce((a, b) -> predictedValues[a] > predictedValues[b] ? a : b).getAsInt();
        return labels[maxValueIndex];
    }
}

