/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.tree;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.tree.Ensemble;
import ciir.umass.edu.learning.tree.FeatureHistogram;
import ciir.umass.edu.learning.tree.RegressionTree;
import ciir.umass.edu.learning.tree.Split;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.MergeSorter;
import ciir.umass.edu.utilities.MyThreadPool;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class LambdaMART
extends Ranker {
    public static int nTrees = 1000;
    public static float learningRate = 0.1f;
    public static int nThreshold = 256;
    public static int nRoundToStopEarly = 100;
    public static int nTreeLeaves = 10;
    public static int minLeafSupport = 1;
    public static int gcCycle = 100;
    protected float[][] thresholds = null;
    protected Ensemble ensemble = null;
    protected double[] modelScores = null;
    protected double[][] modelScoresOnValidation = null;
    protected int bestModelOnValidation = 0x7FFFFFFD;
    protected DataPoint[] martSamples = null;
    protected int[][] sortedIdx = null;
    protected FeatureHistogram hist = null;
    protected double[] pseudoResponses = null;
    protected double[] weights = null;

    public LambdaMART() {
    }

    public LambdaMART(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    @Override
    public void init() {
        this.PRINT("Initializing... ");
        int dpCount = 0;
        for (int i = 0; i < this.samples.size(); ++i) {
            RankList rl = (RankList)this.samples.get(i);
            dpCount += rl.size();
        }
        int current = 0;
        this.martSamples = new DataPoint[dpCount];
        this.modelScores = new double[dpCount];
        this.pseudoResponses = new double[dpCount];
        this.weights = new double[dpCount];
        for (int i = 0; i < this.samples.size(); ++i) {
            RankList rl = (RankList)this.samples.get(i);
            for (int j = 0; j < rl.size(); ++j) {
                this.martSamples[current + j] = rl.get(j);
                this.modelScores[current + j] = 0.0;
                this.pseudoResponses[current + j] = 0.0;
                this.weights[current + j] = 0.0;
            }
            current += rl.size();
        }
        this.sortedIdx = new int[this.features.length][];
        MyThreadPool p = MyThreadPool.getInstance();
        if (p.size() == 1) {
            this.sortSamplesByFeature(0, this.features.length - 1);
        } else {
            int[] partition = p.partition(this.features.length);
            for (int i = 0; i < partition.length - 1; ++i) {
                p.execute(new SortWorker(this, partition[i], partition[i + 1] - 1));
            }
            p.await();
        }
        this.thresholds = new float[this.features.length][];
        for (int f = 0; f < this.features.length; ++f) {
            int i;
            ArrayList<Float> values = new ArrayList<Float>();
            float fmax = Float.NEGATIVE_INFINITY;
            float fmin = Float.MAX_VALUE;
            for (i = 0; i < this.martSamples.length; ++i) {
                int j;
                int k = this.sortedIdx[f][i];
                float fv = this.martSamples[k].getFeatureValue(this.features[f]);
                values.add(Float.valueOf(fv));
                if (fmax < fv) {
                    fmax = fv;
                }
                if (fmin > fv) {
                    fmin = fv;
                }
                for (j = i + 1; j < this.martSamples.length && !(this.martSamples[this.sortedIdx[f][j]].getFeatureValue(this.features[f]) > fv); ++j) {
                }
                i = j - 1;
            }
            if (values.size() <= nThreshold || nThreshold == -1) {
                this.thresholds[f] = new float[values.size() + 1];
                for (i = 0; i < values.size(); ++i) {
                    this.thresholds[f][i] = ((Float)values.get(i)).floatValue();
                }
                this.thresholds[f][values.size()] = Float.MAX_VALUE;
                continue;
            }
            float step = Math.abs(fmax - fmin) / (float)nThreshold;
            this.thresholds[f] = new float[nThreshold + 1];
            this.thresholds[f][0] = fmin;
            for (int j = 1; j < nThreshold; ++j) {
                this.thresholds[f][j] = this.thresholds[f][j - 1] + step;
            }
            this.thresholds[f][LambdaMART.nThreshold] = Float.MAX_VALUE;
        }
        if (this.validationSamples != null) {
            this.modelScoresOnValidation = new double[this.validationSamples.size()][];
            for (int i = 0; i < this.validationSamples.size(); ++i) {
                this.modelScoresOnValidation[i] = new double[((RankList)this.validationSamples.get(i)).size()];
                Arrays.fill(this.modelScoresOnValidation[i], 0.0);
            }
        }
        this.hist = new FeatureHistogram();
        this.hist.construct(this.martSamples, this.pseudoResponses, this.sortedIdx, this.features, this.thresholds);
        this.sortedIdx = null;
        System.gc();
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        this.ensemble = new Ensemble();
        this.PRINTLN("---------------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("---------------------------------");
        this.PRINTLN(new int[]{7, 9, 9}, new String[]{"#iter", this.scorer.name() + "-T", this.scorer.name() + "-V"});
        this.PRINTLN("---------------------------------");
        for (int m3 = 0; m3 < nTrees; ++m3) {
            int i;
            this.PRINT(new int[]{7}, new String[]{m3 + 1 + ""});
            this.computePseudoResponses();
            this.hist.update(this.pseudoResponses);
            RegressionTree rt = new RegressionTree(nTreeLeaves, this.martSamples, this.pseudoResponses, this.hist, minLeafSupport);
            rt.fit();
            this.ensemble.add(rt, learningRate);
            this.updateTreeOutput(rt);
            List<Split> leaves = rt.leaves();
            for (i = 0; i < leaves.size(); ++i) {
                Split s2 = leaves.get(i);
                int[] idx = s2.getSamples();
                for (int j = 0; j < idx.length; ++j) {
                    int n = idx[j];
                    this.modelScores[n] = this.modelScores[n] + (double)learningRate * s2.getOutput();
                }
            }
            rt.clearSamples();
            if (m3 % gcCycle == 0) {
                System.gc();
            }
            this.scoreOnTrainingData = this.computeModelScoreOnTraining();
            this.PRINT(new int[]{9}, new String[]{SimpleMath.round(this.scoreOnTrainingData, 4) + ""});
            if (this.validationSamples != null) {
                for (i = 0; i < this.modelScoresOnValidation.length; ++i) {
                    for (int j = 0; j < this.modelScoresOnValidation[i].length; ++j) {
                        double[] dArray = this.modelScoresOnValidation[i];
                        int n = j;
                        dArray[n] = dArray[n] + (double)learningRate * rt.eval(((RankList)this.validationSamples.get(i)).get(j));
                    }
                }
                double score = this.computeModelScoreOnValidation();
                this.PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
                if (score > this.bestScoreOnValidationData) {
                    this.bestScoreOnValidationData = score;
                    this.bestModelOnValidation = this.ensemble.treeCount() - 1;
                }
            }
            this.PRINTLN("");
            if (m3 - this.bestModelOnValidation > nRoundToStopEarly) break;
        }
        while (this.ensemble.treeCount() > this.bestModelOnValidation + 1) {
            this.ensemble.remove(this.ensemble.treeCount() - 1);
        }
        this.scoreOnTrainingData = this.scorer.score(this.rank(this.samples));
        this.PRINTLN("---------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + SimpleMath.round(this.scoreOnTrainingData, 4));
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public double eval(DataPoint dp) {
        return this.ensemble.eval(dp);
    }

    @Override
    public Ranker createNew() {
        return new LambdaMART();
    }

    @Override
    public String toString() {
        return this.ensemble.toString();
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## No. of trees = " + nTrees + "\n";
        output = output + "## No. of leaves = " + nTreeLeaves + "\n";
        output = output + "## No. of threshold candidates = " + nThreshold + "\n";
        output = output + "## Learning rate = " + learningRate + "\n";
        output = output + "## Stop early = " + nRoundToStopEarly + "\n";
        output = output + "\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            String content = "";
            StringBuffer model = new StringBuffer();
            BufferedReader in = new BufferedReader(new StringReader(fullText));
            while ((content = in.readLine()) != null) {
                if ((content = content.trim()).length() == 0 || content.indexOf("##") == 0) continue;
                model.append(content);
            }
            in.close();
            this.ensemble = new Ensemble(model.toString());
            this.features = this.ensemble.getFeatures();
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in LambdaMART::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of trees: " + nTrees);
        this.PRINTLN("No. of leaves: " + nTreeLeaves);
        this.PRINTLN("No. of threshold candidates: " + nThreshold);
        this.PRINTLN("Min leaf support: " + minLeafSupport);
        this.PRINTLN("Learning rate: " + learningRate);
        this.PRINTLN("Stop early: " + nRoundToStopEarly + " rounds without performance gain on validation data");
    }

    @Override
    public String name() {
        return "LambdaMART";
    }

    public Ensemble getEnsemble() {
        return this.ensemble;
    }

    protected void computePseudoResponses() {
        Arrays.fill(this.pseudoResponses, 0.0);
        Arrays.fill(this.weights, 0.0);
        MyThreadPool p = MyThreadPool.getInstance();
        if (p.size() == 1) {
            this.computePseudoResponses(0, this.samples.size() - 1, 0);
        } else {
            ArrayList<LambdaComputationWorker> workers = new ArrayList<LambdaComputationWorker>();
            int[] partition = p.partition(this.samples.size());
            int current = 0;
            for (int i = 0; i < partition.length - 1; ++i) {
                LambdaComputationWorker wk = new LambdaComputationWorker(this, partition[i], partition[i + 1] - 1, current);
                workers.add(wk);
                p.execute(wk);
                if (i >= partition.length - 2) continue;
                for (int j = partition[i]; j <= partition[i + 1] - 1; ++j) {
                    current += ((RankList)this.samples.get(j)).size();
                }
            }
            p.await();
        }
    }

    protected void computePseudoResponses(int start, int end, int current) {
        int cutoff = this.scorer.getK();
        for (int i = start; i <= end; ++i) {
            RankList orig = (RankList)this.samples.get(i);
            int[] idx = MergeSorter.sort(this.modelScores, current, current + orig.size() - 1, false);
            RankList rl = new RankList(orig, idx, current);
            double[][] changes = this.scorer.swapChange(rl);
            for (int j = 0; j < rl.size(); ++j) {
                DataPoint p1 = rl.get(j);
                int mj = idx[j];
                for (int k = 0; k < rl.size() && (j <= cutoff || k <= cutoff); ++k) {
                    double deltaNDCG;
                    DataPoint p2 = rl.get(k);
                    int mk = idx[k];
                    if (!(p1.getLabel() > p2.getLabel()) || !((deltaNDCG = Math.abs(changes[j][k])) > 0.0)) continue;
                    double rho = 1.0 / (1.0 + Math.exp(this.modelScores[mj] - this.modelScores[mk]));
                    double lambda = rho * deltaNDCG;
                    int n = mj;
                    this.pseudoResponses[n] = this.pseudoResponses[n] + lambda;
                    int n2 = mk;
                    this.pseudoResponses[n2] = this.pseudoResponses[n2] - lambda;
                    double delta = rho * (1.0 - rho) * deltaNDCG;
                    int n3 = mj;
                    this.weights[n3] = this.weights[n3] + delta;
                    int n4 = mk;
                    this.weights[n4] = this.weights[n4] + delta;
                }
            }
            current += orig.size();
        }
    }

    protected void updateTreeOutput(RegressionTree rt) {
        List<Split> leaves = rt.leaves();
        for (int i = 0; i < leaves.size(); ++i) {
            float s1 = 0.0f;
            float s2 = 0.0f;
            Split s3 = leaves.get(i);
            int[] idx = s3.getSamples();
            for (int j = 0; j < idx.length; ++j) {
                int k = idx[j];
                s1 = (float)((double)s1 + this.pseudoResponses[k]);
                s2 = (float)((double)s2 + this.weights[k]);
            }
            if (s2 == 0.0f) {
                s3.setOutput(0.0f);
                continue;
            }
            s3.setOutput(s1 / s2);
        }
    }

    protected int[] sortSamplesByFeature(DataPoint[] samples, int fid) {
        double[] score = new double[samples.length];
        for (int i = 0; i < samples.length; ++i) {
            score[i] = samples[i].getFeatureValue(fid);
        }
        int[] idx = MergeSorter.sort(score, true);
        return idx;
    }

    protected RankList rank(int rankListIndex, int current) {
        RankList orig = (RankList)this.samples.get(rankListIndex);
        double[] scores = new double[orig.size()];
        for (int i = 0; i < scores.length; ++i) {
            scores[i] = this.modelScores[current + i];
        }
        int[] idx = MergeSorter.sort(scores, false);
        return new RankList(orig, idx);
    }

    protected float computeModelScoreOnTraining() {
        float s2 = this.computeModelScoreOnTraining(0, this.samples.size() - 1, 0);
        return s2 /= (float)this.samples.size();
    }

    protected float computeModelScoreOnTraining(int start, int end, int current) {
        float s2 = 0.0f;
        int c = current;
        for (int i = start; i <= end; ++i) {
            s2 = (float)((double)s2 + this.scorer.score(this.rank(i, c)));
            c += ((RankList)this.samples.get(i)).size();
        }
        return s2;
    }

    protected float computeModelScoreOnValidation() {
        float score = this.computeModelScoreOnValidation(0, this.validationSamples.size() - 1);
        return score / (float)this.validationSamples.size();
    }

    protected float computeModelScoreOnValidation(int start, int end) {
        float score = 0.0f;
        for (int i = start; i <= end; ++i) {
            int[] idx = MergeSorter.sort(this.modelScoresOnValidation[i], false);
            score = (float)((double)score + this.scorer.score(new RankList((RankList)this.validationSamples.get(i), idx)));
        }
        return score;
    }

    protected void sortSamplesByFeature(int fStart, int fEnd) {
        for (int i = fStart; i <= fEnd; ++i) {
            this.sortedIdx[i] = this.sortSamplesByFeature(this.martSamples, this.features[i]);
        }
    }

    class Worker
    implements Runnable {
        LambdaMART ranker = null;
        int rlStart = -1;
        int rlEnd = -1;
        int martStart = -1;
        int type = -1;
        float score = 0.0f;

        Worker(LambdaMART ranker, int rlStart, int rlEnd) {
            this.type = 3;
            this.ranker = ranker;
            this.rlStart = rlStart;
            this.rlEnd = rlEnd;
        }

        Worker(LambdaMART ranker, int rlStart, int rlEnd, int martStart) {
            this.type = 4;
            this.ranker = ranker;
            this.rlStart = rlStart;
            this.rlEnd = rlEnd;
            this.martStart = martStart;
        }

        @Override
        public void run() {
            if (this.type == 4) {
                this.score = this.ranker.computeModelScoreOnTraining(this.rlStart, this.rlEnd, this.martStart);
            } else if (this.type == 3) {
                this.score = this.ranker.computeModelScoreOnValidation(this.rlStart, this.rlEnd);
            }
        }
    }

    class LambdaComputationWorker
    implements Runnable {
        LambdaMART ranker = null;
        int rlStart = -1;
        int rlEnd = -1;
        int martStart = -1;

        LambdaComputationWorker(LambdaMART ranker, int rlStart, int rlEnd, int martStart) {
            this.ranker = ranker;
            this.rlStart = rlStart;
            this.rlEnd = rlEnd;
            this.martStart = martStart;
        }

        @Override
        public void run() {
            this.ranker.computePseudoResponses(this.rlStart, this.rlEnd, this.martStart);
        }
    }

    class SortWorker
    implements Runnable {
        LambdaMART ranker = null;
        int start = -1;
        int end = -1;

        SortWorker(LambdaMART ranker, int start, int end) {
            this.ranker = ranker;
            this.start = start;
            this.end = end;
        }

        @Override
        public void run() {
            this.ranker.sortSamplesByFeature(this.start, this.end);
        }
    }
}

