package org.apache.beam.runners.spark;

import org.apache.beam.runners.spark.SparkRunner;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.SparkContextFactory;
import org.apache.beam.runners.spark.translation.TransformTranslator;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.values.PCollection;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/beam/runners/spark/CacheTest.class */
public class CacheTest {
    @Test
    public void cacheCandidatesUpdaterTest() {
        SparkPipelineOptions createOptions = createOptions();
        Pipeline create = Pipeline.create(createOptions);
        PCollection apply = create.apply(Create.of("foo", new String[]{"bar"}));
        apply.apply(Count.globally());
        apply.apply(Count.globally());
        EvaluationContext evaluationContext = new EvaluationContext(SparkContextFactory.getSparkContext(createOptions), create, createOptions);
        create.traverseTopologically(new SparkRunner.CacheVisitor(new TransformTranslator.Translator(), evaluationContext));
        Assert.assertEquals(2L, ((Long) evaluationContext.getCacheCandidates().get(apply)).longValue());
    }

    @Test
    public void shouldCacheTest() {
        SparkPipelineOptions createOptions = createOptions();
        createOptions.setCacheDisabled(true);
        Pipeline create = Pipeline.create(createOptions);
        Create.Values of = Create.of("foo", new String[]{"bar"});
        PCollection pCollection = (PCollection) Mockito.mock(PCollection.class);
        EvaluationContext evaluationContext = new EvaluationContext(SparkContextFactory.getSparkContext(createOptions), create, createOptions);
        evaluationContext.getCacheCandidates().put(pCollection, 2L);
        Assert.assertFalse(evaluationContext.shouldCache(of, pCollection));
        createOptions.setCacheDisabled(false);
        Assert.assertTrue(evaluationContext.shouldCache(of, pCollection));
        Assert.assertFalse(evaluationContext.shouldCache(GroupByKey.create(), pCollection));
    }

    private SparkPipelineOptions createOptions() {
        SparkPipelineOptions as = PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
        as.setRunner(TestSparkRunner.class);
        return as;
    }
}
