package water.rapids.ast.prims.models;

import hex.Model;
import hex.ModelContainer;
import hex.leaderboard.AlgoName;
import hex.leaderboard.Leaderboard;
import hex.leaderboard.LeaderboardCell;
import hex.leaderboard.LeaderboardExtensionsProvider;
import hex.leaderboard.ScoringTimePerRow;
import hex.leaderboard.TrainingTime;
import java.util.Arrays;
import java.util.stream.Stream;
import water.DKV;
import water.Key;
import water.fvec.Frame;
import water.logging.LoggerFactory;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.Log;

/* loaded from: input_file:water/rapids/ast/prims/models/AstMakeLeaderboard.class */
public class AstMakeLeaderboard extends AstPrimitive {
    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"models", "leaderboardFrame", "sortMetric", "extensions", "scoringData"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 6;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "makeLeaderboard";
    }

    private static LeaderboardExtensionsProvider createLeaderboardExtensionProvider(final Frame frame) {
        return new LeaderboardExtensionsProvider() { // from class: water.rapids.ast.prims.models.AstMakeLeaderboard.1
            @Override // hex.leaderboard.LeaderboardExtensionsProvider
            public LeaderboardCell[] createExtensions(Model model) {
                return new LeaderboardCell[]{new TrainingTime(model), new ScoringTimePerRow(model, Frame.this), new AlgoName(model)};
            }
        };
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Key<Model>[] keyArr = (Key[]) Arrays.stream(stackHelp.track(astRootArr[1].exec(env)).getStrs()).flatMap(str -> {
            Cloneable get = DKV.getGet(str);
            if (get instanceof Model) {
                return Stream.of(Key.make(str));
            }
            if (get instanceof ModelContainer) {
                return Stream.of((Object[]) ((ModelContainer) get).getModelKeys());
            }
            throw new RuntimeException("Unsupported model/grid id: " + str + "!");
        }).toArray(i -> {
            return new Key[i];
        });
        String str2 = stackHelp.track(astRootArr[2].exec(env)).getStr();
        Frame frame = null;
        if (str2.isEmpty()) {
            str2 = null;
        } else {
            frame = (Frame) DKV.getGet(str2);
        }
        String str3 = stackHelp.track(astRootArr[3].exec(env)).getStr();
        String[] strs = stackHelp.track(astRootArr[4].exec(env)).getStrs();
        String lowerCase = stackHelp.track(astRootArr[5].exec(env)).getStr().toLowerCase();
        String str4 = str2 + "_" + Arrays.deepHashCode(keyArr) + "_" + lowerCase;
        Arrays.stream(keyArr).forEach(DKV::prefetch);
        if (str3.equalsIgnoreCase("auto")) {
            str3 = null;
        }
        boolean z = Arrays.stream(keyArr).map(key -> {
            return ((Model) DKV.getGet(key))._parms._train;
        }).distinct().count() == 1;
        boolean z2 = Arrays.stream(keyArr).map(key2 -> {
            return ((Model) DKV.getGet(key2))._parms._valid;
        }).distinct().count() == 1;
        boolean z3 = Arrays.stream(keyArr).map(key3 -> {
            return ((Model) DKV.getGet(key3))._parms;
        }).filter(parameters -> {
            return !parameters.algoName().equalsIgnoreCase("stackedensemble");
        }).map(parameters2 -> {
            return Integer.valueOf(parameters2._nfolds);
        }).distinct().count() == 1;
        boolean allMatch = Arrays.stream(keyArr).allMatch(key4 -> {
            return ((Model) DKV.getGet(key4))._output._cross_validation_metrics != null;
        });
        boolean allMatch2 = Arrays.stream(keyArr).allMatch(key5 -> {
            return ((Model) DKV.getGet(key5))._output._validation_metrics != null;
        });
        boolean z4 = false;
        boolean z5 = false;
        boolean z6 = false;
        boolean z7 = false;
        if (lowerCase.equals("auto") && frame == null) {
            z4 = true;
            z6 = true;
            lowerCase = "xval";
        }
        if (lowerCase.equals("xval")) {
            z4 = true;
            z6 = true;
            z7 = true;
            if (!allMatch) {
                lowerCase = "valid";
            }
        }
        if (lowerCase.equals("valid")) {
            z4 = false;
            z5 = true;
            z7 = true;
            if (!allMatch2) {
                lowerCase = "train";
            }
        }
        if (lowerCase.equals("train")) {
            z4 = true;
            z5 = false;
            z7 = true;
        }
        if (z4 && !z) {
            Log.warn("More than one training frame was used amongst the models provided to the leaderboard.");
        }
        if (z5 && !z2) {
            Log.warn("More than one validation frame was used amongst the models provided to the leaderboard.");
        }
        if (z6 && !z3) {
            Log.warn("More than one n-folds settings are present.");
        }
        if (z7 && frame != null) {
            Log.warn("Leaderboard frame present but scoring data are set to " + lowerCase + ". Using scores from " + lowerCase + ".");
        }
        Leaderboard orMake = Leaderboard.getOrMake(str4, LoggerFactory.getLogger(Leaderboard.class), frame, str3, Leaderboard.ScoreData.valueOf(lowerCase));
        orMake.setExtensionsProvider(createLeaderboardExtensionProvider(frame));
        orMake.addModels(keyArr);
        orMake.ensureSorted();
        return new ValFrame(orMake.toTwoDimTable(strs).asFrame(Key.make()));
    }
}
