/*
 * Decompiled with CFR 0.152.
 */
package de.datexis.sector.eval;

import de.datexis.annotator.AnnotatorEvaluation;
import de.datexis.model.Annotation;
import de.datexis.model.Document;
import de.datexis.sector.model.SectionAnnotation;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.LoggerFactory;

public class SegmentationEvaluation
extends AnnotatorEvaluation {
    protected final double DEFAULT_SCORE = 0.0;
    protected double countExp = 0.0;
    protected double countPred = 0.0;
    protected double pksum = 0.0;
    protected double wdsum = 0.0;
    protected boolean enableKPerDocument = false;
    protected boolean enableMergeSections = true;

    public SegmentationEvaluation(String experimentName) {
        this(experimentName, Annotation.Source.GOLD, Annotation.Source.PRED);
    }

    public SegmentationEvaluation(String experimentName, Annotation.Source expected, Annotation.Source predicted) {
        super(experimentName, expected, predicted);
        this.log = LoggerFactory.getLogger(SegmentationEvaluation.class);
        this.clear();
    }

    public SegmentationEvaluation withRecalculateK(boolean enabled) {
        this.enableKPerDocument = enabled;
        return this;
    }

    public SegmentationEvaluation withMergeEnabled(boolean enabled) {
        this.enableMergeSections = enabled;
        return this;
    }

    protected void clear() {
        this.countDocs = 0;
        this.countExamples = 0;
        this.pksum = 0.0;
        this.wdsum = 0.0;
        this.countExp = 0.0;
        this.countPred = 0.0;
    }

    public double getScore() {
        return this.getWD();
    }

    public void calculateScores(Collection<Document> docs) {
        this.calculateScoresFromAnnotations(docs, SectionAnnotation.class);
    }

    public void calculateScoresFromAnnotations(Collection<Document> docs, Class<? extends Annotation> annotationClass) {
        this.countDocs += docs.size();
        int k = this.calculateK(docs);
        for (Document doc : docs) {
            if (this.enableKPerDocument) {
                k = this.calculateK(doc);
            }
            this.wdsum += this.calculateWD(doc, k);
            this.pksum += this.calculatePk(doc, k);
            this.countExp += (double)this.getMassesArray(doc, this.expectedSource).length;
            this.countPred += (double)this.getMassesArray(doc, this.predictedSource).length;
        }
    }

    public double getWD() {
        return this.wdsum / (double)this.countDocs;
    }

    public double getPk() {
        return this.pksum / (double)this.countDocs;
    }

    public double getCountExpected() {
        return this.countExp;
    }

    public double getCountPredicted() {
        return this.countPred;
    }

    public double calculatePk(Document doc, int k) {
        int[] reference = this.getPositionsArray(doc, this.expectedSource);
        int[] hypothesis = this.getPositionsArray(doc, this.predictedSource);
        double sum = 0.0;
        double count = 0.0;
        for (int t = 0; t < reference.length - k; ++t) {
            boolean agreeHyp;
            boolean agreeRef = reference[t] == reference[t + k];
            boolean bl = agreeHyp = hypothesis[t] == hypothesis[t + k];
            if (agreeRef != agreeHyp) {
                sum += 1.0;
            }
            count += 1.0;
        }
        if (reference.length == 2) {
            boolean agreeHyp;
            assert (count == 0.0);
            boolean agreeRef = reference[0] == reference[1];
            boolean bl = agreeHyp = hypothesis[0] == hypothesis[1];
            if (agreeRef == agreeHyp) {
                return 0.0;
            }
            return 1.0;
        }
        if (reference.length == 1) {
            return 0.0;
        }
        if (count > 0.0) {
            return sum / count;
        }
        return 0.0;
    }

    public double calculateWD(Document doc, int k) {
        int[] reference = this.getPositionsArray(doc, this.expectedSource);
        int[] hypothesis = this.getPositionsArray(doc, this.predictedSource);
        double sum = 0.0;
        double count = 0.0;
        for (int t = 0; t < reference.length - k; ++t) {
            int sumRef = 0;
            int sumHyp = 0;
            for (int j = 0; j < k; ++j) {
                boolean agreeHyp;
                if (reference[t + j] == 0) {
                    this.log.warn("document is not correctly annotated");
                    return 1.0;
                }
                boolean agreeRef = reference[t + j] == reference[t + j + 1];
                boolean bl = agreeHyp = hypothesis[t + j] == hypothesis[t + j + 1];
                if (agreeRef) {
                    ++sumRef;
                }
                if (!agreeHyp) continue;
                ++sumHyp;
            }
            if (sumRef != sumHyp) {
                sum += 1.0;
            }
            count += 1.0;
        }
        if (reference.length == 2) {
            boolean agreeHyp;
            assert (count == 0.0);
            boolean agreeRef = reference[0] == reference[1];
            boolean bl = agreeHyp = hypothesis[0] == hypothesis[1];
            if (agreeRef == agreeHyp) {
                return 0.0;
            }
            return 1.0;
        }
        if (reference.length == 1) {
            return 0.0;
        }
        if (count > 0.0) {
            return sum / count;
        }
        return 0.0;
    }

    public int calculateK(Collection<Document> docs) {
        int k = Math.max((int)Math.round(this.getMeanSegmentLength(docs) / 2.0), 2);
        this.log.trace("setting k to {}", (Object)k);
        return k;
    }

    public int calculateK(Document doc) {
        int[] masses;
        double sum = 0.0;
        for (int c : masses = this.getMassesArray(doc, this.expectedSource)) {
            sum += (double)c;
        }
        int k = Math.max((int)Math.round(sum / (double)masses.length / 2.0), 2);
        return k;
    }

    public double getMeanSegmentLength(Collection<Document> docs) {
        double sum = 0.0;
        double count = 0.0;
        for (Document doc : docs) {
            int[] masses;
            for (int c : masses = this.getMassesArray(doc, this.expectedSource)) {
                sum += (double)c;
            }
            count += (double)masses.length;
        }
        return sum / count;
    }

    public int[] getMassesArray(Document doc, Annotation.Source source) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        int[] positions = this.getPositionsArray(doc, source);
        int last = 0;
        int count = 0;
        for (int curr : positions) {
            if (curr != last) {
                if (count > 0) {
                    result.add(count);
                }
                last = curr;
                count = 0;
            }
            ++count;
        }
        if (count > 0) {
            result.add(count);
        }
        return result.stream().mapToInt(Integer::valueOf).toArray();
    }

    public int[] getPositionsArray(Document doc, Annotation.Source source) {
        int[] array = new int[doc.countSentences()];
        int sectionIndex = 0;
        int cursor = 0;
        String lastSection = "";
        List anns = doc.streamAnnotations(source, SectionAnnotation.class).sorted().collect(Collectors.toList());
        for (SectionAnnotation ann : anns) {
            int begin = doc.getSentenceIndexAtPosition(ann.getBegin());
            if (begin < cursor) {
                this.log.warn("document '{}' is not properly annotated at sentence {}", (Object)doc.getId(), (Object)cursor);
            }
            if (begin < 0) {
                begin = 0;
            }
            for (int t = cursor; t < begin; ++t) {
                array[t] = sectionIndex;
                ++cursor;
            }
            String currentSection = this.enableMergeSections ? ann.getSectionLabelOrHeading() : Integer.toString(ann.getBegin());
            if (!currentSection.equals(lastSection)) {
                ++sectionIndex;
            }
            lastSection = currentSection;
        }
        for (int t = cursor; t < array.length; ++t) {
            array[t] = sectionIndex;
        }
        return array;
    }
}

