package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating;

import com.fasterxml.jackson.core.TreeNode;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.optim.BaseOptimizer;
import org.apache.commons.math3.optim.OptimizationData;
import org.apache.commons.math3.optim.linear.LinearConstraint;
import org.apache.commons.math3.optim.linear.LinearConstraintSet;
import org.apache.commons.math3.optim.linear.LinearObjectiveFunction;
import org.apache.commons.math3.optim.linear.Relationship;
import org.apache.commons.math3.optim.linear.SimplexSolver;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.util.Incrementor;
import org.apache.nemo.common.ir.IRDAG;
import org.apache.nemo.common.ir.edge.IREdge;
import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ResourceSiteProperty;
import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Annotates({ResourceSiteProperty.class})
@Requires({ParallelismProperty.class})
/* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/ResourceSitePass.class */
public final class ResourceSitePass extends AnnotatingPass {
    private static final int OBJECTIVE_COEFFICIENT_INDEX = 0;
    private static final Logger LOG = LoggerFactory.getLogger(ResourceSitePass.class);
    private static final HashMap<String, Integer> EMPTY_MAP = new HashMap<>();
    private static String bandwidthSpecificationString = "";

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/ResourceSitePass$BandwidthSpecification.class */
    public static final class BandwidthSpecification {
        private final List<String> nodeNames = new ArrayList();
        private final Map<String, Integer> uplinkBandwidth = new HashMap();
        private final Map<String, Integer> downlinkBandwidth = new HashMap();

        private BandwidthSpecification() {
        }

        static BandwidthSpecification fromJsonString(String str) {
            BandwidthSpecification bandwidthSpecification = new BandwidthSpecification();
            try {
                JsonNode readTree = new ObjectMapper().readTree(str);
                for (int i = ResourceSitePass.OBJECTIVE_COEFFICIENT_INDEX; i < readTree.size(); i++) {
                    TreeNode treeNode = readTree.get(i);
                    String nextTextValue = treeNode.get("name").traverse().nextTextValue();
                    int intValue = treeNode.get("up").traverse().getIntValue();
                    int intValue2 = treeNode.get("down").traverse().getIntValue();
                    bandwidthSpecification.nodeNames.add(nextTextValue);
                    bandwidthSpecification.uplinkBandwidth.put(nextTextValue, Integer.valueOf(intValue));
                    bandwidthSpecification.downlinkBandwidth.put(nextTextValue, Integer.valueOf(intValue2));
                }
                return bandwidthSpecification;
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        int up(String str) {
            return this.uplinkBandwidth.get(str).intValue();
        }

        int down(String str) {
            return this.downlinkBandwidth.get(str).intValue();
        }

        List<String> getNodes() {
            return this.nodeNames;
        }
    }

    public ResourceSitePass() {
        super(ResourceSitePass.class);
    }

    @Override // java.util.function.Function
    public IRDAG apply(IRDAG irdag) {
        if (bandwidthSpecificationString.isEmpty()) {
            irdag.topologicalDo(iRVertex -> {
                iRVertex.setProperty(ResourceSiteProperty.of(EMPTY_MAP));
            });
        } else {
            assignNodeShares(irdag, BandwidthSpecification.fromJsonString(bandwidthSpecificationString));
        }
        return irdag;
    }

    public static void setBandwidthSpecificationString(String str) {
        bandwidthSpecificationString = str;
    }

    private static HashMap<String, Integer> getEvenShares(List<String> list, int i) {
        HashMap<String, Integer> hashMap = new HashMap<>();
        int size = i / list.size();
        int size2 = i % list.size();
        int i2 = OBJECTIVE_COEFFICIENT_INDEX;
        while (i2 < list.size()) {
            hashMap.put(list.get(i2), Integer.valueOf(size + (i2 < size2 ? 1 : OBJECTIVE_COEFFICIENT_INDEX)));
            i2++;
        }
        return hashMap;
    }

    private static void assignNodeShares(IRDAG irdag, BandwidthSpecification bandwidthSpecification) {
        irdag.topologicalDo(iRVertex -> {
            List incomingEdgesOf = irdag.getIncomingEdgesOf(iRVertex);
            int intValue = ((Integer) iRVertex.getPropertyValue(ParallelismProperty.class).orElseThrow(() -> {
                return new RuntimeException("Parallelism property required");
            })).intValue();
            if (incomingEdgesOf.size() == 0) {
                iRVertex.setProperty(ResourceSiteProperty.of(EMPTY_MAP));
                return;
            }
            if (isOneToOneEdge(incomingEdgesOf)) {
                iRVertex.setProperty(ResourceSiteProperty.of((HashMap) ((IREdge) incomingEdgesOf.iterator().next()).getSrc().getPropertyValue(ResourceSiteProperty.class).get()));
                return;
            }
            HashMap hashMap = new HashMap();
            Iterator it = irdag.getIncomingEdgesOf(iRVertex).iterator();
            while (it.hasNext()) {
                IRVertex src = ((IREdge) it.next()).getSrc();
                Map map = (Map) src.getPropertyValue(ResourceSiteProperty.class).get();
                for (Map.Entry entry : (map.isEmpty() ? getEvenShares(bandwidthSpecification.getNodes(), ((Integer) src.getPropertyValue(ParallelismProperty.class).orElseThrow(() -> {
                    return new RuntimeException("Parallelism property required");
                })).intValue()) : map).entrySet()) {
                    hashMap.putIfAbsent((String) entry.getKey(), Integer.valueOf(OBJECTIVE_COEFFICIENT_INDEX));
                    hashMap.put((String) entry.getKey(), Integer.valueOf(((Integer) entry.getValue()).intValue() + ((Integer) hashMap.get(entry.getKey())).intValue()));
                }
            }
            double[] optimize = optimize(bandwidthSpecification, hashMap);
            HashMap hashMap2 = new HashMap();
            for (int i = OBJECTIVE_COEFFICIENT_INDEX; i < bandwidthSpecification.getNodes().size(); i++) {
                hashMap2.put(bandwidthSpecification.getNodes().get(i), Integer.valueOf((int) (optimize[i] * intValue)));
            }
            int sum = intValue - hashMap2.values().stream().mapToInt(num -> {
                return num.intValue();
            }).sum();
            for (String str : hashMap2.keySet()) {
                if (sum == 0) {
                    break;
                }
                hashMap2.put(str, Integer.valueOf(((Integer) hashMap2.get(str)).intValue() + 1));
                sum--;
            }
            iRVertex.setProperty(ResourceSiteProperty.of(hashMap2));
        });
    }

    private static boolean isOneToOneEdge(Collection<IREdge> collection) {
        return collection.size() == 1 && ((CommunicationPatternProperty.Value) collection.iterator().next().getPropertyValue(CommunicationPatternProperty.class).orElseThrow(IllegalStateException::new)).equals(CommunicationPatternProperty.Value.ONE_TO_ONE);
    }

    private static double[] optimize(BandwidthSpecification bandwidthSpecification, Map<String, Integer> map) {
        int sum = map.values().stream().mapToInt(num -> {
            return num.intValue();
        }).sum();
        List<String> nodes = bandwidthSpecification.getNodes();
        ArrayList arrayList = new ArrayList();
        int size = nodes.size() + 1;
        for (int i = OBJECTIVE_COEFFICIENT_INDEX; i < nodes.size(); i++) {
            String str = nodes.get(i);
            int i2 = i + 1;
            int intValue = map.get(str).intValue();
            double[] dArr = new double[size];
            dArr[OBJECTIVE_COEFFICIENT_INDEX] = bandwidthSpecification.up(str);
            dArr[i2] = intValue;
            arrayList.add(new LinearConstraint(dArr, Relationship.GEQ, intValue));
            double[] dArr2 = new double[size];
            dArr2[OBJECTIVE_COEFFICIENT_INDEX] = bandwidthSpecification.down(str);
            dArr2[i2] = intValue - sum;
            arrayList.add(new LinearConstraint(dArr2, Relationship.GEQ, 0.0d));
            double[] dArr3 = new double[size];
            dArr3[i2] = 1.0d;
            arrayList.add(new LinearConstraint(dArr3, Relationship.GEQ, 0.0d));
        }
        double[] dArr4 = new double[size];
        for (int i3 = OBJECTIVE_COEFFICIENT_INDEX; i3 < nodes.size(); i3++) {
            dArr4[1 + i3] = 1.0d;
        }
        arrayList.add(new LinearConstraint(dArr4, Relationship.EQ, 1.0d));
        double[] dArr5 = new double[size];
        dArr5[OBJECTIVE_COEFFICIENT_INDEX] = 1.0d;
        OptimizationData linearObjectiveFunction = new LinearObjectiveFunction(dArr5, 0.0d);
        try {
            SimplexSolver simplexSolver = new SimplexSolver();
            Field declaredField = BaseOptimizer.class.getDeclaredField("iterations");
            declaredField.setAccessible(true);
            ((Incrementor) declaredField.get(simplexSolver)).setMaximalCount(Integer.MAX_VALUE);
            LOG.info(String.format("Max iterations: %d", Integer.valueOf(simplexSolver.getMaxIterations())));
            return Arrays.copyOfRange(simplexSolver.optimize(new OptimizationData[]{new LinearConstraintSet(arrayList), linearObjectiveFunction, GoalType.MINIMIZE}).getPoint(), 1, size);
        } catch (IllegalAccessException | NoSuchFieldException e) {
            throw new RuntimeException(e);
        }
    }
}
