/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.glm.solvers;

import com.github.chen0040.data.frame.BasicDataFrame;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.TupleTwo;
import com.github.chen0040.glm.solvers.Glm;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import java.util.stream.Collectors;

public class OneVsOneGlmClassifier {
    protected List<TupleTwo<Glm, Glm>> classifiers;
    private double alpha = 0.1;
    private boolean shuffleData = false;
    private List<String> classLabels = new ArrayList<String>();
    private Supplier<Glm> generator = () -> Glm.linear();
    private static String BINARY_LABEL = "success";

    public OneVsOneGlmClassifier(List<String> classLabels) {
        this.classLabels.addAll(classLabels);
        this.classifiers = new ArrayList<TupleTwo<Glm, Glm>>();
    }

    public OneVsOneGlmClassifier() {
        this.classifiers = new ArrayList<TupleTwo<Glm, Glm>>();
    }

    public OneVsOneGlmClassifier(Supplier<Glm> binaryClassifierGenerator) {
        this.classifiers = new ArrayList<TupleTwo<Glm, Glm>>();
        this.generator = binaryClassifierGenerator;
    }

    public boolean isShuffleData() {
        return this.shuffleData;
    }

    public void setShuffleData(boolean shuffleData) {
        this.shuffleData = shuffleData;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    protected void createClassifiers(DataFrame dataFrame) {
        this.classifiers = new ArrayList<TupleTwo<Glm, Glm>>();
        if (this.classLabels.size() == 0) {
            this.classLabels.addAll(dataFrame.stream().map(DataRow::categoricalTarget).distinct().collect(Collectors.toList()));
        }
        for (int i = 0; i < this.classLabels.size() - 1; ++i) {
            for (int j = i + 1; j < this.classLabels.size(); ++j) {
                Glm svr1 = this.createClassifier(this.classLabels.get(i));
                Glm svr2 = this.createClassifier(this.classLabels.get(j));
                this.classifiers.add((TupleTwo<Glm, Glm>)new TupleTwo((Object)svr1, (Object)svr2));
            }
        }
    }

    protected Glm createClassifier(String classLabel) {
        Glm svr = this.generator.get();
        svr.setName(classLabel);
        return svr;
    }

    protected double getClassifierScore(DataRow tuple, Glm classifier) {
        return classifier.transform(tuple);
    }

    protected List<DataFrame> split(DataFrame dataFrame, int n) {
        ArrayList<DataFrame> miniFrames = new ArrayList<DataFrame>();
        for (int i = 0; i < n; ++i) {
            miniFrames.add((DataFrame)new BasicDataFrame());
        }
        int index = 0;
        for (DataRow tuple : dataFrame) {
            int batchIndex = index % n;
            ((DataFrame)miniFrames.get(batchIndex)).addRow(tuple);
            ++index;
        }
        return miniFrames;
    }

    protected List<DataFrame> remerge(List<DataFrame> batches, int k) {
        ArrayList<DataFrame> newBatches = new ArrayList<DataFrame>();
        for (int i = 0; i < batches.size(); ++i) {
            BasicDataFrame newBatch = new BasicDataFrame();
            for (int j = 0; j < k; ++j) {
                int d = (i + j) % batches.size();
                DataFrame batch = batches.get(d);
                for (DataRow tuple : batch) {
                    newBatch.addRow(tuple.makeCopy());
                }
            }
            newBatches.add((DataFrame)newBatch);
        }
        return newBatches;
    }

    public double transform(DataRow row) {
        String label = this.classify(row);
        return this.classLabels.indexOf(label);
    }

    public void fit(DataFrame dataFrame) {
        this.createClassifiers(dataFrame);
        if (this.shuffleData) {
            dataFrame.shuffle();
        }
        List<DataFrame> batches = this.split(dataFrame, this.classifiers.size());
        int k = Math.max(1, (int)this.alpha * batches.size());
        batches = this.remerge(batches, k);
        for (int i = 0; i < this.classifiers.size(); ++i) {
            TupleTwo<Glm, Glm> pair = this.classifiers.get(i);
            Glm classifier1 = (Glm)pair._1();
            Glm classifier2 = (Glm)pair._2();
            classifier1.fit(this.createBinaryBatch(batches.get(i), classifier1.getName()));
            classifier2.fit(this.createBinaryBatch(batches.get(i), classifier2.getName()));
        }
    }

    private DataFrame createBinaryBatch(DataFrame dataFrame, String classLabel) {
        BasicDataFrame binaryBatch = new BasicDataFrame();
        for (DataRow row : dataFrame) {
            String label = row.categoricalTarget();
            DataRow rowWithBinaryTargetOutput = row.makeCopy();
            rowWithBinaryTargetOutput.setTargetCell(BINARY_LABEL, label.equals(classLabel) ? 1.0 : 0.0);
            binaryBatch.addRow(rowWithBinaryTargetOutput);
        }
        return binaryBatch;
    }

    public String classify(DataRow row) {
        if ((row = row.makeCopy()).getTargetColumnNames().isEmpty()) {
            row.setTargetColumnNames(Collections.singletonList(BINARY_LABEL));
        }
        Map<String, Integer> scores = this.score(row);
        String predicatedClassLabel = null;
        int maxScore = 0;
        for (Map.Entry<String, Integer> entry : scores.entrySet()) {
            String label = entry.getKey();
            int score = entry.getValue();
            if (score <= maxScore) continue;
            maxScore = score;
            predicatedClassLabel = label;
        }
        if (predicatedClassLabel == null) {
            predicatedClassLabel = "NA";
        }
        return predicatedClassLabel;
    }

    public void reset() {
        this.classifiers.clear();
        this.classLabels.clear();
    }

    public List<String> getClassLabels() {
        return this.classLabels;
    }

    public Map<String, Integer> score(DataRow row) {
        HashMap<String, Integer> scores = new HashMap<String, Integer>();
        for (int i = 0; i < this.classifiers.size(); ++i) {
            double score2;
            TupleTwo<Glm, Glm> pair = this.classifiers.get(i);
            Glm classifier1 = (Glm)pair._1();
            Glm classifier2 = (Glm)pair._2();
            double score1 = this.getClassifierScore(row, classifier1);
            if (score1 == (score2 = this.getClassifierScore(row, classifier2))) continue;
            String winningLabel = score1 > score2 ? classifier1.getName() : classifier2.getName();
            if (scores.containsKey(winningLabel)) {
                scores.put(winningLabel, (Integer)scores.get(winningLabel) + 1);
                continue;
            }
            scores.put(winningLabel, 1);
        }
        return scores;
    }
}

