/*
 * Decompiled with CFR 0.152.
 */
package de.julielab.ml;

import cc.mallet.pipe.Pipe;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AlphabetCarrying;
import cc.mallet.types.FeatureVector;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.Label;
import cc.mallet.types.LabelAlphabet;
import ciir.umass.edu.features.FeatureManager;
import ciir.umass.edu.features.LinearNormalizer;
import ciir.umass.edu.features.Normalizer;
import ciir.umass.edu.features.SumNormalizor;
import ciir.umass.edu.features.ZScoreNormalizor;
import ciir.umass.edu.learning.DataPoint;
import ciir.umass.edu.learning.RANKER_TYPE;
import ciir.umass.edu.learning.RankList;
import ciir.umass.edu.learning.Ranker;
import ciir.umass.edu.learning.RankerFactory;
import ciir.umass.edu.learning.RankerTrainer;
import ciir.umass.edu.learning.SparseDataPoint;
import ciir.umass.edu.metric.METRIC;
import ciir.umass.edu.metric.MetricScorer;
import ciir.umass.edu.metric.MetricScorerFactory;
import de.julielab.java.utilities.FileUtilities;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RankLibRanker
implements AlphabetCarrying,
Serializable {
    private static final Logger log = LoggerFactory.getLogger(RankLibRanker.class);
    private final MetricScorerFactory metricScorerFactory;
    private Ranker ranker;
    private RANKER_TYPE rType;
    private int[] features;
    private METRIC trainMetric;
    private int k;
    private Normalizer featureNormalizer;
    private Alphabet dataAlphabet;
    private Alphabet targetAlphabet;
    private Pipe instancePipe;

    public RankLibRanker(RANKER_TYPE rType, int[] features, METRIC trainMetric, int k, String normalizer) {
        this.rType = rType;
        this.features = features;
        this.trainMetric = trainMetric;
        this.k = k;
        this.metricScorerFactory = new MetricScorerFactory();
        DataPoint.missingZero = true;
        this.initFeatureNormalizer(normalizer);
    }

    public RankLibRanker() {
        this.metricScorerFactory = new MetricScorerFactory();
    }

    public static InstanceList loadSvmLightData(File dataFile) throws Exception {
        Alphabet dataAlphabet = new Alphabet();
        LabelAlphabet targetAlphabet = new LabelAlphabet();
        InstanceList ret = new InstanceList(dataAlphabet, targetAlphabet);
        try (BufferedReader br = FileUtilities.getReaderFromFile(dataFile);){
            int documentId = 0;
            for (String line : () -> br.lines().iterator()) {
                String[] split = line.split("\\s+");
                Float relevance = Float.valueOf(Float.parseFloat(split[0]));
                String queryId = split[1];
                int[] indices = new int[5];
                double[] features = new double[5];
                boolean hasDocumentId = false;
                for (int i = 2; i < split.length; ++i) {
                    if (split[i].equals("#")) {
                        hasDocumentId = true;
                        break;
                    }
                    String[] indexAndValue = split[i].split(":");
                    indices[i - 2] = dataAlphabet.lookupIndex("f" + indexAndValue[0]);
                    features[i - 2] = Double.parseDouble(indexAndValue[1]);
                }
                FeatureVector fv = new FeatureVector(dataAlphabet, indices, features);
                Label label = targetAlphabet.lookupLabel(relevance);
                Instance instance = new Instance(fv, label, queryId, hasDocumentId ? split[split.length - 1] : "doc" + documentId);
                ret.add(instance);
                ++documentId;
            }
        }
        return ret;
    }

    private void writeObject(ObjectOutputStream output) throws IOException {
        int numObjects;
        int n = numObjects = this.instancePipe != null ? 7 : 8;
        if (this.featureNormalizer == null) {
            --numObjects;
        }
        if (this.dataAlphabet == null && this.instancePipe == null) {
            numObjects -= 2;
        }
        output.writeInt(numObjects);
        output.writeObject((Object)this.rType);
        output.writeObject((Object)this.trainMetric);
        output.writeObject(this.k);
        if (this.featureNormalizer != null) {
            output.writeObject(this.featureNormalizer.name());
        }
        if (this.instancePipe == null && this.dataAlphabet != null) {
            output.writeObject(this.dataAlphabet);
            output.writeObject(this.targetAlphabet);
        } else {
            output.writeObject(this.instancePipe);
        }
        output.writeObject(this.getModelAsString());
        output.writeObject(this.features);
    }

    private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
        int numWritten = input.readInt();
        for (int i = 0; i < numWritten; ++i) {
            Object o = input.readObject();
            this.assignLoadedObject(o);
        }
    }

    private void assignLoadedObject(Object o) {
        if (o instanceof String) {
            try {
                this.initFeatureNormalizer((String)o);
            }
            catch (IllegalArgumentException e) {
                this.loadFromString((String)o);
            }
        } else if (o instanceof Alphabet) {
            if (this.dataAlphabet == null) {
                this.dataAlphabet = (Alphabet)o;
            } else {
                this.targetAlphabet = (Alphabet)o;
            }
        } else if (o instanceof Pipe) {
            this.setInstancePipe((Pipe)o);
        } else if (o instanceof RANKER_TYPE) {
            this.rType = (RANKER_TYPE)((Object)o);
        } else if (o instanceof METRIC) {
            this.trainMetric = (METRIC)((Object)o);
        } else if (o instanceof Integer) {
            this.k = (Integer)o;
        } else if (o instanceof int[]) {
            this.features = (int[])o;
        }
    }

    public Pipe getInstancePipe() {
        return this.instancePipe;
    }

    public void setInstancePipe(Pipe instancePipe) {
        if (this.dataAlphabet != null && !instancePipe.getAlphabet().equals(this.dataAlphabet)) {
            throw new IllegalArgumentException("The already existing data alphabet of the ranker and the data alphabet of the passed instance pipe do not match.");
        }
        if (this.targetAlphabet != null && !instancePipe.getTargetAlphabet().equals(this.targetAlphabet)) {
            throw new IllegalArgumentException("The already existing target alphabet of the ranker and the target alphabet of the passed instance pipe do not match.");
        }
        if (this.dataAlphabet == null) {
            this.dataAlphabet = instancePipe.getAlphabet();
        }
        if (this.targetAlphabet == null) {
            this.targetAlphabet = instancePipe.getTargetAlphabet();
        }
        this.instancePipe = instancePipe;
    }

    private void initFeatureNormalizer(String normalizer) {
        if (normalizer != null) {
            if (normalizer.equalsIgnoreCase("sum")) {
                this.featureNormalizer = new SumNormalizor();
            } else if (normalizer.equalsIgnoreCase("zscore")) {
                this.featureNormalizer = new ZScoreNormalizor();
            } else if (normalizer.equalsIgnoreCase("linear")) {
                this.featureNormalizer = new LinearNormalizer();
            } else {
                throw new IllegalArgumentException("Unknown normalizer: " + normalizer);
            }
        }
    }

    public double score(InstanceList documentList, METRIC scoringMetric, int k) {
        MetricScorer scorer = this.metricScorerFactory.createScorer(scoringMetric, k);
        Map<String, RankList> rankLists = this.convertToRankList(documentList);
        return scorer.score(rankLists.values().stream().collect(Collectors.toList()));
    }

    public Alphabet getDataAlphabet() {
        return this.dataAlphabet;
    }

    public Alphabet getTargetAlphabet() {
        return this.targetAlphabet;
    }

    public void train(InstanceList documents) {
        this.dataAlphabet = documents.getDataAlphabet();
        this.targetAlphabet = documents.getTargetAlphabet();
        this.setInstancePipe(documents.getPipe());
        log.info("Training on {} documents without validation set.", (Object)documents.size());
        Map<String, RankList> rankLists = this.convertToRankList(documents);
        this.features = this.features != null ? this.features : FeatureManager.getFeatureFromSampleVector(new ArrayList<RankList>(rankLists.values()));
        this.ranker = new RankerTrainer().train(this.rType, new ArrayList<RankList>(rankLists.values()), this.features, this.metricScorerFactory.createScorer(this.trainMetric, this.k));
    }

    public void train(InstanceList documents, boolean doValidation, float fraction, int randomSeed) {
        List<RankList> validation;
        ArrayList<RankList> train;
        this.setInstancePipe(documents.getPipe());
        if (!doValidation) {
            log.info("Training on {} documents without validation set.", (Object)documents.size());
        } else {
            log.info("Training on {} documents where a fraction of {} is used for training and the rest for validation. The split is done randomly with a seed of {}.", documents.size(), Float.valueOf(fraction), randomSeed);
        }
        Map<String, RankList> rankLists = this.convertToRankList(documents);
        if (this.featureNormalizer != null) {
            rankLists.values().forEach(this.featureNormalizer::normalize);
        }
        if (doValidation) {
            Pair<Map<String, RankList>, Map<String, RankList>> trainValData = this.makeValidationSplit(rankLists, fraction, randomSeed);
            train = new ArrayList<RankList>(trainValData.getLeft().values());
            validation = new ArrayList<RankList>(trainValData.getRight().values());
        } else {
            train = new ArrayList<RankList>(rankLists.values());
            validation = Collections.emptyList();
        }
        this.features = this.features != null ? this.features : FeatureManager.getFeatureFromSampleVector(new ArrayList<RankList>(rankLists.values()));
        this.ranker = new RankerTrainer().train(this.rType, train, validation, this.features, this.metricScorerFactory.createScorer(this.trainMetric, this.k));
        if (!documents.isEmpty()) {
            log.trace("LtR features: " + documents.getAlphabet());
        }
    }

    private Pair<Map<String, RankList>, Map<String, RankList>> makeValidationSplit(Map<String, RankList> allData, float fraction, int randomSeed) {
        int i;
        if (fraction < 0.0f || fraction >= 1.0f) {
            throw new IllegalArgumentException("The fraction to be taken from the training data for validation is specified as " + fraction + " but it must be in [0, 1).");
        }
        int size = (int)(fraction * (float)allData.size());
        log.info("Splitting into training size of {} and validation size of {} queries", (Object)size, (Object)(allData.size() - size));
        ArrayList<RankList> shuffledData = new ArrayList<RankList>(allData.values());
        Collections.shuffle(shuffledData, new Random(randomSeed));
        HashMap<String, RankList> train = new HashMap<String, RankList>();
        HashMap<String, RankList> val = new HashMap<String, RankList>();
        for (i = 0; i < size; ++i) {
            train.put(((RankList)shuffledData.get(i)).getID(), (RankList)shuffledData.get(i));
        }
        for (i = size; i < shuffledData.size(); ++i) {
            val.put(((RankList)shuffledData.get(i)).getID(), (RankList)shuffledData.get(i));
        }
        return new ImmutablePair<Map<String, RankList>, Map<String, RankList>>(train, val);
    }

    private Map<String, RankList> convertToRankList(InstanceList documents) {
        LinkedHashMap dataPointsByQueryId = documents.stream().map(d -> {
            FeatureVector fv = (FeatureVector)d.getData();
            if (fv == null) {
                throw new IllegalArgumentException("Cannot train a ranker because the input documents have no feature vector.");
            }
            double[] values = fv.getValues();
            int[] indices = fv.getIndices();
            if (values != null && values.length > 0 || indices != null && indices.length > 0) {
                int numFeatures;
                int i;
                float[] ranklibValues = new float[fv.numLocations()];
                int[] ranklibIndices = new int[fv.numLocations()];
                if (values == null) {
                    Arrays.fill(ranklibValues, 1.0f);
                } else {
                    for (i = 0; i < fv.numLocations(); ++i) {
                        ranklibValues[i] = (float)values[i];
                    }
                }
                for (i = 0; i < fv.numLocations(); ++i) {
                    ranklibIndices[i] = indices != null ? indices[i] + 1 : i + 1;
                }
                String queryId = d.getName().toString();
                int n = numFeatures = this.features != null && this.features.length > 0 ? this.features[this.features.length - 1] : -1;
                if (numFeatures == -1) {
                    numFeatures = ranklibIndices != null && ranklibIndices.length > 0 ? ranklibIndices[ranklibIndices.length - 1] : 0;
                }
                SparseDataPoint dp = new SparseDataPoint(ranklibValues, ranklibIndices, numFeatures, queryId, ((Float)((Label)d.getTarget()).getEntry()).floatValue());
                dp.setDescription("#" + d.getSource());
                return dp;
            }
            return null;
        }).filter(Objects::nonNull).collect(Collectors.groupingBy(DataPoint::getID, LinkedHashMap::new, Collectors.toList()));
        LinkedHashMap<String, RankList> rankLists = new LinkedHashMap<String, RankList>();
        dataPointsByQueryId.forEach((key, value) -> rankLists.put((String)key, new RankList((List<DataPoint>)value)));
        return rankLists;
    }

    public void load(File modelFile) throws IOException {
        try (BufferedReader br = FileUtilities.getReaderFromFile(modelFile);){
            String model = br.lines().collect(Collectors.joining(System.getProperty("line.separator")));
            this.ranker = new RankerFactory().loadRankerFromString(model);
        }
    }

    public void save(File modelFile) {
        if (!modelFile.getParentFile().exists()) {
            modelFile.getParentFile().mkdirs();
        }
        this.ranker.save(modelFile.getAbsolutePath());
    }

    public String getModelAsString() {
        return this.ranker.model();
    }

    public void loadFromString(String modelString) {
        this.ranker = new RankerFactory().loadRankerFromString(modelString);
    }

    public InstanceList rank(InstanceList documents) {
        Function<Instance, String> instance2uniqueId = i -> i.getName() + "#" + i.getSource();
        Function<DataPoint, String> datapoint2uniqueId = dp -> dp.getID() + dp.getDescription();
        Map docsById = documents.stream().collect(Collectors.toMap(instance2uniqueId::apply, Function.identity()));
        if (docsById.size() != documents.size()) {
            throw new IllegalArgumentException("The passed documents do not have unique IDs. The input document list has size " + documents + ", its ID map form only " + docsById.size());
        }
        Map<String, RankList> rankLists = this.convertToRankList(documents);
        if (this.featureNormalizer != null) {
            rankLists.values().forEach(this.featureNormalizer::normalize);
        }
        for (RankList rl : rankLists.values()) {
            for (int i2 = 0; i2 < rl.size(); ++i2) {
                DataPoint dp2 = rl.get(i2);
                double score = this.ranker.eval(dp2);
                Instance doc = (Instance)docsById.get(datapoint2uniqueId.apply(dp2));
                doc.setProperty("score", score);
            }
        }
        InstanceList ret = new InstanceList(documents.getDataAlphabet(), documents.getTargetAlphabet());
        ret.addAll(documents);
        ret.stream().filter(Predicate.not(d -> d.hasProperty("score"))).forEach(d -> d.setProperty("score", Double.MIN_VALUE));
        Collections.sort(ret, Comparator.comparingDouble(d -> (Double)d.getProperty("score")).reversed());
        return ret;
    }

    public Ranker getRankLibRanker() {
        return this.ranker;
    }

    @Override
    public Alphabet getAlphabet() {
        return this.getDataAlphabet();
    }

    @Override
    public Alphabet[] getAlphabets() {
        return new Alphabet[]{this.getDataAlphabet(), this.getTargetAlphabet()};
    }
}

