package hex;

import hex.ModelMetricsRegression;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import water.DKV;
import water.Key;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Merge;

/* loaded from: input_file:hex/ModelMetricsRegressionCoxPH.class */
public class ModelMetricsRegressionCoxPH extends ModelMetricsRegression {
    private double _concordance;
    private long _concordant;
    private long _discordant;
    private long _tied_y;

    /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$MetricBuilderRegressionCoxPH.class */
    public static class MetricBuilderRegressionCoxPH<T extends MetricBuilderRegressionCoxPH<T>> extends ModelMetricsRegression.MetricBuilderRegression<T> {
        private final String startVecName;
        private final String stopVecName;
        private final boolean isStratified;
        private final String[] stratifyBy;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$MetricBuilderRegressionCoxPH$PairStats.class */
        public static class PairStats {
            final long pairs;
            final long concordant;
            final long tied;
            final int next_ix;

            public PairStats(long j, long j2, long j3, int i) {
                this.pairs = j;
                this.concordant = j2;
                this.tied = j3;
                this.next_ix = i;
            }

            public String toString() {
                return "PairStats{pairs=" + this.pairs + ", concordant=" + this.concordant + ", tied=" + this.tied + ", next_ix=" + this.next_ix + '}';
            }
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$MetricBuilderRegressionCoxPH$Stats.class */
        public static class Stats {
            final long ntotals;
            final long nconcordant;
            final long nties;

            Stats() {
                this(0L, 0L, 0L);
            }

            Stats(long j, long j2, long j3) {
                this.ntotals = j;
                this.nconcordant = j2;
                this.nties = j3;
            }

            double c() {
                return (this.nconcordant + (0.5d * this.nties)) / this.ntotals;
            }

            long discordant() {
                return (this.ntotals - this.nconcordant) - this.nties;
            }

            public String toString() {
                return "Stats{ntotals=" + this.ntotals + ", nconcordant=" + this.nconcordant + ", ndiscordant=" + discordant() + ", nties=" + this.nties + '}';
            }

            Stats plus(Stats stats) {
                return new Stats(this.ntotals + stats.ntotals, this.nconcordant + stats.nconcordant, this.nties + stats.nties);
            }
        }

        public MetricBuilderRegressionCoxPH(String str, String str2, boolean z, String[] strArr) {
            this.startVecName = str;
            this.stopVecName = str2;
            this.isStratified = z;
            this.stratifyBy = strArr;
        }

        @Override // hex.ModelMetricsRegression.MetricBuilderRegression, hex.ModelMetrics.MetricBuilder
        public ModelMetricsRegressionCoxPH makeModelMetrics(Model model, Frame frame, Frame frame2, Frame frame3) {
            ModelMetricsRegression computeModelMetrics = super.computeModelMetrics(model, frame, frame2, frame3);
            Stats concordance = concordance(model, frame, frame2, frame3);
            ModelMetricsRegressionCoxPH modelMetricsRegressionCoxPH = new ModelMetricsRegressionCoxPH(model, frame, this._count, computeModelMetrics.mse(), weightedSigma(), computeModelMetrics.mae(), computeModelMetrics.rmsle(), computeModelMetrics.mean_residual_deviance(), this._customMetric, concordance.c(), concordance.nconcordant, concordance.discordant(), concordance.nties);
            if (model != null) {
                model.addModelMetrics(modelMetricsRegressionCoxPH);
            }
            return modelMetricsRegressionCoxPH;
        }

        private Stats concordance(Model model, Frame frame, Frame frame2, Frame frame3) {
            return concordance(frame2.vec(this.startVecName), frame2.vec(this.stopVecName), frame2.lastVec(), this.isStratified ? (List) Arrays.asList(this.stratifyBy).stream().map(str -> {
                return frame.vec(str);
            }).collect(Collectors.toList()) : Collections.emptyList(), frame3.lastVec());
        }

        static Stats concordance(Vec vec, Vec vec2, Vec vec3, List<Vec> list, Vec vec4) {
            try {
                Scope.enter();
                Stats concordanceStats = concordanceStats(prepareFrameForConcordanceComputation(vec3, list, vec4, durations(vec, vec2)));
                Scope.exit(new Key[0]);
                return concordanceStats;
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }

        private static Frame prepareFrameForConcordanceComputation(Vec vec, List<Vec> list, Vec vec2, Vec vec3) {
            Frame frame = new Frame(new Vec[0]);
            frame.add("duration", vec3);
            frame.add("event", vec);
            frame.add("estimate", vec2);
            for (int i = 0; i < list.size(); i++) {
                frame.add("strata_" + i, list.get(i));
            }
            return frame;
        }

        private static Vec durations(Vec vec, Vec vec2) {
            if (null == vec) {
                return vec2;
            }
            Vec vec3 = new MRTask() { // from class: hex.ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH.1
                @Override // water.MRTask
                public void map(Chunk chunk, Chunk chunk2, NewChunk newChunk) {
                    for (int i = 0; i < chunk._len; i++) {
                        newChunk.addNum(chunk2.atd(i) - chunk.atd(i));
                    }
                }
            }.doAll((byte) 3, vec, vec2).outputFrame(new String[]{"durations"}, (String[][]) null).vec(0);
            DKV.put(vec3);
            Scope.track(vec3);
            return vec3;
        }

        private static Stats concordanceStats(Frame frame) {
            Frame removeNAs = removeNAs(frame);
            int[] iArr = new int[removeNAs.numCols() - 2];
            int[] iArr2 = new int[removeNAs.numCols() - 3];
            for (int i = 0; i < iArr2.length; i++) {
                iArr[i] = i + 3;
                iArr2[i] = i + 3;
            }
            iArr[removeNAs.numCols() - 3] = 0;
            if (0 == removeNAs.numRows()) {
                return new Stats();
            }
            Frame sort = removeNAs.sort(iArr);
            Scope.track(sort);
            List list = (List) Arrays.stream(iArr2).boxed().map(num -> {
                Vec vec = sort.vec(num.intValue());
                vec.getClass();
                return new Vec.Reader();
            }).collect(Collectors.toList());
            long j = 0;
            ArrayList arrayList = new ArrayList(sort.numCols() - 3);
            Stats stats = new Stats();
            long j2 = 0;
            while (true) {
                long j3 = j2;
                if (j3 >= sort.numRows()) {
                    Vec vec = sort.vec("duration");
                    vec.getClass();
                    Vec.Reader reader = new Vec.Reader();
                    Vec vec2 = sort.vec("event");
                    vec2.getClass();
                    Vec.Reader reader2 = new Vec.Reader();
                    Vec vec3 = sort.vec("estimate");
                    vec3.getClass();
                    return stats.plus(statsForAStrata(reader, reader2, new Vec.Reader(), j, sort.numRows()));
                }
                ArrayList arrayList2 = new ArrayList(sort.numCols() - 3);
                Iterator it = list.iterator();
                while (it.hasNext()) {
                    arrayList2.add(Double.valueOf(((Vec.Reader) it.next()).at(j3)));
                }
                if (!arrayList.equals(arrayList2)) {
                    arrayList = arrayList2;
                    Vec vec4 = sort.vec("duration");
                    vec4.getClass();
                    Vec.Reader reader3 = new Vec.Reader();
                    Vec vec5 = sort.vec("event");
                    vec5.getClass();
                    Vec.Reader reader4 = new Vec.Reader();
                    Vec vec6 = sort.vec("estimate");
                    vec6.getClass();
                    Stats statsForAStrata = statsForAStrata(reader3, reader4, new Vec.Reader(), j, j3);
                    j = j3;
                    stats = stats.plus(statsForAStrata);
                }
                j2 = j3 + 1;
            }
        }

        private static Frame removeNAs(Frame frame) {
            Frame outputFrame = new Merge.RemoveNAsTask(0, 2).doAll(frame.types(), frame).outputFrame(frame.names(), frame.domains());
            Scope.track(outputFrame);
            Arrays.stream(outputFrame.vecs()).forEach(Scope::track);
            outputFrame.replace(1, outputFrame.vec("event"));
            return outputFrame;
        }

        private static Stats statsForAStrata(Vec.Reader reader, Vec.Reader reader2, Vec.Reader reader3, long j, long j2) {
            boolean z;
            boolean z2;
            if (j2 == j) {
                return new Stats();
            }
            int i = 0;
            int i2 = 0;
            long j3 = j;
            while (true) {
                long j4 = j3;
                if (j4 >= j2) {
                    break;
                }
                if (CMAESOptimizer.DEFAULT_STOPFITNESS == reader2.at(j4)) {
                    i++;
                } else {
                    i2++;
                }
                j3 = j4 + 1;
            }
            long[] jArr = new long[i2];
            long[] jArr2 = new long[i];
            int i3 = 0;
            int i4 = 0;
            long j5 = j;
            while (true) {
                long j6 = j5;
                if (j6 >= j2) {
                    break;
                }
                if (CMAESOptimizer.DEFAULT_STOPFITNESS == reader2.at(j6)) {
                    int i5 = i3;
                    i3++;
                    jArr2[i5] = j6;
                } else {
                    int i6 = i4;
                    i4++;
                    jArr[i6] = j6;
                }
                j5 = j6 + 1;
            }
            if (!$assertionsDisabled && jArr2.length + jArr.length != j2 - j) {
                throw new AssertionError();
            }
            int i7 = 0;
            int i8 = 0;
            StatTree statTree = new StatTree(Arrays.stream(jArr).mapToDouble(j7 -> {
                return estimateTime(reader3, j7);
            }).distinct().sorted().toArray());
            long j8 = 0;
            long j9 = 0;
            long j10 = 0;
            while (true) {
                z = i8 < jArr2.length;
                z2 = i7 < jArr.length;
                if (z && (!z2 || deadTime(reader, jArr[i7]) > deadTime(reader, jArr2[i8]))) {
                    PairStats handlePairs = handlePairs(jArr2, reader3, i8, statTree);
                    j8 += handlePairs.pairs;
                    j9 += handlePairs.concordant;
                    j10 += handlePairs.tied;
                    i8 = handlePairs.next_ix;
                } else {
                    if (!z2 || (z && deadTime(reader, jArr[i7]) > deadTime(reader, jArr2[i8]))) {
                        break;
                    }
                    PairStats handlePairs2 = handlePairs(jArr, reader3, i7, statTree);
                    for (int i9 = i7; i9 < handlePairs2.next_ix; i9++) {
                        statTree.insert(estimateTime(reader3, jArr[i9]));
                    }
                    j8 += handlePairs2.pairs;
                    j9 += handlePairs2.concordant;
                    j10 += handlePairs2.tied;
                    i7 = handlePairs2.next_ix;
                }
            }
            if ($assertionsDisabled || !(z2 || z)) {
                return new Stats(j8, j9, j10);
            }
            throw new AssertionError();
        }

        private static double deadTime(Vec.Reader reader, long j) {
            return reader.at(j);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public static double estimateTime(Vec.Reader reader, long j) {
            return -reader.at(j);
        }

        static PairStats handlePairs(long[] jArr, Vec.Reader reader, int i, StatTree statTree) {
            int i2 = i;
            while (i2 < jArr.length && jArr[i2] == jArr[i]) {
                i2++;
            }
            long len = statTree.len() * (i2 - i);
            long j = 0;
            long j2 = 0;
            for (int i3 = i; i3 < i2; i3++) {
                StatTree.RankAndCount rankAndCount = statTree.rankAndCount(estimateTime(reader, jArr[i3]));
                j += rankAndCount.rank;
                j2 += rankAndCount.count;
            }
            return new PairStats(len, j, j2, i2);
        }

        static {
            $assertionsDisabled = !ModelMetricsRegressionCoxPH.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$StatTree.class */
    public static class StatTree {
        final double[] values;
        final long[] counts;
        static final /* synthetic */ boolean $assertionsDisabled;

        /* JADX INFO: Access modifiers changed from: package-private */
        /* loaded from: input_file:hex/ModelMetricsRegressionCoxPH$StatTree$RankAndCount.class */
        public static class RankAndCount {
            final long rank;
            final long count;

            public RankAndCount(long j, long j2) {
                this.rank = j;
                this.count = j2;
            }

            public String toString() {
                return "RankAndCount{rank=" + this.rank + ", count=" + this.count + '}';
            }
        }

        StatTree(double[] dArr) {
            if (!$assertionsDisabled && null == dArr) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && !sortedAscending(dArr)) {
                throw new AssertionError();
            }
            this.values = new double[dArr.length];
            addMissingValues(dArr, fillTree(dArr, 0, dArr.length, 0));
            this.counts = new long[dArr.length];
            if (!$assertionsDisabled && !containsAll(dArr, this.values)) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && !isSearchTree(this.values)) {
                throw new AssertionError();
            }
            if (!$assertionsDisabled && !allZeroes(this.counts)) {
                throw new AssertionError();
            }
        }

        private void addMissingValues(double[] dArr, int i) {
            int length = dArr.length - i;
            for (int i2 = 0; i2 < length; i2++) {
                this.values[i + i2] = dArr[i2 * 2];
            }
        }

        private int fillTree(double[] dArr, int i, int i2, int i3) {
            int i4 = i2 - i;
            if (0 >= i4) {
                return 0;
            }
            int numberOfLeadingZeros = (32 - Integer.numberOfLeadingZeros(i4 + 1)) - 1;
            int i5 = (1 << numberOfLeadingZeros) - 1;
            int min = ((1 << (numberOfLeadingZeros - 1)) - 1) + Math.min(i4 - i5, 1 << (numberOfLeadingZeros - 1));
            this.values[i3] = dArr[i + min];
            fillTree(dArr, i, i + min, leftChild(i3));
            fillTree(dArr, i + min + 1, i2, rightChild(i3));
            return i5;
        }

        private static boolean sortedAscending(double[] dArr) {
            for (int i = 1; i < dArr.length; i++) {
                if (dArr[i - 1] > dArr[i]) {
                    return false;
                }
            }
            return true;
        }

        private static boolean containsAll(double[] dArr, double[] dArr2) {
            for (double d : dArr2) {
                if (!ArrayUtils.contains(dArr, d)) {
                    return false;
                }
            }
            return true;
        }

        private static boolean isSearchTree(double[] dArr) {
            for (int i = 0; i < dArr.length; i++) {
                int leftChild = leftChild(i);
                if (leftChild < dArr.length && dArr[i] < dArr[leftChild]) {
                    return false;
                }
                int rightChild = rightChild(i);
                if (rightChild < dArr.length && dArr[i] > dArr[rightChild]) {
                    return false;
                }
            }
            return true;
        }

        private static boolean allZeroes(long[] jArr) {
            for (long j : jArr) {
                if (0 != j) {
                    return false;
                }
            }
            return true;
        }

        void insert(double d) {
            int i = 0;
            long length = this.values.length;
            while (i < length) {
                double d2 = this.values[i];
                long[] jArr = this.counts;
                int i2 = i;
                jArr[i2] = jArr[i2] + 1;
                if (d < d2) {
                    i = leftChild(i);
                } else if (d <= d2) {
                    return;
                } else {
                    i = rightChild(i);
                }
            }
            throw new IllegalArgumentException("Value " + d + " not contained in tree. Tree counts now in illegal state;");
        }

        public int size() {
            return this.values.length;
        }

        public long len() {
            return this.counts[0];
        }

        RankAndCount rankAndCount(double d) {
            int i = 0;
            int i2 = 0;
            while (i < this.values.length) {
                double d2 = this.values[i];
                if (d < d2) {
                    i = leftChild(i);
                } else {
                    if (d <= d2) {
                        long j = this.counts[i];
                        int leftChild = leftChild(i);
                        if (leftChild < this.values.length) {
                            long j2 = this.counts[leftChild];
                            j -= j2;
                            i2 = (int) (i2 + j2);
                            int rightChild = rightChild(i);
                            if (rightChild < this.values.length) {
                                j -= this.counts[rightChild];
                            }
                        }
                        return new RankAndCount(i2, j);
                    }
                    int i3 = (int) (i2 + this.counts[i]);
                    int rightChild2 = rightChild(i);
                    if (rightChild2 >= this.values.length) {
                        return new RankAndCount(i3, 0L);
                    }
                    i2 = (int) (i3 - this.counts[rightChild2]);
                    i = rightChild2;
                }
            }
            return new RankAndCount(i2, 0L);
        }

        public String toString() {
            return toString(new StringBuilder()).toString();
        }

        private StringBuilder toString(StringBuilder sb) {
            int i = 0;
            int i2 = 2;
            while (true) {
                if (i >= i2 - 1) {
                    sb.append("\n");
                    i2 *= 2;
                } else {
                    if (i >= this.values.length) {
                        return sb;
                    }
                    sb.append(this.values[i]).append('(').append(this.counts[i]).append(')').append(" ");
                    i++;
                }
            }
        }

        private static int leftChild(int i) {
            return (2 * i) + 1;
        }

        private static int rightChild(int i) {
            return (2 * i) + 2;
        }

        static {
            $assertionsDisabled = !ModelMetricsRegressionCoxPH.class.desiredAssertionStatus();
        }
    }

    public double concordance() {
        return this._concordance;
    }

    public long concordant() {
        return this._concordant;
    }

    public long discordant() {
        return this._discordant;
    }

    public long tiedY() {
        return this._tied_y;
    }

    public ModelMetricsRegressionCoxPH(Model model, Frame frame, long j, double d, double d2, double d3, double d4, double d5, CustomMetric customMetric, double d6, long j2, long j3, long j4) {
        super(model, frame, j, d, d2, d3, d4, d5, customMetric);
        this._concordance = d6;
        this._concordant = j2;
        this._discordant = j3;
        this._tied_y = j4;
    }

    public static ModelMetricsRegressionCoxPH getFromDKV(Model model, Frame frame) {
        ModelMetrics fromDKV = ModelMetrics.getFromDKV(model, frame);
        if (fromDKV instanceof ModelMetricsRegressionCoxPH) {
            return (ModelMetricsRegressionCoxPH) fromDKV;
        }
        throw new H2OIllegalArgumentException("Expected to find a Regression ModelMetrics for model: " + model._key.toString() + " and frame: " + frame._key.toString(), "Expected to find a ModelMetricsRegression for model: " + model._key.toString() + " and frame: " + frame._key.toString() + " but found a: " + fromDKV.getClass());
    }

    @Override // hex.ModelMetricsRegression, hex.ModelMetricsSupervised, hex.ModelMetrics
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        if (Double.isNaN(this._concordance)) {
            sb.append(" concordance: N/A\n");
        } else {
            sb.append(" concordance: " + ((float) this._concordance) + "\n");
        }
        sb.append(" concordant: " + this._concordant + "\n");
        sb.append(" discordant: " + this._discordant + "\n");
        sb.append(" tied.y: " + this._tied_y + "\n");
        return sb.toString();
    }
}
