package org.apache.tez.runtime.common.resources;

import org.apache.hadoop.conf.Configuration;
import org.apache.tez.dag.api.InputDescriptor;
import org.apache.tez.dag.api.OutputDescriptor;
import org.apache.tez.dag.api.ProcessorDescriptor;
import org.apache.tez.runtime.api.MemoryUpdateCallback;
import org.apache.tez.runtime.api.TezInputContext;
import org.apache.tez.runtime.api.TezOutputContext;
import org.apache.tez.runtime.api.TezProcessorContext;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/tez/runtime/common/resources/TestMemoryDistributor.class */
public class TestMemoryDistributor {
    protected Configuration conf = new Configuration();

    /* loaded from: input_file:org/apache/tez/runtime/common/resources/TestMemoryDistributor$MemoryUpdateCallbackForTest.class */
    private static class MemoryUpdateCallbackForTest implements MemoryUpdateCallback {
        long assigned;

        private MemoryUpdateCallbackForTest() {
            this.assigned = -1000L;
        }

        public void memoryAssigned(long j) {
            this.assigned = j;
        }
    }

    @Before
    public void setup() {
        this.conf.setBoolean("tez.task.scale.memory.enabled", true);
        this.conf.set("tez.task.scale.memory.allocator.class", ScalingAllocator.class.getName());
    }

    @Test(timeout = 5000)
    public void testScalingNoProcessor() {
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 1, this.conf);
        memoryDistributor.setJvmMemory(10000L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest3 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(5000L, memoryUpdateCallbackForTest3, createTestOutputContext(), createTestOutputDescriptor());
        memoryDistributor.makeInitialAllocations();
        Assert.assertEquals(2800L, memoryUpdateCallbackForTest.assigned);
        Assert.assertEquals(2800L, memoryUpdateCallbackForTest2.assigned);
        Assert.assertEquals(1400L, memoryUpdateCallbackForTest3.assigned);
    }

    @Test(timeout = 5000)
    public void testScalingNoProcessor2() {
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 0, this.conf);
        memoryDistributor.setJvmMemory(209715200L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(104857600L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(157286400L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor());
        memoryDistributor.makeInitialAllocations();
        Assert.assertEquals(58720256L, memoryUpdateCallbackForTest.assigned);
        Assert.assertEquals(88080384L, memoryUpdateCallbackForTest2.assigned);
    }

    @Test(timeout = 5000)
    public void testScalingProcessor() {
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 1, this.conf);
        memoryDistributor.setJvmMemory(10000L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest3 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(5000L, memoryUpdateCallbackForTest3, createTestOutputContext(), createTestOutputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest4 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(5000L, memoryUpdateCallbackForTest4, createTestProcessortContext(), createTestProcessorDescriptor());
        memoryDistributor.makeInitialAllocations();
        Assert.assertTrue(memoryUpdateCallbackForTest.assigned >= 2333 && memoryUpdateCallbackForTest.assigned <= 2334);
        Assert.assertTrue(memoryUpdateCallbackForTest2.assigned >= 2333 && memoryUpdateCallbackForTest2.assigned <= 2334);
        Assert.assertTrue(memoryUpdateCallbackForTest3.assigned >= 1166 && memoryUpdateCallbackForTest3.assigned <= 1167);
        Assert.assertTrue(memoryUpdateCallbackForTest4.assigned >= 1166 && memoryUpdateCallbackForTest4.assigned <= 1167);
    }

    @Test(timeout = 5000)
    public void testScalingDisabled() {
        Configuration configuration = new Configuration(this.conf);
        configuration.setBoolean("tez.task.scale.memory.enabled", false);
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 0, configuration);
        memoryDistributor.setJvmMemory(207093760L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(104857600L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(144965632L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor());
        memoryDistributor.makeInitialAllocations();
        Assert.assertEquals(104857600L, memoryUpdateCallbackForTest.assigned);
        Assert.assertEquals(144965632L, memoryUpdateCallbackForTest2.assigned);
    }

    @Test(timeout = 5000)
    public void testReserveFractionConfigured() {
        Configuration configuration = new Configuration(this.conf);
        configuration.setDouble("tez.task.scale.memory.reserve-fraction", 0.5d);
        MemoryDistributor memoryDistributor = new MemoryDistributor(2, 1, configuration);
        memoryDistributor.setJvmMemory(10000L);
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest2 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(10000L, memoryUpdateCallbackForTest2, createTestInputContext(), createTestInputDescriptor());
        MemoryUpdateCallbackForTest memoryUpdateCallbackForTest3 = new MemoryUpdateCallbackForTest();
        memoryDistributor.requestMemory(5000L, memoryUpdateCallbackForTest3, createTestOutputContext(), createTestOutputDescriptor());
        memoryDistributor.makeInitialAllocations();
        Assert.assertEquals(2000L, memoryUpdateCallbackForTest.assigned);
        Assert.assertEquals(2000L, memoryUpdateCallbackForTest2.assigned);
        Assert.assertEquals(1000L, memoryUpdateCallbackForTest3.assigned);
    }

    protected InputDescriptor createTestInputDescriptor() {
        InputDescriptor inputDescriptor = (InputDescriptor) Mockito.mock(InputDescriptor.class);
        ((InputDescriptor) Mockito.doReturn("InputClass").when(inputDescriptor)).getClassName();
        return inputDescriptor;
    }

    protected OutputDescriptor createTestOutputDescriptor() {
        OutputDescriptor outputDescriptor = (OutputDescriptor) Mockito.mock(OutputDescriptor.class);
        ((OutputDescriptor) Mockito.doReturn("OutputClass").when(outputDescriptor)).getClassName();
        return outputDescriptor;
    }

    protected ProcessorDescriptor createTestProcessorDescriptor() {
        ProcessorDescriptor processorDescriptor = (ProcessorDescriptor) Mockito.mock(ProcessorDescriptor.class);
        ((ProcessorDescriptor) Mockito.doReturn("ProcessorClass").when(processorDescriptor)).getClassName();
        return processorDescriptor;
    }

    protected TezInputContext createTestInputContext() {
        TezInputContext tezInputContext = (TezInputContext) Mockito.mock(TezInputContext.class);
        ((TezInputContext) Mockito.doReturn("input").when(tezInputContext)).getSourceVertexName();
        ((TezInputContext) Mockito.doReturn("task").when(tezInputContext)).getTaskVertexName();
        return tezInputContext;
    }

    protected TezOutputContext createTestOutputContext() {
        TezOutputContext tezOutputContext = (TezOutputContext) Mockito.mock(TezOutputContext.class);
        ((TezOutputContext) Mockito.doReturn("output").when(tezOutputContext)).getDestinationVertexName();
        ((TezOutputContext) Mockito.doReturn("task").when(tezOutputContext)).getTaskVertexName();
        return tezOutputContext;
    }

    protected TezProcessorContext createTestProcessortContext() {
        TezProcessorContext tezProcessorContext = (TezProcessorContext) Mockito.mock(TezProcessorContext.class);
        ((TezProcessorContext) Mockito.doReturn("task").when(tezProcessorContext)).getTaskVertexName();
        return tezProcessorContext;
    }
}
