package water.rapids.ast.prims.reducers;

import java.util.ArrayList;
import java.util.Iterator;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import water.DKV;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.vals.ValFrame;

/* loaded from: input_file:water/rapids/ast/prims/reducers/AstMeanTest.class */
public class AstMeanTest extends TestUtil {
    private static Vec vi1;
    private static Vec vd1;
    private static Vec vd2;
    private static Vec vd3;
    private static Vec vs1;
    private static Vec vt1;
    private static Vec vt2;
    private static Vec vc1;
    private static Vec vc2;
    private static ArrayList<Frame> allFrames;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
        vi1 = TestUtil.ivec(-1, -2, 0, 2, 1);
        vd1 = TestUtil.dvec(1.5d, 2.5d, 3.5d, 4.5d, 8.0d);
        vd2 = TestUtil.dvec(0.2d, 0.4d, 0.6d, 0.8d, 1.0d);
        vd3 = TestUtil.dvec(1.0d, 2.0d, Double.NaN, 3.0d, Double.NaN);
        vs1 = TestUtil.svec("a", "b", "c", "d", "e");
        vt1 = TestUtil.tvec(10000000, 10000020, 10000030, 10000040, 10000060);
        vt2 = TestUtil.tvec(20000000, 20000020, 20000030, 20000040, 20000060);
        vc1 = TestUtil.cvec(ar("N", "Y"), "Y", "N", "Y", "Y", "N");
        vc2 = TestUtil.cvec("a", "c", "c", "b", "a");
        allFrames = new ArrayList<>(10);
    }

    @AfterClass
    public static void teardown() {
        for (Vec vec : (Vec[]) aro(vi1, vd1, vd2, vd3, vs1, vt1, vt2, vc1, vc2)) {
            vec.remove();
        }
        Iterator<Frame> it = allFrames.iterator();
        while (it.hasNext()) {
            it.next().delete();
        }
    }

    @Test
    public void testAstMeanGeneralStructure() {
        AstMean astMean = new AstMean();
        Assert.assertEquals(3L, astMean.args().length);
        Assert.assertTrue(astMean.example().startsWith("(mean "));
        Assert.assertTrue("Description for AstMean is too short", astMean.description().length() > 100);
    }

    @Test
    public void testColumnwiseMeanWithoutNaRm() {
        Frame register = register(new Frame(Key.make(), ar("I", "D", "DD", "DN", "T", "S", "C"), (Vec[]) aro(vi1, vd1, vd2, vd3, vt1, vs1, vc2)));
        Val exec = Rapids.exec("(mean " + register._key + " 0 0)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register2 = register(exec.getFrame());
        Assert.assertArrayEquals(register.names(), register2.names());
        Assert.assertArrayEquals(ar(3, 3, 3, 3, 5, 3, 3), register2.types());
        assertRowFrameEquals(ard(0.0d, 4.0d, 0.6d, Double.NaN, 1.000003E7d, Double.NaN, Double.NaN), register2);
    }

    @Test
    public void testColumnwiseMeanWithNaRm() {
        Frame register = register(new Frame(Key.make(), ar("I", "D", "DD", "DN", "T", "S", "C"), (Vec[]) aro(vi1, vd1, vd2, vd3, vt1, vs1, vc2)));
        Val exec = Rapids.exec("(mean " + register._key + " 1 0)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register2 = register(exec.getFrame());
        Assert.assertArrayEquals(register.names(), register2.names());
        Assert.assertArrayEquals(ar(3, 3, 3, 3, 5, 3, 3), register2.types());
        assertRowFrameEquals(ard(0.0d, 4.0d, 0.6d, 2.0d, 1.000003E7d, Double.NaN, Double.NaN), register2);
    }

    @Test
    public void testColumnwiseMeanOnEmptyFrame() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make()))._key + " 0 0)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        Assert.assertEquals(register.numCols(), 0L);
        Assert.assertEquals(register.numRows(), 0L);
    }

    @Test
    public void testColumnwiseMeanBinaryVec() {
        Assert.assertTrue(vc1.isBinary() && !vc2.isBinary());
        Frame register = register(new Frame(Key.make(), ar("C1", "C2"), (Vec[]) aro(vc1, vc2)));
        Val exec = Rapids.exec("(mean " + register._key + " 1 0)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register2 = register(exec.getFrame());
        Assert.assertArrayEquals(register.names(), register2.names());
        Assert.assertArrayEquals(ar(3, 3), register2.types());
        assertRowFrameEquals(ard(0.6d, Double.NaN), register2);
    }

    @Test
    public void testRowwiseMeanWithoutNaRm() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make(), ar("i1", "d1", "d2", "d3"), (Vec[]) aro(vi1, vd1, vd2, vd3)))._key + " 0 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        assertColFrameEquals(ard(0.425d, 0.725d, Double.NaN, 2.575d, Double.NaN), register);
        Assert.assertEquals("mean", register.name(0));
    }

    @Test
    public void testRowwiseMeanWithoutNaRmAndNonnumericColumn() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make(), ar("i1", "d1", "d2", "d3", "s1"), (Vec[]) aro(vi1, vd1, vd2, vd3, vs1)))._key + " 0 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        assertColFrameEquals(ard(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN), register);
        Assert.assertEquals("mean", register.name(0));
    }

    @Test
    public void testRowwiseMeanWithNaRm() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make(), ar("i1", "d1", "d2", "d3", "s1"), (Vec[]) aro(vi1, vd1, vd2, vd3, vs1)))._key + " 1 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        Assert.assertEquals("Unexpected column name", "mean", register.name(0));
        Assert.assertEquals("Unexpected column type", 3L, register.types()[0]);
        assertColFrameEquals(ard(0.425d, 0.725d, 1.3666666666666665d, 2.575d, 3.3333333333333335d), register);
    }

    @Test
    public void testRowwiseMeanOnFrameWithTimeColumnsOnly() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make(), ar("t1", "s", "t2"), (Vec[]) aro(vt1, vs1, vt2)))._key + " 1 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        Assert.assertEquals("Unexpected column name", "mean", register.name(0));
        Assert.assertEquals("Unexpected column type", 5L, register.types()[0]);
        assertColFrameEquals(ard(1.5E7d, 1.500002E7d, 1.500003E7d, 1.500004E7d, 1.500006E7d), register);
    }

    @Test
    public void testRowwiseMeanOnFrameWithTimeAndNumericColumn() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make(), ar("t1", "i1"), (Vec[]) aro(vt1, vi1)))._key + " 1 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        assertColFrameEquals(ard(-1.0d, -2.0d, 0.0d, 2.0d, 1.0d), register(exec.getFrame()));
    }

    @Test
    public void testRowwiseMeanOnEmptyFrame() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make()))._key + " 0 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        Assert.assertEquals(register.numCols(), 0L);
        Assert.assertEquals(register.numRows(), 0L);
    }

    @Test
    public void testRowwiseMeanOnFrameWithNonnumericColumnsOnly() {
        Val exec = Rapids.exec("(mean " + register(new Frame(Key.make(), ar("c1", "s1"), (Vec[]) aro(vc2, vs1)))._key + " 1 1)");
        Assert.assertTrue(exec instanceof ValFrame);
        Frame register = register(exec.getFrame());
        Assert.assertEquals("Unexpected column name", "mean", register.name(0));
        Assert.assertEquals("Unexpected column type", 3L, register.types()[0]);
        assertColFrameEquals(ard(Double.NaN, Double.NaN, Double.NaN, Double.NaN, Double.NaN), register);
    }

    @Test
    public void testBadFirstArgument() {
        try {
            Rapids.exec("(mean " + vi1._key + " 1 0)");
            Assert.fail();
        } catch (IllegalArgumentException e) {
        }
        try {
            Rapids.exec("(mean hello 1 0)");
            Assert.fail();
        } catch (IllegalArgumentException e2) {
        }
        try {
            Rapids.exec("(mean 2 1 0)");
            Assert.fail();
        } catch (IllegalArgumentException e3) {
        }
    }

    @Test
    public void testValRowArgument() {
        Frame register = register(new Frame(Key.make(), ar("i1", "d1", "d2", "d3"), (Vec[]) aro(vi1, vd1, vd2, vd3)));
        Val exec = Rapids.exec("(apply " + register._key + " 1 {x . (mean x 1)})");
        Assert.assertTrue(exec instanceof ValFrame);
        assertColFrameEquals(ard(0.425d, 0.725d, 1.3666666666666665d, 2.575d, 3.3333333333333335d), register(exec.getFrame()));
        Val exec2 = Rapids.exec("(apply " + register._key + " 1 {x . (mean x 0)})");
        Assert.assertTrue(exec2 instanceof ValFrame);
        assertColFrameEquals(ard(0.425d, 0.725d, Double.NaN, 2.575d, Double.NaN), register(exec2.getFrame()));
    }

    private static void assertRowFrameEquals(double[] dArr, Frame frame) {
        Assert.assertEquals(1L, frame.numRows());
        Assert.assertEquals(dArr.length, frame.numCols());
        for (int i = 0; i < dArr.length; i++) {
            Assert.assertEquals("Wrong average in column " + frame.name(i), dArr[i], frame.vec(i).at(0L), 1.0E-8d);
        }
    }

    private static void assertColFrameEquals(double[] dArr, Frame frame) {
        Assert.assertEquals(1L, frame.numCols());
        Assert.assertEquals(dArr.length, frame.numRows());
        for (int i = 0; i < dArr.length; i++) {
            Assert.assertEquals("Wrong average in row " + i, dArr[i], frame.vec(0).at(i), 1.0E-8d);
        }
    }

    private static Frame register(Frame frame) {
        if (frame._key != null) {
            DKV.put(frame._key, frame);
        }
        allFrames.add(frame);
        return frame;
    }
}
