package hex;

import hex.DataInfo;
import hex.glm.GLMModel;
import hex.splitframe.ShuffleSplitFrame;
import java.util.Random;
import org.junit.BeforeClass;
import org.junit.Test;
import water.Key;
import water.MRTask;
import water.TestUtil;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.IcedHashMap;

/* loaded from: input_file:hex/DataInfoTestAdapt.class */
public class DataInfoTestAdapt extends TestUtil {
    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Test
    public void testInteractionTrainTestSplitAdapt() {
        DataInfo dataInfo = null;
        DataInfo dataInfo2 = null;
        Frame frame = null;
        Frame frame2 = null;
        Frame[] frameArr = null;
        Frame[] frameArr2 = null;
        String[] strArr = {"class", "sepal_len"};
        try {
            frame = parse_test_file(Key.make("a.hex"), "smalldata/iris/iris_wheader.csv");
            frame.swap(3, 4);
            frame2 = GLMModel.GLMOutput.expand(frame, strArr, false, false, true);
            long nextLong = new Random().nextLong();
            frameArr = ShuffleSplitFrame.shuffleSplitFrame(frame, new Key[]{Key.make(), Key.make()}, new double[]{0.8d, 0.2d}, nextLong);
            frameArr2 = ShuffleSplitFrame.shuffleSplitFrame(frame2, new Key[]{Key.make(), Key.make()}, new double[]{0.8d, 0.2d}, nextLong);
            checkSplits(frameArr, frameArr2, strArr, false, false);
            dataInfo = makeInfo(frameArr[0], strArr, false, false);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "petal_wid";
            Model.adaptTestForTrain(frameArr[1], (String[]) null, (String[][]) null, dataInfo._adaptedFrame.names(), dataInfo._adaptedFrame.domains(), gLMParameters, true, false, strArr, (ToEigenVec) null, (IcedHashMap) null, false);
            dataInfo2 = dataInfo.scoringInfo(dataInfo._adaptedFrame._names, frameArr[1]);
            checkFrame(dataInfo2, frameArr2[1]);
            cleanup(frame, frame2);
            cleanup(frameArr);
            cleanup(frameArr2);
            cleanup(dataInfo, dataInfo2);
        } catch (Throwable th) {
            cleanup(frame, frame2);
            cleanup(frameArr);
            cleanup(frameArr2);
            cleanup(dataInfo, dataInfo2);
            throw th;
        }
    }

    @Test
    public void testInteractionTrainTestSplitAdaptAirlines() {
        DataInfo dataInfo = null;
        DataInfo dataInfo2 = null;
        Frame frame = null;
        Frame frame2 = null;
        Frame frame3 = null;
        Frame[] frameArr = null;
        Frame[] frameArr2 = null;
        String[] strArr = {"CRSDepTime", "Origin"};
        String[] strArr2 = {"Year", "Month", "DayofMonth", "DayOfWeek", "CRSDepTime", "CRSArrTime", "UniqueCarrier", "CRSElapsedTime", "Origin", "Dest", "Distance", "IsDepDelayed"};
        try {
            frame = parse_test_file(Key.make("a.hex"), "smalldata/airlines/allyears2k_headers.zip");
            frame2 = frame.subframe(strArr2);
            frame3 = GLMModel.GLMOutput.expand(frame2, strArr, false, false, false);
            long nextLong = new Random().nextLong();
            frameArr = ShuffleSplitFrame.shuffleSplitFrame(frame2, new Key[]{Key.make(), Key.make()}, new double[]{0.8d, 0.2d}, nextLong);
            frameArr2 = ShuffleSplitFrame.shuffleSplitFrame(frame3, new Key[]{Key.make(), Key.make()}, new double[]{0.8d, 0.2d}, nextLong);
            checkSplits(frameArr, frameArr2, strArr, false, false, false);
            dataInfo = makeInfo(frameArr[0], strArr, false, false, false);
            GLMModel.GLMParameters gLMParameters = new GLMModel.GLMParameters();
            gLMParameters._response_column = "IsDepDelayed";
            Model.adaptTestForTrain(frameArr[1], (String[]) null, (String[][]) null, dataInfo._adaptedFrame.names(), dataInfo._adaptedFrame.domains(), gLMParameters, true, false, strArr, (ToEigenVec) null, (IcedHashMap) null, false);
            dataInfo2 = dataInfo.scoringInfo(dataInfo._adaptedFrame._names, frameArr[1]);
            checkFrame(dataInfo2, frameArr2[1], false);
            cleanup(frame2, frame, frame3);
            cleanup(frameArr);
            cleanup(frameArr2);
            cleanup(dataInfo, dataInfo2);
        } catch (Throwable th) {
            cleanup(frame2, frame, frame3);
            cleanup(frameArr);
            cleanup(frameArr2);
            cleanup(dataInfo, dataInfo2);
            throw th;
        }
    }

    private void cleanup(Frame... frameArr) {
        for (Frame frame : frameArr) {
            if (null != frame) {
                frame.delete();
            }
        }
    }

    private void cleanup(DataInfo... dataInfoArr) {
        for (DataInfo dataInfo : dataInfoArr) {
            if (null != dataInfo) {
                dataInfo.dropInteractions();
                dataInfo.remove();
            }
        }
    }

    private void checkSplits(Frame[] frameArr, Frame[] frameArr2, String[] strArr, boolean z, boolean z2) {
        checkSplits(frameArr, frameArr2, strArr, z, z2, false);
    }

    private void checkSplits(Frame[] frameArr, Frame[] frameArr2, String[] strArr, boolean z, boolean z2, boolean z3) {
        for (int i = 0; i < frameArr.length; i++) {
            checkFrame(makeInfo(frameArr[i], strArr, z, z2, z3), frameArr2[i], z3);
        }
    }

    private static DataInfo makeInfo(Frame frame, String[] strArr, boolean z, boolean z2) {
        return makeInfo(frame, strArr, z, z2, true);
    }

    private static DataInfo makeInfo(Frame frame, String[] strArr, boolean z, boolean z2, boolean z3) {
        return new DataInfo(frame, (Frame) null, 1, z, z2 ? DataInfo.TransformType.STANDARDIZE : DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, z3, false, false, false, false, false, strArr);
    }

    /* JADX WARN: Type inference failed for: r0v4, types: [hex.DataInfoTestAdapt$1] */
    private void checkFrame(final Frame frame, Frame frame2) {
        new MRTask() { // from class: hex.DataInfoTestAdapt.1
            public void map(Chunk[] chunkArr) {
                int numCols = frame.numCols();
                for (int i = 0; i < numCols; i++) {
                    for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                        if (Math.abs(chunkArr[i].atd(i2) - chunkArr[i + numCols].atd(i2)) > 1.0E-12d) {
                            throw new RuntimeException("bonk");
                        }
                    }
                }
            }
        }.doAll(new Vec[frame.numCols() + frame2.numCols()]);
    }

    private void checkFrame(DataInfo dataInfo, Frame frame) {
        checkFrame(dataInfo, frame, true);
    }

    /* JADX WARN: Type inference failed for: r0v13, types: [hex.DataInfoTestAdapt$2] */
    private void checkFrame(final DataInfo dataInfo, Frame frame, final boolean z) {
        try {
            Vec[] vecArr = new Vec[dataInfo._adaptedFrame.numCols() + frame.numCols()];
            System.arraycopy(dataInfo._adaptedFrame.vecs(), 0, vecArr, 0, dataInfo._adaptedFrame.numCols());
            System.arraycopy(frame.vecs(), 0, vecArr, dataInfo._adaptedFrame.numCols(), frame.numCols());
            new MRTask() { // from class: hex.DataInfoTestAdapt.2
                public void map(Chunk[] chunkArr) {
                    int numCols = dataInfo._adaptedFrame.numCols();
                    DataInfo.Row newDenseRow = dataInfo.newDenseRow();
                    for (int i = 0; i < chunkArr[0]._len; i++) {
                        dataInfo.extractDenseRow(chunkArr, i, newDenseRow);
                        if (!z || !newDenseRow.isBad()) {
                            for (int i2 = 0; i2 < dataInfo.fullN(); i2++) {
                                double abs = Math.abs(chunkArr[numCols + i2].atd(i) - newDenseRow.get(i2));
                                if (abs > 1.0E-12d) {
                                    if (z || abs >= 10.0d) {
                                        throw new RuntimeException("bonk");
                                    }
                                    System.out.println("row mismatch: " + i + " column= " + i2 + "; diff= " + abs + " but not skipping missing, so due to discrepancies in taking mean on split frames");
                                }
                            }
                        }
                    }
                }
            }.doAll(vecArr);
            dataInfo.dropInteractions();
            dataInfo.remove();
        } catch (Throwable th) {
            dataInfo.dropInteractions();
            dataInfo.remove();
            throw th;
        }
    }
}
