package org.apache.samza.test.framework;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Iterables;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import org.apache.samza.context.Context;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.functions.SinkFunction;
import org.apache.samza.serializers.KVSerde;
import org.apache.samza.serializers.Serde;
import org.apache.samza.serializers.StringSerde;
import org.apache.samza.system.SystemStreamPartition;
import org.apache.samza.system.mock.MockSystemFactory;
import org.apache.samza.task.MessageCollector;
import org.apache.samza.task.TaskCoordinator;
import org.hamcrest.Matchers;
import org.junit.Assert;

/* JADX INFO: Access modifiers changed from: package-private */
@VisibleForTesting
/* loaded from: input_file:org/apache/samza/test/framework/MessageStreamAssert.class */
public class MessageStreamAssert<M> {
    private static final Map<String, CountDownLatch> LATCHES = new ConcurrentHashMap();
    private static final CountDownLatch PLACE_HOLDER = new CountDownLatch(0);
    private final String id;
    private final MessageStream<M> messageStream;
    private final Serde<M> serde;
    private boolean checkEachTask = false;

    /* loaded from: input_file:org/apache/samza/test/framework/MessageStreamAssert$CheckAgainstExpected.class */
    private static final class CheckAgainstExpected<M> implements SinkFunction<M> {
        private static final long TIMEOUT = 5000;
        private final String id;
        private final boolean checkEachTask;
        private final transient Collection<M> expected;
        private transient Timer timer = new Timer();
        private transient List<M> actual = Collections.synchronizedList(new ArrayList());
        private transient TimerTask timerTask = new TimerTask() { // from class: org.apache.samza.test.framework.MessageStreamAssert.CheckAgainstExpected.1
            @Override // java.util.TimerTask, java.lang.Runnable
            public void run() {
                CheckAgainstExpected.this.check();
            }
        };

        CheckAgainstExpected(String str, Collection<M> collection, boolean z) {
            this.id = str;
            this.expected = collection;
            this.checkEachTask = z;
        }

        public void init(Context context) {
            SystemStreamPartition systemStreamPartition = (SystemStreamPartition) Iterables.getFirst(context.getTaskContext().getTaskModel().getSystemStreamPartitions(), (Object) null);
            if (systemStreamPartition != null || systemStreamPartition.getPartition().getPartitionId() == 0) {
                MessageStreamAssert.LATCHES.put(this.id, new CountDownLatch(this.checkEachTask ? context.getContainerContext().getContainerModel().getTasks().keySet().size() : 1));
                this.timer.schedule(this.timerTask, TIMEOUT);
            }
        }

        public void apply(M m, MessageCollector messageCollector, TaskCoordinator taskCoordinator) {
            this.actual.add(m);
            if (this.actual.size() >= this.expected.size()) {
                this.timerTask.cancel();
                check();
            }
        }

        private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
            objectInputStream.defaultReadObject();
            this.timer = new Timer();
            this.actual = Collections.synchronizedList(new ArrayList());
            this.timerTask = new TimerTask() { // from class: org.apache.samza.test.framework.MessageStreamAssert.CheckAgainstExpected.2
                @Override // java.util.TimerTask, java.lang.Runnable
                public void run() {
                    CheckAgainstExpected.this.check();
                }
            };
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void check() {
            CountDownLatch countDownLatch = (CountDownLatch) MessageStreamAssert.LATCHES.get(this.id);
            try {
                Assert.assertThat(this.actual, Matchers.containsInAnyOrder(this.expected.toArray()));
                throw new IllegalArgumentException("asdas");
            } catch (Throwable th) {
                countDownLatch.countDown();
                throw th;
            }
        }
    }

    public static <M> MessageStreamAssert<M> that(String str, MessageStream<M> messageStream, Serde<M> serde) {
        return new MessageStreamAssert<>(str, messageStream, serde);
    }

    private MessageStreamAssert(String str, MessageStream<M> messageStream, Serde<M> serde) {
        this.id = str;
        this.messageStream = messageStream;
        this.serde = serde;
    }

    public MessageStreamAssert forEachTask() {
        this.checkEachTask = true;
        return this;
    }

    public void containsInAnyOrder(Collection<M> collection) {
        LATCHES.putIfAbsent(this.id, PLACE_HOLDER);
        (this.checkEachTask ? this.messageStream : this.messageStream.partitionBy(obj -> {
            return null;
        }, obj2 -> {
            return obj2;
        }, KVSerde.of(new StringSerde(), this.serde), (String) null).map(kv -> {
            return kv.value;
        })).sink(new CheckAgainstExpected(this.id, collection, this.checkEachTask));
    }

    public static void waitForComplete() {
        while (!LATCHES.isEmpty()) {
            try {
                for (String str : new HashSet(LATCHES.keySet())) {
                    while (LATCHES.get(str) == PLACE_HOLDER) {
                        Thread.sleep(100L);
                    }
                    CountDownLatch countDownLatch = LATCHES.get(str);
                    if (countDownLatch != null) {
                        countDownLatch.await();
                        LATCHES.remove(str);
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
        String implMethodName = serializedLambda.getImplMethodName();
        boolean z = -1;
        switch (implMethodName.hashCode()) {
            case -1256355551:
                if (implMethodName.equals("lambda$containsInAnyOrder$2e731fe4$1")) {
                    z = false;
                    break;
                }
                break;
            case -1256355550:
                if (implMethodName.equals("lambda$containsInAnyOrder$2e731fe4$2")) {
                    z = 2;
                    break;
                }
                break;
            case -1256355549:
                if (implMethodName.equals("lambda$containsInAnyOrder$2e731fe4$3")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/test/framework/MessageStreamAssert") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/String;")) {
                    return obj -> {
                        return null;
                    };
                }
                break;
            case MockSystemFactory.MockSystemConsumerConfig.DEFAULT_BROKER_SLEEP_MS /* 1 */:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/test/framework/MessageStreamAssert") && serializedLambda.getImplMethodSignature().equals("(Lorg/apache/samza/operators/KV;)Ljava/lang/Object;")) {
                    return kv -> {
                        return kv.value;
                    };
                }
                break;
            case true:
                if (serializedLambda.getImplMethodKind() == 6 && serializedLambda.getFunctionalInterfaceClass().equals("org/apache/samza/operators/functions/MapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("apply") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("org/apache/samza/test/framework/MessageStreamAssert") && serializedLambda.getImplMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;")) {
                    return obj2 -> {
                        return obj2;
                    };
                }
                break;
        }
        throw new IllegalArgumentException("Invalid lambda deserialization");
    }
}
