package net.haesleinhuepf.clijx.plugins;

import ij.measure.ResultsTable;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import net.haesleinhuepf.clij.clearcl.ClearCLBuffer;
import net.haesleinhuepf.clij.macro.CLIJMacroPlugin;
import net.haesleinhuepf.clij.macro.CLIJOpenCLProcessor;
import net.haesleinhuepf.clij.macro.documentation.OffersDocumentation;
import net.haesleinhuepf.clij2.AbstractCLIJ2Plugin;
import net.haesleinhuepf.clij2.CLIJ2;
import net.haesleinhuepf.clij2.utilities.HasClassifiedInputOutput;
import net.haesleinhuepf.clij2.utilities.IsCategorized;
import net.haesleinhuepf.clijx.weka.GenerateLabelFeatureImage;
import org.apache.commons.math3.ml.clustering.CentroidCluster;
import org.apache.commons.math3.ml.clustering.DoublePoint;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
import org.scijava.plugin.Plugin;

@Plugin(type = CLIJMacroPlugin.class, name = "CLIJx_kMeansLabelClusterer")
/* loaded from: input_file:net/haesleinhuepf/clijx/plugins/KMeansLabelClusterer.class */
public class KMeansLabelClusterer extends AbstractCLIJ2Plugin implements CLIJMacroPlugin, CLIJOpenCLProcessor, OffersDocumentation, IsCategorized, HasClassifiedInputOutput {
    private static String last_loaded_filename = "";
    private static List<CentroidCluster<DoublePoint>> centroids = null;

    public String getParameterHelpText() {
        return "Image input, Image label_map, ByRef Image destination, String features, String modelfilename, Number number_of_classes, Number neighbor_radius, Boolean train";
    }

    public Object[] getDefaultValues() {
        return new Object[]{null, null, null, GenerateLabelFeatureImage.defaultFeatures(), "kmeans_clusterer.model.csv", 2, 0, true};
    }

    public boolean executeCL() {
        return kMeansLabelClusterer(getCLIJ2(), (ClearCLBuffer) this.args[0], (ClearCLBuffer) this.args[1], (ClearCLBuffer) this.args[2], (String) this.args[3], (String) this.args[4], Integer.valueOf(asInteger(this.args[5]).intValue()), Integer.valueOf(asInteger(this.args[6]).intValue()), Boolean.valueOf(asBoolean(this.args[7]).booleanValue()));
    }

    public static boolean kMeansLabelClusterer(CLIJ2 clij2, ClearCLBuffer clearCLBuffer, ClearCLBuffer clearCLBuffer2, ClearCLBuffer clearCLBuffer3, String str, String str2, Integer num, Integer num2, Boolean bool) {
        if (!new File(str2).exists()) {
            clij2.set(clearCLBuffer3, 0.0d);
            System.out.println("Model " + str2 + " not found. Will train new KMeansLabelClusterer.");
            bool = true;
        }
        if (centroids != null && num.intValue() != centroids.size()) {
            System.out.println("Number of classes doesn't match to trained model. Will train new KMeansLabelClusterer.");
            bool = true;
        }
        ClearCLBuffer generateLabelFeatureImage = GenerateLabelFeatureImage.generateLabelFeatureImage(clij2, clearCLBuffer, clearCLBuffer2, str);
        clij2.print(generateLabelFeatureImage);
        ResultsTable resultsTable = new ResultsTable();
        clij2.pullToResultsTable(generateLabelFeatureImage, resultsTable);
        generateLabelFeatureImage.close();
        if (centroids == null || last_loaded_filename.compareTo(str2) != 0) {
            centroids = centroidsFromDisc(str2);
            last_loaded_filename = str2;
            System.out.println("Load model");
        }
        if (centroids != null && centroids.size() > 0) {
            if (generateLabelFeatureImage.getHeight() != centroids.get(0).getCenter().getPoint().length) {
                System.out.println("Number of features doesn't match. Will train new KMeansLabelClusterer.");
                bool = true;
            }
        }
        if (centroids == null || centroids.size() == 0 || bool.booleanValue() || num.intValue() != centroids.size()) {
            System.out.println("Train model");
            centroids = trainKMeansClustering(resultsTable, str2, num.intValue());
        }
        predictKMeansClustering(resultsTable, centroids, "CLASS");
        ClearCLBuffer create = clij2.create(resultsTable.size(), 1L, 1L);
        clij2.pushResultsTableColumn(create, resultsTable, "CLASS");
        ClearCLBuffer create2 = clij2.create(resultsTable.size() + 1, 1L, 1L);
        clij2.set(create2, 0.0d);
        clij2.paste(create, create2, 1.0d, 0.0d, 0.0d);
        if (num2.intValue() > 0) {
            int maximumOfAllPixels = (int) clij2.maximumOfAllPixels(clearCLBuffer2);
            ClearCLBuffer create3 = clij2.create(maximumOfAllPixels + 1, maximumOfAllPixels + 1);
            clij2.generateTouchMatrix(clearCLBuffer2, create3);
            clij2.setColumn(create3, 0.0d, 0.0d);
            ClearCLBuffer create4 = clij2.create(create2);
            ClearCLBuffer create5 = clij2.create(create3);
            clij2.copy(create3, create5);
            for (int i = 1; i < num2.intValue(); i++) {
                clij2.neighborsOfNeighbors(create3, create5);
                clij2.copy(create5, create3);
            }
            create5.close();
            ModeOfTouchingNeighbors.modeOfTouchingNeighbors(clij2, create2, create3, create4);
            clij2.copy(create4, create2);
            clij2.setColumn(create2, 0.0d, 0.0d);
            create4.close();
            create3.close();
        }
        clij2.replaceIntensities(clearCLBuffer2, create2, clearCLBuffer3);
        create.close();
        create2.close();
        if (!bool.booleanValue()) {
            return true;
        }
        centroidsToDisc(centroids, str2);
        System.out.println("Saved model to " + str2);
        return true;
    }

    public static void centroidsToDisc(List<CentroidCluster<DoublePoint>> list, String str) {
        ResultsTable resultsTable = new ResultsTable();
        int i = 0;
        Iterator<CentroidCluster<DoublePoint>> it = list.iterator();
        while (it.hasNext()) {
            double[] point = it.next().getCenter().getPoint();
            for (int i2 = 0; i2 < point.length; i2++) {
                resultsTable.setValue(i2, i, point[i2]);
            }
            i++;
        }
        resultsTable.save(str);
    }

    public static List<CentroidCluster<DoublePoint>> centroidsFromDisc(String str) {
        ArrayList arrayList = new ArrayList();
        try {
            ResultsTable open = ResultsTable.open(str);
            for (int i = 0; i < open.size(); i++) {
                String[] split = open.getRowAsString(i).split("\t");
                double[] dArr = new double[split.length];
                for (int i2 = 0; i2 < split.length; i2++) {
                    dArr[i2] = Double.parseDouble(split[i2]);
                }
                arrayList.add(new CentroidCluster(new DoublePoint(dArr)));
            }
            return arrayList;
        } catch (IOException e) {
            return arrayList;
        }
    }

    private static List<CentroidCluster<DoublePoint>> trainKMeansClustering(ResultsTable resultsTable, String str, int i) {
        return new KMeansPlusPlusClusterer(i).cluster(tableToList(resultsTable));
    }

    private static void predictKMeansClustering(ResultsTable resultsTable, List<CentroidCluster<DoublePoint>> list, String str) {
        double[] predictKMeansClustering = predictKMeansClustering(resultsTable, list);
        for (int i = 0; i < predictKMeansClustering.length; i++) {
            resultsTable.setValue(str, i, predictKMeansClustering[i]);
        }
    }

    private static double[] predictKMeansClustering(ResultsTable resultsTable, List<CentroidCluster<DoublePoint>> list) {
        int size = list.size();
        List<DoublePoint> tableToList = tableToList(resultsTable);
        double[] dArr = new double[tableToList.size()];
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        for (int i = 0; i < resultsTable.size(); i++) {
            int i2 = 0;
            double d = Double.MAX_VALUE;
            for (int i3 = 0; i3 < size; i3++) {
                double compute = euclideanDistance.compute(list.get(i3).getCenter().getPoint(), tableToList.get(i).getPoint());
                if (compute < d) {
                    d = compute;
                    i2 = i3 + 1;
                }
            }
            dArr[i] = i2;
        }
        return dArr;
    }

    public static List<DoublePoint> tableToList(ResultsTable resultsTable) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < resultsTable.size(); i++) {
            String[] split = resultsTable.getRowAsString(i).split("\t");
            double[] dArr = new double[split.length];
            int i2 = 0;
            for (String str : split) {
                dArr[i2] = Double.parseDouble(str);
                i2++;
            }
            arrayList.add(new DoublePoint(dArr));
        }
        return arrayList;
    }

    public String getDescription() {
        return "Applies K-Means clustering to an image and a corresponding label map. \n\nSee also: https://commons.apache.org/proper/commons-math/javadocs/api-3.6/org/apache/commons/math3/ml/clustering/KMeansPlusPlusClusterer.html\nMake sure that the handed over feature list is the same used while training the model.\nThe neighbor_radius specifies a correction step which allows to use a region where the mode of \nclassification results (the most popular class) will be determined after clustering.";
    }

    public String getAvailableForDimensions() {
        return "2D, 3D";
    }

    public String getCategories() {
        return "Label,Segmentation";
    }

    public static void invalidateCache() {
        last_loaded_filename = "";
    }

    public String getInputType() {
        return "Label Image";
    }

    public String getOutputType() {
        return "Label Image";
    }
}
