/*
 * Decompiled with CFR 0.152.
 */
package jasima.core.experiment;

import jasima.core.experiment.Experiment;
import jasima.core.experiment.FullFactorialExperiment;
import jasima.core.experiment.MultipleReplicationExperiment;
import jasima.core.statistics.SummaryStat;
import jasima.core.util.Util;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Map;
import org.apache.commons.math3.distribution.NormalDistribution;

public class OCBAExperiment
extends FullFactorialExperiment {
    private static final long serialVersionUID = 621315272493464195L;
    private String objective;
    private ProblemType problemType = ProblemType.MINIMIZE;
    private int minReplicationsPerConfiguration = -1;
    private int numReplicationsPerConfiguration = 10;
    private double pcsLevel = 0.0;
    private boolean detailedResults = true;
    private int totalBudget;
    private int iterationBudget;
    private int budgetUsed;
    private ArrayList<MultipleReplicationExperiment> configurations;
    private SummaryStat[] stats;
    private double finalPCS;
    private int currBest;

    public OCBAExperiment() {
        this.setProduceAveragedResults(false);
    }

    @Override
    protected void createExperiments() {
        if (this.getNumTasksExecuted() == 0) {
            super.createExperiments();
            this.stats = Util.initializedArray(this.experiments.size(), SummaryStat.class);
            this.configurations = new ArrayList();
            for (Experiment e : this.experiments) {
                MultipleReplicationExperiment mre = (MultipleReplicationExperiment)e;
                int numIterations = this.getMinReplicationsPerConfiguration() == -1 ? Math.max(3, Runtime.getRuntime().availableProcessors()) : this.getMinReplicationsPerConfiguration();
                mre.setMaxReplications(numIterations);
                this.configurations.add(mre);
            }
            this.totalBudget = this.configurations.size() * this.getNumReplicationsPerConfiguration();
            this.iterationBudget = Math.round(0.1f * (float)this.configurations.size());
            if (this.iterationBudget < this.getMinReplicationsPerConfiguration()) {
                this.iterationBudget = this.getMinReplicationsPerConfiguration();
            }
            if (this.getMinReplicationsPerConfiguration() == -1) {
                int numProc = Runtime.getRuntime().availableProcessors();
                this.iterationBudget = Math.max(this.iterationBudget, Math.max(3, numProc));
            }
            this.budgetUsed = 0;
        }
    }

    @Override
    protected Experiment createExperimentForConf(Map<String, Object> conf) {
        Experiment e = super.createExperimentForConf(conf);
        if (e == null) {
            return null;
        }
        e.setName(this.getBaseExperiment().getName());
        MultipleReplicationExperiment mre = new MultipleReplicationExperiment();
        mre.setAllowParallelExecution(this.isAllowParallelExecution());
        mre.setBaseExperiment(e);
        this.configureRunExperiment(mre);
        return mre;
    }

    @Override
    protected void done() {
        super.done();
        this.finalPCS = this.calcPCS();
    }

    @Override
    public void produceResults() {
        super.produceResults();
        this.resultMap.put("bestConfiguration", this.configurations.get(this.currBest).getBaseExperiment());
        this.resultMap.put("bestIndex", this.currBest);
        this.resultMap.put("bestPerformance", this.stats[this.currBest].mean());
        this.resultMap.put("numEvaluations", this.budgetUsed);
        this.resultMap.put("pcs", this.finalPCS);
        if (this.isDetailedResults()) {
            int[] numRuns = new int[this.configurations.size()];
            double[] means = new double[this.stats.length];
            Experiment[] exps = new Experiment[this.configurations.size()];
            for (int i = 0; i < this.configurations.size(); ++i) {
                exps[i] = this.configurations.get(i).getBaseExperiment();
                SummaryStat vs = this.stats[i];
                numRuns[i] = vs.numObs();
                means[i] = vs.mean();
            }
            this.resultMap.put("allocationVector", numRuns);
            this.resultMap.put("meansVector", means);
            this.resultMap.put("configurations", exps);
            this.resultMap.put("probBestBetter", this.calcPCSPriosPerConfiguration());
            this.resultMap.put("rank", this.findRank(means));
        }
    }

    private int[] findRank(final double[] means) {
        Integer[] idx = new Integer[means.length];
        for (int i = 0; i < idx.length; ++i) {
            idx[i] = i;
        }
        Arrays.sort(idx, new Comparator<Integer>(){

            @Override
            public int compare(Integer i1, Integer i2) {
                return (OCBAExperiment.this.getProblemType() == ProblemType.MAXIMIZE ? -1 : 1) * Double.compare(means[i1], means[i2]);
            }
        });
        int[] ranks = new int[idx.length];
        for (int i = 0; i < ranks.length; ++i) {
            ranks[idx[i].intValue()] = i + 1;
        }
        return ranks;
    }

    @Override
    protected boolean hasMoreTasks() {
        this.currBest = 0;
        double bestMean = this.getProblemType() == ProblemType.MAXIMIZE ? this.stats[0].mean() : -this.stats[0].mean();
        for (int i = 1; i < this.stats.length; ++i) {
            double v;
            double d = v = this.getProblemType() == ProblemType.MAXIMIZE ? this.stats[i].mean() : -this.stats[i].mean();
            if (!(v > bestMean)) continue;
            bestMean = v;
            this.currBest = i;
        }
        this.experiments.clear();
        if (this.totalBudget > 0 && this.budgetUsed >= this.totalBudget || this.getPcsLevel() > 0.0 && this.calcPCS() > this.getPcsLevel()) {
            return false;
        }
        int iter = this.iterationBudget;
        if (this.totalBudget > 0) {
            iter = Math.min(iter, this.totalBudget - this.budgetUsed);
        }
        int[] newRuns = this.ocba(iter);
        for (int i = 0; i < newRuns.length; ++i) {
            if (newRuns[i] <= 0) continue;
            MultipleReplicationExperiment mre = this.configurations.get(i);
            mre.state().set(Experiment.ExperimentState.INITIAL);
            mre.setMaxReplications(newRuns[i]);
            this.experiments.add(mre);
        }
        return true;
    }

    @Override
    protected void storeRunResults(Experiment e, Map<String, Object> r) {
        super.storeRunResults(e, r);
        int i = this.configurations.indexOf(e);
        assert (i >= 0);
        Object o = r.get(this.getObjective());
        if (o == null) {
            o = r.get(this.getObjective() + ".mean");
        }
        if (o == null) {
            throw new RuntimeException("Can't find result value for objective '" + this.getObjective() + "'.");
        }
        this.budgetUsed += this.configurations.get(i).getMaxReplications();
        SummaryStat vs = this.stats[i];
        if (o instanceof Number) {
            vs.value(((Number)o).doubleValue());
        } else if (o instanceof SummaryStat) {
            vs.combine((SummaryStat)o);
        } else {
            throw new RuntimeException("Don't know how to handle result '" + String.valueOf(o) + "'.");
        }
    }

    protected double calcPCS() {
        double[] prodTerms = this.calcPCSPriosPerConfiguration();
        double res = 1.0;
        for (int i = 0; i < prodTerms.length; ++i) {
            if (i == this.currBest) continue;
            res *= prodTerms[i];
        }
        return res;
    }

    protected double[] calcPCSPriosPerConfiguration() {
        SummaryStat best = this.stats[this.currBest];
        double bestMean = best.mean();
        double bestNormVariance = best.variance() / (double)best.numObs();
        double[] prodTerms = new double[this.stats.length];
        for (int i = 0; i < this.stats.length; ++i) {
            if (i == this.currBest) continue;
            SummaryStat vs = this.stats[i];
            prodTerms[i] = (bestMean - vs.mean()) / Math.sqrt(bestNormVariance + vs.variance() / (double)vs.numObs());
        }
        NormalDistribution normalDist = new NormalDistribution();
        for (int i = 0; i < this.stats.length; ++i) {
            if (i == this.currBest) continue;
            prodTerms[i] = normalDist.cumulativeProbability(prodTerms[i]);
            if (this.getProblemType() != ProblemType.MINIMIZE) continue;
            prodTerms[i] = 1.0 - prodTerms[i];
        }
        return prodTerms;
    }

    protected int[] ocba(int add_budget) {
        int i;
        boolean more_alloc;
        int i2;
        int nd = this.stats.length;
        if (nd == 1) {
            return new int[]{add_budget};
        }
        double[] t_s_mean = new double[nd];
        if (this.getProblemType() == ProblemType.MAXIMIZE) {
            for (i2 = 0; i2 < nd; ++i2) {
                t_s_mean[i2] = -this.stats[i2].mean();
            }
        } else {
            for (i2 = 0; i2 < nd; ++i2) {
                t_s_mean[i2] = this.stats[i2].mean();
            }
        }
        int t_budget = add_budget;
        for (int i3 = 0; i3 < nd; ++i3) {
            t_budget += this.stats[i3].numObs();
        }
        int b = this.currBest;
        int s = OCBAExperiment.second_best(t_s_mean, b);
        double[] ratio = new double[nd];
        ratio[s] = 1.0;
        for (int i4 = 0; i4 < nd; ++i4) {
            if (i4 == s || i4 == b) continue;
            double temp = (t_s_mean[b] - t_s_mean[s]) / (t_s_mean[b] - t_s_mean[i4]);
            ratio[i4] = temp * temp * this.stats[i4].variance() / this.stats[s].variance();
        }
        double temp = 0.0;
        for (int i5 = 0; i5 < nd; ++i5) {
            if (i5 == b) continue;
            temp += ratio[i5] * ratio[i5] / this.stats[i5].variance();
        }
        ratio[b] = Math.sqrt(this.stats[b].variance() * temp);
        int[] morerun = new int[nd];
        for (int i6 = 0; i6 < nd; ++i6) {
            morerun[i6] = 1;
        }
        int t1_budget = t_budget;
        int[] an = new int[nd];
        do {
            int i7;
            more_alloc = false;
            double ratio_s = 0.0;
            for (i7 = 0; i7 < nd; ++i7) {
                if (morerun[i7] != 1) continue;
                ratio_s += ratio[i7];
            }
            for (i7 = 0; i7 < nd; ++i7) {
                if (morerun[i7] != 1) continue;
                an[i7] = (int)((double)t1_budget / ratio_s * ratio[i7]);
                if (an[i7] >= this.stats[i7].numObs()) continue;
                an[i7] = this.stats[i7].numObs();
                morerun[i7] = 0;
                more_alloc = true;
            }
            if (!more_alloc) continue;
            t1_budget = t_budget;
            for (i7 = 0; i7 < nd; ++i7) {
                if (morerun[i7] == 1) continue;
                t1_budget -= an[i7];
            }
        } while (more_alloc);
        t1_budget = an[0];
        for (i = 1; i < nd; ++i) {
            t1_budget += an[i];
        }
        int n = b;
        an[n] = an[n] + (t_budget - t1_budget);
        for (i = 0; i < nd; ++i) {
            int n2 = i;
            an[n2] = an[n2] - this.stats[i].numObs();
        }
        return an;
    }

    private static int second_best(double[] t_s_mean, int b) {
        int second_index = b == 0 ? 1 : 0;
        for (int i = 0; i < t_s_mean.length; ++i) {
            if (!(t_s_mean[i] < t_s_mean[second_index]) || i == b) continue;
            second_index = i;
        }
        return second_index;
    }

    public void setMinReplicationsPerConfiguration(int minReps) {
        if (minReps < 3) {
            throw new IllegalArgumentException("Minimum number of replications has to be >=3 or -1.");
        }
        this.minReplicationsPerConfiguration = minReps;
    }

    public int getMinReplicationsPerConfiguration() {
        return this.minReplicationsPerConfiguration;
    }

    public void setObjective(String objective) {
        this.objective = objective;
    }

    public String getObjective() {
        return this.objective;
    }

    public void setPcsLevel(double pcsLevel) {
        if (pcsLevel < 0.0 || pcsLevel > 1.0) {
            throw new IllegalArgumentException("Invalid probability: " + pcsLevel);
        }
        this.pcsLevel = pcsLevel;
    }

    public double getPcsLevel() {
        return this.pcsLevel;
    }

    public void setDetailedResults(boolean detailedResults) {
        this.detailedResults = detailedResults;
    }

    public boolean isDetailedResults() {
        return this.detailedResults;
    }

    public void setNumReplicationsPerConfiguration(int numReplications) {
        this.numReplicationsPerConfiguration = numReplications;
    }

    public int getNumReplicationsPerConfiguration() {
        return this.numReplicationsPerConfiguration;
    }

    public ProblemType getProblemType() {
        return this.problemType;
    }

    public void setProblemType(ProblemType problemType) {
        this.problemType = problemType;
    }

    public static enum ProblemType {
        MINIMIZE,
        MAXIMIZE;

    }
}

