/*
 * Decompiled with CFR 0.152.
 */
package xcsf.listener;

import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import xcsf.MatchSet;
import xcsf.Population;
import xcsf.StateDescriptor;
import xcsf.XCSFConstants;
import xcsf.XCSFListener;
import xcsf.XCSFUtils;
import xcsf.classifier.Classifier;

public class PredictionErrorPlot
implements XCSFListener {
    private static final double[] TICS = new double[]{0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0};
    static final String[] GNUPLOT_CMD = new String[]{"set grid", "set xrange[0:1]", "set yrange[0:1]", "set logscale z", "set xlabel 'x'", "set ylabel 'y'", "set zlabel 'p.err.'", "set style data lines", "set contour", "set surface", "set hidden3d", "set dgrid3d " + TICS.length + "," + TICS.length};
    private MatchSet ms = new MatchSet(false);
    private double[][][] samples = new double[TICS.length][TICS.length][3];
    private XCSFUtils.Gnuplot console = new XCSFUtils.Gnuplot();
    private String tmpFilename;

    public PredictionErrorPlot() throws IOException {
        for (String cmd : GNUPLOT_CMD) {
            this.console.execute(cmd);
        }
        File tmpfile = File.createTempFile(this.getClass().getSimpleName(), "gnuplot");
        tmpfile.deleteOnExit();
        this.tmpFilename = tmpfile.getAbsolutePath();
    }

    public void nextExperiment(int experiment, String functionName) {
    }

    public void stateChanged(int iteration, Population population, MatchSet matchSet, StateDescriptor state, double[][] performance) {
        if (iteration % XCSFConstants.averageExploitTrials != 0 || state.getConditionInput().length != 2) {
            return;
        }
        this.createIsoSamples(population);
        this.console.execute("set title 'iteration " + iteration + "'");
        this.console.execute("splot '" + this.tmpFilename + "' using 1:2:3 title 'avg prediction error'");
    }

    private void createIsoSamples(Population population) {
        int n = TICS.length;
        double[] output = new double[1];
        try {
            PrintStream ps = new PrintStream(this.tmpFilename);
            for (int x = 0; x < n; ++x) {
                for (int y = 0; y < n; ++y) {
                    double[] input = new double[]{TICS[x], TICS[y]};
                    StateDescriptor state = new StateDescriptor(input, output);
                    this.ms.match(state, population);
                    if (this.ms.size() == 0) {
                        this.ms.setNumClosestMatching(true);
                        this.ms.match(state, population);
                        this.ms.setNumClosestMatching(false);
                    }
                    this.samples[x][y][0] = input[0];
                    this.samples[x][y][1] = input[1];
                    this.samples[x][y][2] = this.getAvgPredictionError();
                    ps.println(this.samples[x][y][0] + " " + this.samples[x][y][1] + " " + this.samples[x][y][2]);
                }
            }
            ps.flush();
            ps.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
    }

    private double getAvgPredictionError() {
        double avg = 0.0;
        double fitnesssum = 0.0;
        for (Classifier cl : this.ms) {
            fitnesssum += cl.getFitness();
            avg += cl.getPredictionError() * cl.getFitness();
        }
        return avg / fitnesssum;
    }

    protected void finalize() throws Throwable {
        this.console.close();
        super.finalize();
    }
}

