/*
 * Decompiled with CFR 0.152.
 */
package org.apache.storm.hive.common;

import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.hadoop.hive.metastore.txn.TxnDbUtil;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hive.hcatalog.streaming.HiveEndPoint;
import org.apache.hive.hcatalog.streaming.RecordWriter;
import org.apache.hive.hcatalog.streaming.SerializationError;
import org.apache.hive.hcatalog.streaming.StreamingConnection;
import org.apache.hive.hcatalog.streaming.StreamingException;
import org.apache.hive.hcatalog.streaming.TransactionBatch;
import org.apache.storm.Config;
import org.apache.storm.hive.bolt.HiveSetupUtil;
import org.apache.storm.hive.bolt.mapper.DelimitedRecordHiveMapper;
import org.apache.storm.hive.bolt.mapper.HiveMapper;
import org.apache.storm.hive.common.HiveWriter;
import org.apache.storm.task.GeneralTopologyContext;
import org.apache.storm.topology.TopologyBuilder;
import org.apache.storm.tuple.Fields;
import org.apache.storm.tuple.Tuple;
import org.apache.storm.tuple.TupleImpl;
import org.apache.storm.tuple.Values;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.mockito.verification.VerificationMode;

public class TestHiveWriter {
    public static final String PART1_NAME = "city";
    public static final String PART2_NAME = "state";
    public static final String[] partNames = new String[]{"city", "state"};
    static final String dbName = "testdb";
    static final String tblName = "test_table2";
    final String[] partitionVals = new String[]{"sunnyvale", "ca"};
    final String[] colNames = new String[]{"id", "msg"};
    private final int port;
    private final String metaStoreURI;
    private final HiveConf conf;
    int timeout = 10000;
    UserGroupInformation ugi = null;
    private ExecutorService callTimeoutPool;

    public TestHiveWriter() throws Exception {
        this.port = 9083;
        this.metaStoreURI = null;
        int callTimeoutPoolSize = 1;
        this.callTimeoutPool = Executors.newFixedThreadPool(callTimeoutPoolSize, new ThreadFactoryBuilder().setNameFormat("hiveWriterTest").build());
        this.conf = HiveSetupUtil.getHiveConf();
        TxnDbUtil.setConfValues((HiveConf)this.conf);
        if (this.metaStoreURI != null) {
            this.conf.setVar(HiveConf.ConfVars.METASTOREURIS, this.metaStoreURI);
        }
    }

    @Test
    public void testInstantiate() throws Exception {
        DelimitedRecordHiveMapper mapper = new MockedDelemiteredRecordHiveMapper().withColumnFields(new Fields(this.colNames)).withPartitionFields(new Fields(partNames));
        HiveEndPoint endPoint = new HiveEndPoint(this.metaStoreURI, dbName, tblName, Arrays.asList(this.partitionVals));
        TestingHiveWriter writer = new TestingHiveWriter(endPoint, 10, true, this.timeout, this.callTimeoutPool, (HiveMapper)mapper, this.ugi, false);
        writer.close();
    }

    @Test
    public void testWriteBasic() throws Exception {
        DelimitedRecordHiveMapper mapper = new MockedDelemiteredRecordHiveMapper().withColumnFields(new Fields(this.colNames)).withPartitionFields(new Fields(partNames));
        HiveEndPoint endPoint = new HiveEndPoint(this.metaStoreURI, dbName, tblName, Arrays.asList(this.partitionVals));
        TestingHiveWriter writer = new TestingHiveWriter(endPoint, 10, true, this.timeout, this.callTimeoutPool, (HiveMapper)mapper, this.ugi, false);
        this.writeTuples(writer, (HiveMapper)mapper, 3);
        writer.flush(false);
        writer.close();
        ((TransactionBatch)Mockito.verify((Object)writer.getMockedTxBatch(), (VerificationMode)Mockito.times((int)3))).write((byte[])Mockito.any(byte[].class));
    }

    @Test
    public void testWriteMultiFlush() throws Exception {
        DelimitedRecordHiveMapper mapper = new MockedDelemiteredRecordHiveMapper().withColumnFields(new Fields(this.colNames)).withPartitionFields(new Fields(partNames));
        HiveEndPoint endPoint = new HiveEndPoint(this.metaStoreURI, dbName, tblName, Arrays.asList(this.partitionVals));
        TestingHiveWriter writer = new TestingHiveWriter(endPoint, 10, true, this.timeout, this.callTimeoutPool, (HiveMapper)mapper, this.ugi, false);
        Tuple tuple = this.generateTestTuple("1", "abc");
        writer.write(mapper.mapRecord(tuple));
        tuple = this.generateTestTuple("2", "def");
        writer.write(mapper.mapRecord(tuple));
        Assertions.assertEquals((int)writer.getTotalRecords(), (int)2);
        ((TransactionBatch)Mockito.verify((Object)writer.getMockedTxBatch(), (VerificationMode)Mockito.times((int)2))).write((byte[])Mockito.any(byte[].class));
        ((TransactionBatch)Mockito.verify((Object)writer.getMockedTxBatch(), (VerificationMode)Mockito.never())).commit();
        writer.flush(true);
        Assertions.assertEquals((int)writer.getTotalRecords(), (int)0);
        ((TransactionBatch)Mockito.verify((Object)writer.getMockedTxBatch(), (VerificationMode)Mockito.atLeastOnce())).commit();
        tuple = this.generateTestTuple("3", "ghi");
        writer.write(mapper.mapRecord(tuple));
        writer.flush(true);
        tuple = this.generateTestTuple("4", "klm");
        writer.write(mapper.mapRecord(tuple));
        writer.flush(true);
        writer.close();
        ((TransactionBatch)Mockito.verify((Object)writer.getMockedTxBatch(), (VerificationMode)Mockito.times((int)4))).write((byte[])Mockito.any(byte[].class));
    }

    private Tuple generateTestTuple(Object id, Object msg) {
        TopologyBuilder builder = new TopologyBuilder();
        GeneralTopologyContext topologyContext = new GeneralTopologyContext(builder.createTopology(), (Map)new Config(), new HashMap(), new HashMap(), new HashMap(), ""){

            public Fields getComponentOutputFields(String componentId, String streamId) {
                return new Fields(new String[]{"id", "msg"});
            }
        };
        return new TupleImpl(topologyContext, (List)new Values(new Object[]{id, msg}), "", 1, "");
    }

    private void writeTuples(HiveWriter writer, HiveMapper mapper, int count) throws HiveWriter.WriteFailure, InterruptedException, SerializationError {
        Integer id = 100;
        String msg = "test-123";
        for (int i = 1; i <= count; ++i) {
            Tuple tuple = this.generateTestTuple(id, msg);
            writer.write(mapper.mapRecord(tuple));
        }
    }

    private static class MockedDelemiteredRecordHiveMapper
    extends DelimitedRecordHiveMapper {
        private final RecordWriter mockedRecordWriter = (RecordWriter)Mockito.mock(RecordWriter.class);

        public RecordWriter createRecordWriter(HiveEndPoint endPoint) throws StreamingException, IOException, ClassNotFoundException {
            return this.mockedRecordWriter;
        }

        public RecordWriter getMockedRecordWriter() {
            return this.mockedRecordWriter;
        }
    }

    private static class TestingHiveWriter
    extends HiveWriter {
        private StreamingConnection mockedStreamingConn;
        private TransactionBatch mockedTxBatch;

        public TestingHiveWriter(HiveEndPoint endPoint, int txnsPerBatch, boolean autoCreatePartitions, long callTimeout, ExecutorService callTimeoutPool, HiveMapper mapper, UserGroupInformation ugi, boolean tokenAuthEnabled) throws InterruptedException, HiveWriter.ConnectFailure {
            super(endPoint, txnsPerBatch, autoCreatePartitions, callTimeout, callTimeoutPool, mapper, ugi, tokenAuthEnabled);
        }

        synchronized StreamingConnection newConnection(UserGroupInformation ugi, boolean tokenAuthEnabled) throws InterruptedException, HiveWriter.ConnectFailure {
            if (this.mockedStreamingConn == null) {
                this.mockedStreamingConn = (StreamingConnection)Mockito.mock(StreamingConnection.class);
                this.mockedTxBatch = (TransactionBatch)Mockito.mock(TransactionBatch.class);
                try {
                    Mockito.when((Object)this.mockedStreamingConn.fetchTransactionBatch(Mockito.anyInt(), (RecordWriter)Mockito.any(RecordWriter.class))).thenReturn((Object)this.mockedTxBatch);
                }
                catch (StreamingException e) {
                    throw new RuntimeException(e);
                }
            }
            return this.mockedStreamingConn;
        }

        public TransactionBatch getMockedTxBatch() {
            return this.mockedTxBatch;
        }
    }
}

