/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.KeyValuePair;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public class CoorAscent
extends Ranker {
    public static int nRestart = 5;
    public static int nMaxIteration = 25;
    public static double stepBase = 0.05;
    public static double stepScale = 2.0;
    public static double tolerance = 0.001;
    public static boolean regularized = false;
    public static double slack = 0.001;
    public double[] weight = null;
    protected int current_feature = -1;
    protected double weight_change = -1.0;

    public CoorAscent() {
    }

    public CoorAscent(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    public void init() {
        this.PRINT("Initializing with " + this.features.length + " features... ");
        this.weight = new double[this.features.length];
        Arrays.fill(this.weight, 1.0 / (double)this.features.length);
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        double[] regVector = new double[this.weight.length];
        this.copy(this.weight, regVector);
        double[] bestModel = null;
        double bestModelScore = 0.0;
        int[] sign = new int[]{1, -1, 0};
        this.PRINTLN("---------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("---------------------------");
        for (int r = 0; r < nRestart; ++r) {
            double startScore;
            this.PRINTLN("[+] Random restart #" + (r + 1) + "/" + nRestart + "...");
            int consecutive_fails = 0;
            for (int i = 0; i < this.weight.length; ++i) {
                this.weight[i] = 1.0f / (float)this.features.length;
            }
            this.current_feature = -1;
            double bestScore = startScore = this.scorer.score(this.rank(this.samples));
            double[] bestWeight = new double[this.weight.length];
            this.copy(this.weight, bestWeight);
            while (this.weight.length > 1 && consecutive_fails < this.weight.length - 1 || this.weight.length == 1 && consecutive_fails == 0) {
                this.PRINTLN("Shuffling features' order... [Done.]");
                this.PRINTLN("Optimizing weight vector... ");
                this.PRINTLN("------------------------------");
                this.PRINTLN(new int[]{7, 8, 7}, new String[]{"Feature", "weight", this.scorer.name()});
                this.PRINTLN("------------------------------");
                int[] fids = this.getShuffledFeatures();
                for (int i = 0; i < fids.length; ++i) {
                    this.current_feature = fids[i];
                    double origWeight = this.weight[fids[i]];
                    double totalStep = 0.0;
                    double bestTotalStep = 0.0;
                    boolean succeeds = false;
                    for (int s2 = 0; s2 < sign.length; ++s2) {
                        int dir2 = sign[s2];
                        double step = 0.001 * (double)dir2;
                        if (origWeight != 0.0 && Math.abs(step) > 0.5 * Math.abs(origWeight)) {
                            step = stepBase * Math.abs(origWeight);
                        }
                        totalStep = step;
                        int numIter = nMaxIteration;
                        if (dir2 == 0) {
                            numIter = 1;
                            totalStep = -origWeight;
                        }
                        for (int j = 0; j < numIter; ++j) {
                            double w = origWeight + totalStep;
                            this.weight_change = step;
                            this.weight[fids[i]] = w;
                            double score = this.scorer.score(this.rank(this.samples));
                            if (regularized) {
                                double penalty = slack * this.getDistance(this.weight, regVector);
                                score -= penalty;
                            }
                            if (score > bestScore) {
                                bestScore = score;
                                bestTotalStep = totalStep;
                                succeeds = true;
                                String bw = (this.weight[fids[i]] > 0.0 ? "+" : "") + SimpleMath.round(this.weight[fids[i]], 4);
                                this.PRINTLN(new int[]{7, 8, 7}, new String[]{this.features[fids[i]] + "", bw + "", SimpleMath.round(bestScore, 4) + ""});
                            }
                            if (j >= nMaxIteration - 1) continue;
                            totalStep += (step *= stepScale);
                        }
                        if (succeeds) break;
                        if (s2 >= sign.length - 1) continue;
                        this.weight_change = -totalStep;
                        this.updateCached();
                        this.weight[fids[i]] = origWeight;
                    }
                    if (succeeds) {
                        this.weight_change = bestTotalStep - totalStep;
                        this.updateCached();
                        this.weight[fids[i]] = origWeight + bestTotalStep;
                        consecutive_fails = 0;
                        double sum = this.normalize(this.weight);
                        this.scaleCached(sum);
                        this.copy(this.weight, bestWeight);
                        continue;
                    }
                    ++consecutive_fails;
                    this.weight_change = -totalStep;
                    this.updateCached();
                    this.weight[fids[i]] = origWeight;
                }
                this.PRINTLN("------------------------------");
                if (!(bestScore - startScore < tolerance)) continue;
                break;
            }
            if (this.validationSamples != null) {
                this.current_feature = -1;
                bestScore = this.scorer.score(this.rank(this.validationSamples));
            }
            if (bestModel != null && !(bestScore > bestModelScore)) continue;
            bestModelScore = bestScore;
            bestModel = bestWeight;
        }
        this.copy(bestModel, this.weight);
        this.current_feature = -1;
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4);
        this.PRINTLN("---------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public RankList rank(RankList rl) {
        int i;
        double[] score = new double[rl.size()];
        if (this.current_feature == -1) {
            for (i = 0; i < rl.size(); ++i) {
                for (int j = 0; j < this.features.length; ++j) {
                    int n = i;
                    score[n] = score[n] + this.weight[j] * (double)rl.get(i).getFeatureValue(this.features[j]);
                }
                rl.get(i).setCached(score[i]);
            }
        } else {
            for (i = 0; i < rl.size(); ++i) {
                score[i] = rl.get(i).getCached() + this.weight_change * (double)rl.get(i).getFeatureValue(this.features[this.current_feature]);
                rl.get(i).setCached(score[i]);
            }
        }
        int[] idx = MergeSorter.sort(score, false);
        return new RankList(rl, idx);
    }

    @Override
    public double eval(DataPoint p) {
        double score = 0.0;
        for (int i = 0; i < this.features.length; ++i) {
            score += this.weight[i] * (double)p.getFeatureValue(this.features[i]);
        }
        return score;
    }

    @Override
    public Ranker createNew() {
        return new CoorAscent();
    }

    @Override
    public String toString() {
        String output = "";
        for (int i = 0; i < this.weight.length; ++i) {
            output = output + this.features[i] + ":" + this.weight[i] + (i == this.weight.length - 1 ? "" : " ");
        }
        return output;
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## Restart = " + nRestart + "\n";
        output = output + "## MaxIteration = " + nMaxIteration + "\n";
        output = output + "## StepBase = " + stepBase + "\n";
        output = output + "## StepScale = " + stepScale + "\n";
        output = output + "## Tolerance = " + tolerance + "\n";
        output = output + "## Regularized = " + regularized + "\n";
        output = output + "## Slack = " + slack + "\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            String content = "";
            BufferedReader in = new BufferedReader(new StringReader(fullText));
            KeyValuePair kvp = null;
            while ((content = in.readLine()) != null) {
                if ((content = content.trim()).length() == 0 || content.indexOf("##") == 0) continue;
                kvp = new KeyValuePair(content);
                break;
            }
            in.close();
            assert (kvp != null);
            List<String> keys = kvp.keys();
            List<String> values = kvp.values();
            this.weight = new double[keys.size()];
            this.features = new int[keys.size()];
            for (int i = 0; i < keys.size(); ++i) {
                this.features[i] = Integer.parseInt(keys.get(i));
                this.weight[i] = Double.parseDouble(values.get(i));
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in CoorAscent::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of random restarts: " + nRestart);
        this.PRINTLN("No. of iterations to search in each direction: " + nMaxIteration);
        this.PRINTLN("Tolerance: " + tolerance);
        if (regularized) {
            this.PRINTLN("Reg. param: " + slack);
        } else {
            this.PRINTLN("Regularization: No");
        }
    }

    @Override
    public String name() {
        return "Coordinate Ascent";
    }

    private void updateCached() {
        for (int j = 0; j < this.samples.size(); ++j) {
            RankList rl = (RankList)this.samples.get(j);
            for (int i = 0; i < rl.size(); ++i) {
                double score = rl.get(i).getCached() + this.weight_change * (double)rl.get(i).getFeatureValue(this.features[this.current_feature]);
                rl.get(i).setCached(score);
            }
        }
    }

    private void scaleCached(double sum) {
        for (int j = 0; j < this.samples.size(); ++j) {
            RankList rl = (RankList)this.samples.get(j);
            for (int i = 0; i < rl.size(); ++i) {
                rl.get(i).setCached(rl.get(i).getCached() / sum);
            }
        }
    }

    private int[] getShuffledFeatures() {
        int i;
        int[] fids = new int[this.features.length];
        ArrayList<Integer> l = new ArrayList<Integer>();
        for (i = 0; i < this.features.length; ++i) {
            l.add(i);
        }
        Collections.shuffle(l);
        for (i = 0; i < l.size(); ++i) {
            fids[i] = (Integer)l.get(i);
        }
        return fids;
    }

    private double getDistance(double[] w1, double[] w2) {
        assert (w1.length == w2.length);
        double s1 = 0.0;
        double s2 = 0.0;
        for (int i = 0; i < w1.length; ++i) {
            s1 += Math.abs(w1[i]);
            s2 += Math.abs(w2[i]);
        }
        double dist = 0.0;
        for (int i = 0; i < w1.length; ++i) {
            double t = w1[i] / s1 - w2[i] / s2;
            dist += t * t;
        }
        return Math.sqrt(dist);
    }

    private double normalize(double[] weights) {
        int j;
        double sum = 0.0;
        for (j = 0; j < weights.length; ++j) {
            sum += Math.abs(weights[j]);
        }
        if (sum > 0.0) {
            j = 0;
            while (j < weights.length) {
                int n = j++;
                weights[n] = weights[n] / sum;
            }
        } else {
            sum = 1.0;
            for (j = 0; j < weights.length; ++j) {
                weights[j] = 1.0 / (double)weights.length;
            }
        }
        return sum;
    }

    public void copyModel(CoorAscent ranker) {
        this.weight = new double[this.features.length];
        if (ranker.weight.length != this.weight.length) {
            System.out.println("These two models use different feature sets!!");
            System.exit(1);
        }
        this.copy(ranker.weight, this.weight);
        this.PRINTLN("Model loaded.");
    }

    public double distance(CoorAscent ca) {
        return this.getDistance(this.weight, ca.weight);
    }
}

