package ai.h2o.xgboost4j.java;

import com.sun.jna.platform.win32.COM.tlb.imp.TlbBase;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.fs.FileSystem;

/* loaded from: input_file:ai/h2o/xgboost4j/java/XGBoost.class */
public class XGBoost {
    private static final Log logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/h2o/xgboost4j/java/XGBoost$CVPack.class */
    public static class CVPack {
        DMatrix dtrain;
        DMatrix dtest;
        DMatrix[] dmats;
        String[] names = {"train", "test"};
        Booster booster;

        public CVPack(DMatrix dMatrix, DMatrix dMatrix2, Map<String, Object> map) throws XGBoostError {
            this.dmats = new DMatrix[]{dMatrix, dMatrix2};
            this.booster = Booster.newBooster(map, this.dmats);
            this.dtrain = dMatrix;
            this.dtest = dMatrix2;
        }

        public void update(int i) throws XGBoostError {
            this.booster.update(this.dtrain, i);
        }

        public void update(IObjective iObjective) throws XGBoostError {
            this.booster.update(this.dtrain, iObjective);
        }

        public String eval(int i) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, i);
        }

        public String eval(IEvaluation iEvaluation) throws XGBoostError {
            return this.booster.evalSet(this.dmats, this.names, iEvaluation);
        }
    }

    public static Booster loadModel(String str) throws XGBoostError {
        return Booster.loadModel(str);
    }

    public static Booster loadModel(InputStream inputStream) throws XGBoostError, IOException {
        byte[] bArr = new byte[1048576];
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        while (true) {
            int read = inputStream.read(bArr);
            if (read == -1) {
                inputStream.close();
                return Booster.loadModel(byteArrayOutputStream.toByteArray());
            }
            byteArrayOutputStream.write(bArr, 0, read);
        }
    }

    public static Booster loadModel(byte[] bArr) throws XGBoostError, IOException {
        return Booster.loadModel(bArr);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int i, Map<String, DMatrix> map2, IObjective iObjective, IEvaluation iEvaluation) throws XGBoostError {
        return train(dMatrix, map, i, map2, (float[][]) null, iObjective, iEvaluation, 0);
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int i, Map<String, DMatrix> map2, float[][] fArr, IObjective iObjective, IEvaluation iEvaluation, int i2) throws XGBoostError {
        return train(dMatrix, map, i, map2, fArr, iObjective, iEvaluation, i2, null);
    }

    private static void saveCheckpoint(Booster booster, int i, Set<Integer> set, ExternalCheckpointManager externalCheckpointManager) throws XGBoostError {
        try {
            if (set.contains(Integer.valueOf(i))) {
                externalCheckpointManager.updateCheckpoint(booster);
            }
        } catch (Exception e) {
            logger.error("failed to save checkpoint in XGBoost4J at iteration " + i, e);
            throw new XGBoostError("failed to save checkpoint in XGBoost4J at iteration" + i, e);
        }
    }

    public static Booster trainAndSaveCheckpoint(DMatrix dMatrix, Map<String, Object> map, int i, Map<String, DMatrix> map2, float[][] fArr, IObjective iObjective, IEvaluation iEvaluation, int i2, Booster booster, int i3, String str, FileSystem fileSystem) throws XGBoostError, IOException {
        DMatrix[] dMatrixArr;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        HashSet hashSet = new HashSet();
        ExternalCheckpointManager externalCheckpointManager = str != null ? new ExternalCheckpointManager(str, fileSystem) : null;
        for (Map.Entry<String, DMatrix> entry : map2.entrySet()) {
            arrayList.add(entry.getKey());
            arrayList2.add(entry.getValue());
        }
        String[] strArr = (String[]) arrayList.toArray(new String[arrayList.size()]);
        DMatrix[] dMatrixArr2 = (DMatrix[]) arrayList2.toArray(new DMatrix[arrayList2.size()]);
        float f = isMaximizeEvaluation(map) ? -3.4028235E38f : Float.MAX_VALUE;
        int i4 = 0;
        float[][] fArr2 = fArr == null ? new float[strArr.length][i] : fArr;
        if (dMatrixArr2.length > 0) {
            dMatrixArr = new DMatrix[dMatrixArr2.length + 1];
            dMatrixArr[0] = dMatrix;
            System.arraycopy(dMatrixArr2, 0, dMatrixArr, 1, dMatrixArr2.length);
        } else {
            dMatrixArr = new DMatrix[]{dMatrix};
        }
        if (booster == null) {
            booster = Booster.newBooster(map, dMatrixArr);
            booster.loadRabitCheckpoint();
        } else {
            booster.setParams(map);
        }
        if (externalCheckpointManager != null) {
            hashSet = new HashSet(externalCheckpointManager.getCheckpointRounds(i3, i));
        }
        int version = booster.getVersion() / 2;
        while (true) {
            if (version >= i) {
                break;
            }
            if (booster.getVersion() % 2 == 0) {
                if (iObjective != null) {
                    booster.update(dMatrix, iObjective);
                } else {
                    booster.update(dMatrix, version);
                }
                saveCheckpoint(booster, version, hashSet, externalCheckpointManager);
                booster.saveRabitCheckpoint();
            }
            if (dMatrixArr2.length > 0) {
                float[] fArr3 = new float[dMatrixArr2.length];
                String evalSet = iEvaluation != null ? booster.evalSet(dMatrixArr2, strArr, iEvaluation, fArr3) : booster.evalSet(dMatrixArr2, strArr, version, fArr3);
                for (int i5 = 0; i5 < fArr3.length; i5++) {
                    fArr2[i5][version] = fArr3[i5];
                }
                float f2 = fArr3[fArr3.length - 1];
                if (isMaximizeEvaluation(map)) {
                    if (f2 > f) {
                        f = f2;
                        i4 = version;
                        booster.setAttr("best_iteration", String.valueOf(i4));
                        booster.setAttr("best_score", String.valueOf(f));
                    }
                } else if (f2 < f) {
                    f = f2;
                    i4 = version;
                    booster.setAttr("best_iteration", String.valueOf(i4));
                    booster.setAttr("best_score", String.valueOf(f));
                }
                if (shouldEarlyStop(i2, version, i4)) {
                    if (shouldPrint(map, version)) {
                        Rabit.trackerPrint(String.format("early stopping after %d rounds away from the best iteration", Integer.valueOf(i2)));
                    }
                } else if (Rabit.getRank() == 0 && shouldPrint(map, version) && shouldPrint(map, version)) {
                    Rabit.trackerPrint(evalSet + '\n');
                }
            }
            booster.saveRabitCheckpoint();
            version++;
        }
        return booster;
    }

    public static Booster train(DMatrix dMatrix, Map<String, Object> map, int i, Map<String, DMatrix> map2, float[][] fArr, IObjective iObjective, IEvaluation iEvaluation, int i2, Booster booster) throws XGBoostError {
        try {
            return trainAndSaveCheckpoint(dMatrix, map, i, map2, fArr, iObjective, iEvaluation, i2, booster, -1, null, null);
        } catch (IOException e) {
            logger.error("training failed in xgboost4j", e);
            throw new XGBoostError("training failed in xgboost4j ", e);
        }
    }

    private static Integer tryGetIntFromObject(Object obj) {
        if (obj instanceof Integer) {
            return Integer.valueOf(((Integer) obj).intValue());
        }
        if (!(obj instanceof String)) {
            return null;
        }
        try {
            return Integer.valueOf(Integer.parseInt((String) obj));
        } catch (NumberFormatException e) {
            return null;
        }
    }

    private static boolean shouldPrint(Map<String, Object> map, int i) {
        Object obj = map.get("silent");
        Integer tryGetIntFromObject = tryGetIntFromObject(obj);
        if (obj != null) {
            if (obj.equals("true") || obj.equals("True")) {
                return false;
            }
            if (tryGetIntFromObject != null && tryGetIntFromObject.intValue() != 0) {
                return false;
            }
        }
        Object obj2 = map.get("verbose_eval");
        Integer tryGetIntFromObject2 = tryGetIntFromObject(obj2);
        if (obj2 == null) {
            return true;
        }
        if (obj2.equals("false") || obj2.equals("False")) {
            return false;
        }
        if (tryGetIntFromObject2 != null) {
            return tryGetIntFromObject2.intValue() != 0 && i % tryGetIntFromObject2.intValue() == 0;
        }
        return true;
    }

    static boolean shouldEarlyStop(int i, int i2, int i3) {
        return i > 0 && i2 - i3 >= i;
    }

    private static boolean isMaximizeEvaluation(Map<String, Object> map) {
        try {
            String valueOf = String.valueOf(map.get("maximize_evaluation_metrics"));
            if ($assertionsDisabled || valueOf != null) {
                return Boolean.valueOf(valueOf).booleanValue();
            }
            throw new AssertionError();
        } catch (Exception e) {
            logger.error("maximize_evaluation_metrics has to be specified for enabling early stop, allowed value: true/false", e);
            throw e;
        }
    }

    public static String[] crossValidation(DMatrix dMatrix, Map<String, Object> map, int i, int i2, String[] strArr, IObjective iObjective, IEvaluation iEvaluation) throws XGBoostError {
        CVPack[] makeNFold = makeNFold(dMatrix, i2, map, strArr);
        String[] strArr2 = new String[i];
        String[] strArr3 = new String[makeNFold.length];
        for (int i3 = 0; i3 < i; i3++) {
            for (CVPack cVPack : makeNFold) {
                if (iObjective != null) {
                    cVPack.update(iObjective);
                } else {
                    cVPack.update(i3);
                }
            }
            for (int i4 = 0; i4 < makeNFold.length; i4++) {
                if (iEvaluation != null) {
                    strArr3[i4] = makeNFold[i4].eval(iEvaluation);
                } else {
                    strArr3[i4] = makeNFold[i4].eval(i3);
                }
            }
            strArr2[i3] = aggCVResults(strArr3);
            logger.info(strArr2[i3]);
        }
        return strArr2;
    }

    private static CVPack[] makeNFold(DMatrix dMatrix, int i, Map<String, Object> map, String[] strArr) throws XGBoostError {
        List<Integer> genRandPermutationNums = genRandPermutationNums(0, (int) dMatrix.rowNum());
        int size = genRandPermutationNums.size() / i;
        int[] iArr = new int[size];
        int[] iArr2 = new int[genRandPermutationNums.size() - size];
        CVPack[] cVPackArr = new CVPack[i];
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = 0;
            int i4 = 0;
            for (int i5 = 0; i5 < genRandPermutationNums.size(); i5++) {
                if (i5 > i2 * size && i5 < (i2 * size) + size && i3 < size) {
                    iArr[i3] = genRandPermutationNums.get(i5).intValue();
                    i3++;
                } else if (i4 < genRandPermutationNums.size() - size) {
                    iArr2[i4] = genRandPermutationNums.get(i5).intValue();
                    i4++;
                } else {
                    iArr[i3] = genRandPermutationNums.get(i5).intValue();
                    i3++;
                }
            }
            CVPack cVPack = new CVPack(dMatrix.slice(iArr2), dMatrix.slice(iArr), map);
            if (strArr != null) {
                for (String str : strArr) {
                    cVPack.booster.setParam("eval_metric", str);
                }
            }
            cVPackArr[i2] = cVPack;
        }
        return cVPackArr;
    }

    private static List<Integer> genRandPermutationNums(int i, int i2) {
        ArrayList arrayList = new ArrayList();
        for (int i3 = i; i3 < i2; i3++) {
            arrayList.add(Integer.valueOf(i3));
        }
        Collections.shuffle(arrayList);
        return arrayList;
    }

    private static String aggCVResults(String[] strArr) {
        HashMap hashMap = new HashMap();
        String str = strArr[0].split(TlbBase.TAB)[0];
        for (String str2 : strArr) {
            String[] split = str2.split(TlbBase.TAB);
            for (int i = 1; i < split.length; i++) {
                String[] split2 = split[i].split(":");
                String str3 = split2[0];
                Float valueOf = Float.valueOf(split2[1]);
                if (!hashMap.containsKey(str3)) {
                    hashMap.put(str3, new ArrayList());
                }
                ((List) hashMap.get(str3)).add(valueOf);
            }
        }
        for (String str4 : hashMap.keySet()) {
            float f = 0.0f;
            Iterator it = ((List) hashMap.get(str4)).iterator();
            while (it.hasNext()) {
                f += ((Float) it.next()).floatValue();
            }
            str = str + String.format("\tcv-%s:%f", str4, Float.valueOf(f / ((List) hashMap.get(str4)).size()));
        }
        return str;
    }

    static {
        $assertionsDisabled = !XGBoost.class.desiredAssertionStatus();
        logger = LogFactory.getLog(XGBoost.class);
    }
}
