/*
 * Decompiled with CFR 0.152.
 */
package ciir.umass.edu.learning.neuralnet;

import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.neuralnet.Layer;
import ciir.umass.edu.learning.neuralnet.Neuron;
import ciir.umass.edu.learning.neuralnet.PropParameter;
import ciir.umass.edu.learning.neuralnet.RankNet;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.utilities.RankLibError;
import ciir.umass.edu.utilities.SimpleMath;
import java.io.BufferedReader;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.List;

public class ListNet
extends RankNet {
    public static int nIteration = 1500;
    public static double learningRate = 1.0E-5;
    public static int nHiddenLayer = 0;

    public ListNet() {
    }

    public ListNet(List<RankList> samples, int[] features, MetricScorer scorer) {
        super(samples, features, scorer);
    }

    protected float[] feedForward(RankList rl) {
        float[] labels = new float[rl.size()];
        for (int i = 0; i < rl.size(); ++i) {
            this.addInput(rl.get(i));
            this.propagate(i);
            labels[i] = rl.get(i).getLabel();
        }
        return labels;
    }

    protected void backPropagate(float[] labels) {
        PropParameter p = new PropParameter(labels);
        this.outputLayer.computeDelta(p);
        this.outputLayer.updateWeight(p);
    }

    @Override
    protected void estimateLoss() {
        this.error = 0.0;
        double sumLabelExp = 0.0;
        double sumScoreExp = 0.0;
        for (int i = 0; i < this.samples.size(); ++i) {
            int j;
            RankList rl = (RankList)this.samples.get(i);
            double[] scores = new double[rl.size()];
            double err = 0.0;
            for (j = 0; j < rl.size(); ++j) {
                scores[j] = this.eval(rl.get(j));
                sumLabelExp += Math.exp(rl.get(j).getLabel());
                sumScoreExp += Math.exp(scores[j]);
            }
            for (j = 0; j < rl.size(); ++j) {
                double p1 = Math.exp(rl.get(j).getLabel()) / sumLabelExp;
                double p2 = Math.exp(scores[j]) / sumScoreExp;
                err += -p1 * SimpleMath.logBase2(p2);
            }
            this.error += err / (double)rl.size();
        }
        this.lastError = this.error;
    }

    @Override
    public void init() {
        this.PRINT("Initializing... ");
        this.setInputOutput(this.features.length, 1, 1);
        this.wire();
        if (this.validationSamples != null) {
            for (int i = 0; i < this.layers.size(); ++i) {
                this.bestModelOnValidation.add(new ArrayList());
            }
        }
        Neuron.learningRate = learningRate;
        this.PRINTLN("[Done]");
    }

    @Override
    public void learn() {
        this.PRINTLN("-----------------------------------------");
        this.PRINTLN("Training starts...");
        this.PRINTLN("--------------------------------------------------");
        this.PRINTLN(new int[]{7, 14, 9, 9}, new String[]{"#epoch", "C.E. Loss", this.scorer.name() + "-T", this.scorer.name() + "-V"});
        this.PRINTLN("--------------------------------------------------");
        for (int i = 1; i <= nIteration; ++i) {
            for (int j = 0; j < this.samples.size(); ++j) {
                float[] labels = this.feedForward((RankList)this.samples.get(j));
                this.backPropagate(labels);
                this.clearNeuronOutputs();
            }
            this.PRINT(new int[]{7, 14}, new String[]{i + "", SimpleMath.round(this.error, 6) + ""});
            if (i % 1 == 0) {
                this.scoreOnTrainingData = this.scorer.score(this.rank(this.samples));
                this.PRINT(new int[]{9}, new String[]{SimpleMath.round(this.scoreOnTrainingData, 4) + ""});
                if (this.validationSamples != null) {
                    double score = this.scorer.score(this.rank(this.validationSamples));
                    if (score > this.bestScoreOnValidationData) {
                        this.bestScoreOnValidationData = score;
                        this.saveBestModelOnValidation();
                    }
                    this.PRINT(new int[]{9}, new String[]{SimpleMath.round(score, 4) + ""});
                }
            }
            this.PRINTLN("");
        }
        if (this.validationSamples != null) {
            this.restoreBestModelOnValidation();
        }
        this.scoreOnTrainingData = SimpleMath.round(this.scorer.score(this.rank(this.samples)), 4);
        this.PRINTLN("--------------------------------------------------");
        this.PRINTLN("Finished sucessfully.");
        this.PRINTLN(this.scorer.name() + " on training data: " + this.scoreOnTrainingData);
        if (this.validationSamples != null) {
            this.bestScoreOnValidationData = this.scorer.score(this.rank(this.validationSamples));
            this.PRINTLN(this.scorer.name() + " on validation data: " + SimpleMath.round(this.bestScoreOnValidationData, 4));
        }
        this.PRINTLN("---------------------------------");
    }

    @Override
    public double eval(DataPoint p) {
        return super.eval(p);
    }

    @Override
    public Ranker createNew() {
        return new ListNet();
    }

    @Override
    public String toString() {
        return super.toString();
    }

    @Override
    public String model() {
        String output = "## " + this.name() + "\n";
        output = output + "## Epochs = " + nIteration + "\n";
        output = output + "## No. of features = " + this.features.length + "\n";
        for (int i = 0; i < this.features.length; ++i) {
            output = output + this.features[i] + (i == this.features.length - 1 ? "" : " ");
        }
        output = output + "\n";
        output = output + "0\n";
        output = output + this.toString();
        return output;
    }

    @Override
    public void loadFromString(String fullText) {
        try {
            int i;
            String content = "";
            BufferedReader in = new BufferedReader(new StringReader(fullText));
            ArrayList<String> l = new ArrayList<String>();
            while ((content = in.readLine()) != null) {
                if ((content = content.trim()).length() == 0 || content.indexOf("##") == 0) continue;
                l.add(content);
            }
            in.close();
            String[] tmp = ((String)l.get(0)).split(" ");
            this.features = new int[tmp.length];
            for (int i2 = 0; i2 < tmp.length; ++i2) {
                this.features[i2] = Integer.parseInt(tmp[i2]);
            }
            int nHiddenLayer = Integer.parseInt((String)l.get(1));
            int[] nn = new int[nHiddenLayer];
            for (i = 2; i < 2 + nHiddenLayer; ++i) {
                nn[i - 2] = Integer.parseInt((String)l.get(i));
            }
            this.setInputOutput(this.features.length, 1);
            for (int j = 0; j < nHiddenLayer; ++j) {
                this.addHiddenLayer(nn[j]);
            }
            this.wire();
            while (i < l.size()) {
                String[] s2 = ((String)l.get(i)).split(" ");
                int iLayer = Integer.parseInt(s2[0]);
                int iNeuron = Integer.parseInt(s2[1]);
                Neuron n = ((Layer)this.layers.get(iLayer)).get(iNeuron);
                for (int k = 0; k < n.getOutLinks().size(); ++k) {
                    n.getOutLinks().get(k).setWeight(Double.parseDouble(s2[k + 2]));
                }
                ++i;
            }
        }
        catch (Exception ex) {
            throw RankLibError.create("Error in ListNet::load(): ", ex);
        }
    }

    @Override
    public void printParameters() {
        this.PRINTLN("No. of epochs: " + nIteration);
        this.PRINTLN("Learning rate: " + learningRate);
    }

    @Override
    public String name() {
        return "ListNet";
    }
}

