package water.rapids.ast.prims.mungers;

import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.mockito.Mockito;
import water.Key;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.TestFrameBuilder;
import water.rapids.Env;
import water.rapids.Rapids;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNumList;
import water.rapids.vals.ValFrame;

@RunWith(Parameterized.class)
/* loaded from: input_file:water/rapids/ast/prims/mungers/AstScaleTest.class */
public class AstScaleTest extends TestUtil {

    @Rule
    public transient ExpectedException ee = ExpectedException.none();

    @Parameterized.Parameter
    public String scaleFunc;

    @BeforeClass
    public static void setup() {
        stall_till_cloudsize(1);
    }

    @Parameterized.Parameters
    public static Object[] scaleFunc() {
        return new Object[]{"scale", "scale_inplace"};
    }

    @Test
    public void testScaleNumeric() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames("C1", "C2").withVecTypes(3, 3).withDataForCol(0, ard(0.0d, 2.0d, 4.0d, 6.0d)).withDataForCol(1, ard(0.0d, 0.0d, 0.0d, 1.0d)).build();
            Frame build2 = new TestFrameBuilder().withColNames("C1", "C2").withVecTypes(3, 3).withDataForCol(0, ard(-1.161895003862225d, -0.3872983346207417d, 0.3872983346207417d, 1.161895003862225d)).withDataForCol(1, ard(-0.5d, -0.5d, -0.5d, 1.5d)).build();
            ValFrame valFrame = (ValFrame) Rapids.exec("(" + this.scaleFunc + " " + build._key + " 1 1)");
            compareFrames(build2, valFrame.getFrame(), 1.0E-10d);
            checkInPlace(build, valFrame);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testScaleNumericWithCategoricals() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames("C1", "C2").withVecTypes(4, 3).withDataForCol(0, ar("a", "b", "c", "d")).withDataForCol(1, ard(0.0d, 0.0d, 0.0d, 1.0d)).build();
            Frame build2 = new TestFrameBuilder().withColNames("C1", "C2").withVecTypes(4, 3).withDataForCol(0, ar("a", "b", "c", "d")).withDataForCol(1, ard(-0.5d, -0.5d, -0.5d, 1.5d)).build();
            ValFrame valFrame = (ValFrame) Rapids.exec("(" + this.scaleFunc + " " + build._key + " 1 1)");
            compareFrames(build2, valFrame.getFrame(), 1.0E-10d);
            checkInPlace(build, valFrame);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testScaleNoNumeric() {
        Scope.enter();
        try {
            Frame build = new TestFrameBuilder().withColNames("C1", "C2").withVecTypes(4, 4).withDataForCol(0, ar("a", "b", "c", "d")).withDataForCol(1, ar("a", "b", "c", "d")).build();
            Frame build2 = new TestFrameBuilder().withColNames("C1", "C2").withVecTypes(4, 4).withDataForCol(0, ar("a", "b", "c", "d")).withDataForCol(1, ar("a", "b", "c", "d")).build();
            ValFrame valFrame = (ValFrame) Rapids.exec("(" + this.scaleFunc + " " + build._key + " 1 1)");
            compareFrames(build2, valFrame.getFrame());
            checkInPlace(build, valFrame);
            Scope.exit(new Key[0]);
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    @Test
    public void testCalcMeans_invalidCols() {
        Frame frame = (Frame) Mockito.mock(Frame.class);
        Frame frame2 = (Frame) Mockito.mock(Frame.class);
        AstNumList astNumList = (AstRoot) Mockito.mock(AstNumList.class);
        Mockito.when(Integer.valueOf(frame2.numCols())).thenReturn(43);
        Mockito.when(astNumList.expand()).thenReturn(new double[4]);
        this.ee.expectMessage("Values must be the same length as is the number of columns of the Frame to scale (fill 0 for non-numeric columns).");
        AstScale.calcMeans((Env) null, astNumList, frame2, frame);
    }

    @Test
    public void testCalcMults_invalidCols() {
        Frame frame = (Frame) Mockito.mock(Frame.class);
        Frame frame2 = (Frame) Mockito.mock(Frame.class);
        AstNumList astNumList = (AstRoot) Mockito.mock(AstNumList.class);
        Mockito.when(Integer.valueOf(frame2.numCols())).thenReturn(43);
        Mockito.when(astNumList.expand()).thenReturn(new double[4]);
        this.ee.expectMessage("Values must be the same length as is the number of columns of the Frame to scale (fill 0 for non-numeric columns).");
        AstScale.calcMults((Env) null, astNumList, frame2, frame);
    }

    private void checkInPlace(Frame frame, ValFrame valFrame) {
        Frame frame2 = valFrame.getFrame();
        if ("scale_inplace".equals(this.scaleFunc)) {
            assertFrameEquals(frame, frame2, 0.0d);
            return;
        }
        for (int i = 0; i < frame.numCols(); i++) {
            if (frame.vec(i).get_type() == 3) {
                Assert.assertNotEquals(Double.valueOf(frame.vec(i).max()), Double.valueOf(frame2.vec(i).max()));
            } else {
                assertCatVecEquals(frame.vec(i), frame2.vec(i));
            }
        }
    }
}
