package org.apache.parquet.hadoop;

import java.io.File;
import java.lang.management.ManagementFactory;
import java.util.Iterator;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.mapreduce.RecordWriter;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.parquet.hadoop.example.GroupWriteSupport;
import org.apache.parquet.hadoop.metadata.CompressionCodecName;
import org.apache.parquet.schema.MessageTypeParser;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/parquet/hadoop/TestMemoryManager.class */
public class TestMemoryManager {
    long expectedPoolSize;
    ParquetOutputFormat parquetOutputFormat;
    Configuration conf = new Configuration();
    String writeSchema = "message example {\nrequired int32 line;\nrequired binary content;\n}";
    int counter = 0;

    @Rule
    public TemporaryFolder temp = new TemporaryFolder();

    @Before
    public void setUp() throws Exception {
        this.parquetOutputFormat = new ParquetOutputFormat(new GroupWriteSupport());
        GroupWriteSupport.setSchema(MessageTypeParser.parseMessageType(this.writeSchema), this.conf);
        this.expectedPoolSize = Math.round(ManagementFactory.getMemoryMXBean().getHeapMemoryUsage().getMax() * 0.949999988079071d);
        this.conf.setLong("parquet.block.size", this.expectedPoolSize / 2);
        createWriter(0).close((TaskAttemptContext) null);
    }

    @Test
    public void testMemoryManagerUpperLimit() {
        long totalMemoryPool = ParquetOutputFormat.getMemoryManager().getTotalMemoryPool();
        Assert.assertTrue("Pool size should be within 10% of the expected value (expected = " + this.expectedPoolSize + " actual = " + totalMemoryPool + ")", Math.abs(this.expectedPoolSize - totalMemoryPool) < ((long) (((double) this.expectedPoolSize) * 0.1d)));
    }

    @Test
    public void testMemoryManager() throws Exception {
        long totalMemoryPool = ParquetOutputFormat.getMemoryManager().getTotalMemoryPool();
        long j = totalMemoryPool / 2;
        this.conf.setLong("parquet.block.size", j);
        Assert.assertTrue("Pool should hold 2 full row groups", 2 * j <= totalMemoryPool);
        Assert.assertTrue("Pool should not hold 3 full row groups", totalMemoryPool < 3 * j);
        Assert.assertEquals("Allocations should start out at 0", 0L, getTotalAllocation());
        RecordWriter createWriter = createWriter(1);
        Assert.assertTrue("Allocations should never exceed pool size", getTotalAllocation() <= totalMemoryPool);
        Assert.assertEquals("First writer should be limited by row group size", j, getTotalAllocation());
        RecordWriter createWriter2 = createWriter(2);
        Assert.assertTrue("Allocations should never exceed pool size", getTotalAllocation() <= totalMemoryPool);
        Assert.assertEquals("Second writer should be limited by row group size", 2 * j, getTotalAllocation());
        RecordWriter createWriter3 = createWriter(3);
        Assert.assertTrue("Allocations should never exceed pool size", getTotalAllocation() <= totalMemoryPool);
        createWriter.close((TaskAttemptContext) null);
        Assert.assertTrue("Allocations should never exceed pool size", getTotalAllocation() <= totalMemoryPool);
        Assert.assertEquals("Allocations should be increased to the row group size", 2 * j, getTotalAllocation());
        createWriter2.close((TaskAttemptContext) null);
        Assert.assertTrue("Allocations should never exceed pool size", getTotalAllocation() <= totalMemoryPool);
        Assert.assertEquals("Allocations should be increased to the row group size", j, getTotalAllocation());
        createWriter3.close((TaskAttemptContext) null);
        Assert.assertEquals("Allocations should be increased to the row group size", 0L, getTotalAllocation());
    }

    @Test
    public void testReallocationCallback() throws Exception {
        long totalMemoryPool = ParquetOutputFormat.getMemoryManager().getTotalMemoryPool();
        long j = totalMemoryPool / 2;
        this.conf.setLong("parquet.block.size", j);
        Assert.assertTrue("Pool should hold 2 full row groups", 2 * j <= totalMemoryPool);
        Assert.assertTrue("Pool should not hold 3 full row groups", totalMemoryPool < 3 * j);
        Runnable runnable = () -> {
            this.counter++;
        };
        ParquetOutputFormat.getMemoryManager().registerScaleCallBack("increment-test-counter", runnable);
        try {
            ParquetOutputFormat.getMemoryManager().registerScaleCallBack("increment-test-counter", runnable);
            Assert.fail("Duplicated registering callback should throw duplicates exception.");
        } catch (IllegalArgumentException e) {
        }
        RecordWriter createWriter = createWriter(1);
        RecordWriter createWriter2 = createWriter(2);
        RecordWriter createWriter3 = createWriter(3);
        createWriter.close((TaskAttemptContext) null);
        createWriter2.close((TaskAttemptContext) null);
        createWriter3.close((TaskAttemptContext) null);
        Assert.assertEquals("Allocations should be adjusted once", 1L, this.counter);
        Assert.assertEquals("Should not allow duplicate callbacks", 1L, ParquetOutputFormat.getMemoryManager().getScaleCallBacks().size());
    }

    private RecordWriter createWriter(int i) throws Exception {
        File newFile = this.temp.newFile(String.valueOf(i) + ".parquet");
        if (newFile.delete()) {
            return this.parquetOutputFormat.getRecordWriter(this.conf, new Path(newFile.toString()), CompressionCodecName.UNCOMPRESSED);
        }
        throw new RuntimeException("Could not delete file: " + newFile);
    }

    private long getTotalAllocation() {
        long j = 0;
        Iterator it = ParquetOutputFormat.getMemoryManager().getWriterList().keySet().iterator();
        while (it.hasNext()) {
            j += ((InternalParquetRecordWriter) it.next()).getRowGroupSizeThreshold();
        }
        return j;
    }
}
