package com.datatorrent.stram;

import com.datatorrent.api.Context;
import com.datatorrent.api.DAG;
import com.datatorrent.common.partitioner.StatelessPartitioner;
import com.datatorrent.stram.StreamingContainerAgent;
import com.datatorrent.stram.engine.GenericTestOperator;
import com.datatorrent.stram.plan.logical.LogicalPlan;
import com.datatorrent.stram.plan.physical.PTContainer;
import com.datatorrent.stram.plan.physical.PTOperator;
import com.datatorrent.stram.support.StramTestSupport;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.io.File;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.hadoop.yarn.api.records.NodeReport;
import org.apache.hadoop.yarn.api.records.NodeState;
import org.apache.hadoop.yarn.server.utils.BuilderUtils;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:com/datatorrent/stram/LocalityTest.class */
public class LocalityTest {
    @Test
    public void testNodeLocal() {
        LogicalPlan logicalPlan = new LogicalPlan();
        logicalPlan.getAttributes().put(Context.DAGContext.APPLICATION_PATH, new File("target", LocalityTest.class.getName()).getAbsolutePath());
        logicalPlan.setAttribute(Context.OperatorContext.STORAGE_AGENT, new StramTestSupport.MemoryStorageAgent());
        GenericTestOperator addOperator = logicalPlan.addOperator("o1", GenericTestOperator.class);
        GenericTestOperator addOperator2 = logicalPlan.addOperator("partitioned", GenericTestOperator.class);
        logicalPlan.getMeta(addOperator2).getAttributes().put(Context.OperatorContext.PARTITIONER, new StatelessPartitioner(2));
        GenericTestOperator addOperator3 = logicalPlan.addOperator("partitionedParallel", GenericTestOperator.class);
        logicalPlan.addStream("o1_outport1", addOperator.outport1, addOperator2.inport1).setLocality((DAG.Locality) null);
        logicalPlan.addStream("partitioned_outport1", addOperator2.outport1, addOperator3.inport2).setLocality(DAG.Locality.NODE_LOCAL);
        logicalPlan.setInputPortAttribute(addOperator3.inport2, Context.PortContext.PARTITION_PARALLEL, true);
        logicalPlan.addStream("partitionedParallel_outport1", addOperator3.outport1, logicalPlan.addOperator("single", GenericTestOperator.class).inport1);
        logicalPlan.setAttribute(LogicalPlan.CONTAINERS_MAX_COUNT, 7);
        StreamingContainerManager streamingContainerManager = new StreamingContainerManager(logicalPlan);
        Assert.assertEquals("number required containers", 6L, streamingContainerManager.containerStartRequests.size());
        ResourceRequestHandler resourceRequestHandler = new ResourceRequestHandler();
        HashMap newHashMap = Maps.newHashMap();
        NodeReport newNodeReport = BuilderUtils.newNodeReport(BuilderUtils.newNodeId("host1", 0), NodeState.RUNNING, "httpAddress", "rackName", BuilderUtils.newResource(0, 0), BuilderUtils.newResource(2000 * 2, 2), 0, (String) null, 0L);
        newHashMap.put(newNodeReport.getNodeId().getHost(), newNodeReport);
        NodeReport newNodeReport2 = BuilderUtils.newNodeReport(BuilderUtils.newNodeId("host2", 0), NodeState.RUNNING, "httpAddress", "rackName", BuilderUtils.newResource(0, 0), BuilderUtils.newResource(2000 * 2, 2), 0, (String) null, 0L);
        newHashMap.put(newNodeReport2.getNodeId().getHost(), newNodeReport2);
        resourceRequestHandler.updateNodeReports(Lists.newArrayList(newHashMap.values()));
        HashMap newHashMap2 = Maps.newHashMap();
        Iterator it = streamingContainerManager.containerStartRequests.iterator();
        while (it.hasNext()) {
            StreamingContainerAgent.ContainerStartRequest containerStartRequest = (StreamingContainerAgent.ContainerStartRequest) it.next();
            String host = resourceRequestHandler.getHost(containerStartRequest, true);
            containerStartRequest.container.host = host;
            if (host != null) {
                newHashMap2.put(containerStartRequest.container, host);
                NodeReport nodeReport = (NodeReport) newHashMap.get(host);
                nodeReport.getUsed().setMemory(nodeReport.getUsed().getMemory() + 2000);
            }
        }
        Assert.assertEquals("" + newHashMap2, newHashMap.keySet(), Sets.newHashSet(newHashMap2.values()));
        Iterator it2 = newHashMap2.entrySet().iterator();
        while (it2.hasNext()) {
            for (PTOperator pTOperator : ((PTContainer) ((Map.Entry) it2.next()).getKey()).getOperators()) {
                if (pTOperator.getNodeLocalOperators().getOperatorSet().size() > 1) {
                    String str = null;
                    for (PTOperator pTOperator2 : pTOperator.getNodeLocalOperators().getOperatorSet()) {
                        Assert.assertNotNull("host null " + pTOperator2.getContainer(), pTOperator2.getContainer().host);
                        if (str == null) {
                            str = pTOperator2.getContainer().host;
                        } else {
                            Assert.assertEquals("expected same host " + pTOperator2, str, pTOperator2.getContainer().host);
                        }
                    }
                }
            }
        }
    }
}
