/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.confignode.manager.load.balancer.router.leader;

import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.commons.cluster.NodeStatus;
import org.apache.iotdb.commons.cluster.RegionStatus;
import org.apache.iotdb.confignode.manager.load.balancer.router.leader.AbstractLeaderBalancer;
import org.apache.iotdb.confignode.manager.load.balancer.router.leader.GreedyLeaderBalancer;
import org.apache.iotdb.confignode.manager.load.balancer.router.leader.MinCostFlowLeaderBalancer;
import org.apache.iotdb.confignode.manager.load.cache.node.NodeStatistics;
import org.apache.iotdb.confignode.manager.load.cache.region.RegionStatistics;
import org.apache.tsfile.utils.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LeaderBalancerComparisonTest {
    private static final boolean isCommandLineMode = false;
    private static final Logger LOGGER = LoggerFactory.getLogger(LeaderBalancerComparisonTest.class);
    private static FileWriter WRITER;
    private static final GreedyLeaderBalancer GREEDY_LEADER_BALANCER;
    private static final MinCostFlowLeaderBalancer MIN_COST_FLOW_LEADER_BALANCER;
    private static final Random RANDOM;
    private static final int TEST_MAX_DATA_NODE_NUM = 100;
    private static final int TEST_CPU_CORE_NUM = 16;
    private static final int TEST_REPLICA_NUM = 3;
    private static final double GREEDY_INIT_RATE = 0.9;
    private static final double DISABLE_DATA_NODE_RATE = 0.05;

    public static void prepareWriter() throws IOException {
        WRITER = new FileWriter("./leaderBalancerTest.txt");
    }

    public void leaderBalancerComparisonTest() throws IOException {
        for (int dataNodeNum = 3; dataNodeNum <= 100; ++dataNodeNum) {
            int regionGroupNum = 16 * dataNodeNum / 3;
            HashMap<TConsensusGroupId, Set<Integer>> regionReplicaSetMap = new HashMap<TConsensusGroupId, Set<Integer>>();
            HashMap<TConsensusGroupId, Integer> regionLeaderMap = new HashMap<TConsensusGroupId, Integer>();
            this.generateTestData(dataNodeNum, regionGroupNum, regionReplicaSetMap, regionLeaderMap);
            ConcurrentHashMap<TConsensusGroupId, Integer> greedyLeaderDistribution = new ConcurrentHashMap<TConsensusGroupId, Integer>();
            TreeMap<Integer, NodeStatistics> allRunningDataNodeStatistics = new TreeMap<Integer, NodeStatistics>();
            for (int i = 0; i < dataNodeNum; ++i) {
                allRunningDataNodeStatistics.put(i, new NodeStatistics(NodeStatus.Running));
            }
            TreeMap<TConsensusGroupId, Map<Integer, RegionStatistics>> allRunningRegionStatistics = new TreeMap<TConsensusGroupId, Map<Integer, RegionStatistics>>();
            regionReplicaSetMap.forEach((regionGroupId, regionReplicaSet) -> {
                TreeMap regionStatistics = new TreeMap();
                regionReplicaSet.forEach(dataNodeId -> regionStatistics.put(dataNodeId, new RegionStatistics(RegionStatus.Running)));
                allRunningRegionStatistics.put((TConsensusGroupId)regionGroupId, regionStatistics);
            });
            Statistics greedyStatistics = this.doBalancing(dataNodeNum, regionGroupNum, (AbstractLeaderBalancer)GREEDY_LEADER_BALANCER, regionReplicaSetMap, regionLeaderMap, allRunningDataNodeStatistics, allRunningRegionStatistics, greedyLeaderDistribution);
            ConcurrentHashMap<TConsensusGroupId, Integer> mcfLeaderDistribution = new ConcurrentHashMap<TConsensusGroupId, Integer>();
            Statistics mcfStatistics = this.doBalancing(dataNodeNum, regionGroupNum, (AbstractLeaderBalancer)MIN_COST_FLOW_LEADER_BALANCER, regionReplicaSetMap, regionLeaderMap, allRunningDataNodeStatistics, allRunningRegionStatistics, mcfLeaderDistribution);
            greedyStatistics.toFile();
            mcfStatistics.toFile();
            int disabledDataNodeNum = (int)Math.ceil((double)dataNodeNum * 0.05);
            HashSet<Integer> disabledDataNodeSet = new HashSet<Integer>();
            while (disabledDataNodeSet.size() < disabledDataNodeNum) {
                int dataNodeId = RANDOM.nextInt(dataNodeNum);
                if (disabledDataNodeSet.contains(dataNodeId)) continue;
                disabledDataNodeSet.add(dataNodeId);
            }
            TreeMap<Integer, NodeStatistics> disabledDataNodeStatistics = new TreeMap<Integer, NodeStatistics>();
            for (int i = 0; i < dataNodeNum; ++i) {
                disabledDataNodeStatistics.put(i, disabledDataNodeSet.contains(i) ? new NodeStatistics(NodeStatus.Unknown) : new NodeStatistics(NodeStatus.Running));
            }
            TreeMap<TConsensusGroupId, Map<Integer, RegionStatistics>> disabledRegionStatistics = new TreeMap<TConsensusGroupId, Map<Integer, RegionStatistics>>();
            regionReplicaSetMap.forEach((regionGroupId, regionReplicaSet) -> {
                TreeMap regionStatistics = new TreeMap();
                regionReplicaSet.forEach(dataNodeId -> regionStatistics.put(dataNodeId, disabledDataNodeSet.contains(dataNodeId) ? new RegionStatistics(RegionStatus.Unknown) : new RegionStatistics(RegionStatus.Running)));
                disabledRegionStatistics.put((TConsensusGroupId)regionGroupId, regionStatistics);
            });
            greedyStatistics = this.doBalancing(dataNodeNum, regionGroupNum, (AbstractLeaderBalancer)GREEDY_LEADER_BALANCER, regionReplicaSetMap, greedyLeaderDistribution, disabledDataNodeStatistics, disabledRegionStatistics, greedyLeaderDistribution);
            mcfStatistics = this.doBalancing(dataNodeNum, regionGroupNum, (AbstractLeaderBalancer)MIN_COST_FLOW_LEADER_BALANCER, regionReplicaSetMap, mcfLeaderDistribution, disabledDataNodeStatistics, disabledRegionStatistics, mcfLeaderDistribution);
            greedyStatistics.toFile();
            mcfStatistics.toFile();
            greedyStatistics = this.doBalancing(dataNodeNum, regionGroupNum, (AbstractLeaderBalancer)GREEDY_LEADER_BALANCER, regionReplicaSetMap, greedyLeaderDistribution, allRunningDataNodeStatistics, allRunningRegionStatistics, greedyLeaderDistribution);
            mcfStatistics = this.doBalancing(dataNodeNum, regionGroupNum, (AbstractLeaderBalancer)MIN_COST_FLOW_LEADER_BALANCER, regionReplicaSetMap, mcfLeaderDistribution, allRunningDataNodeStatistics, allRunningRegionStatistics, mcfLeaderDistribution);
            greedyStatistics.toFile();
            mcfStatistics.toFile();
        }
    }

    private void generateTestData(int dataNodeNum, int regionGroupNum, Map<TConsensusGroupId, Set<Integer>> regionReplicaSetMap, Map<TConsensusGroupId, Integer> regionLeaderMap) {
        ConcurrentHashMap<Integer, AtomicInteger> regionCounter = new ConcurrentHashMap<Integer, AtomicInteger>();
        ConcurrentHashMap<Integer, AtomicInteger> leaderCounter = new ConcurrentHashMap<Integer, AtomicInteger>();
        for (int i = 0; i < dataNodeNum; ++i) {
            regionCounter.put(i, new AtomicInteger(0));
            leaderCounter.put(i, new AtomicInteger(0));
        }
        int greedyNum = (int)(0.9 * (double)regionGroupNum);
        int randomNum = regionGroupNum - greedyNum;
        for (int index = 0; index < regionGroupNum; ++index) {
            int leaderId = -1;
            TConsensusGroupId regionGroupId = new TConsensusGroupId(TConsensusGroupType.DataRegion, index);
            TRegionReplicaSet regionReplicaSet = new TRegionReplicaSet().setRegionId(regionGroupId);
            int seed = RANDOM.nextInt(greedyNum + randomNum);
            if (seed < greedyNum) {
                int leaderWeight = Integer.MAX_VALUE;
                PriorityQueue<Pair> dataNodePriorityQueue = new PriorityQueue<Pair>(Comparator.comparingInt(Pair::getRight));
                regionCounter.forEach((dataNodeId, regionGroupCount) -> dataNodePriorityQueue.offer(new Pair(dataNodeId, (Object)regionGroupCount.get())));
                for (int i = 0; i < 3; ++i) {
                    int dataNodeId2 = (Integer)Objects.requireNonNull(dataNodePriorityQueue.poll()).getLeft();
                    regionReplicaSet.addToDataNodeLocations(new TDataNodeLocation().setDataNodeId(dataNodeId2));
                    if (((AtomicInteger)leaderCounter.get(dataNodeId2)).get() >= leaderWeight) continue;
                    leaderWeight = ((AtomicInteger)leaderCounter.get(dataNodeId2)).get();
                    leaderId = dataNodeId2;
                }
                --greedyNum;
            } else {
                HashSet<Integer> randomSet = new HashSet<Integer>();
                while (randomSet.size() < 3) {
                    int dataNodeId3 = RANDOM.nextInt(dataNodeNum);
                    if (randomSet.contains(dataNodeId3)) continue;
                    randomSet.add(dataNodeId3);
                    regionReplicaSet.addToDataNodeLocations(new TDataNodeLocation().setDataNodeId(dataNodeId3));
                }
                leaderId = (Integer)new ArrayList(randomSet).get(RANDOM.nextInt(3));
                --randomNum;
            }
            regionReplicaSetMap.put(regionGroupId, regionReplicaSet.getDataNodeLocations().stream().map(TDataNodeLocation::getDataNodeId).collect(Collectors.toSet()));
            regionReplicaSet.getDataNodeLocations().forEach(dataNodeLocation -> ((AtomicInteger)regionCounter.get(dataNodeLocation.getDataNodeId())).getAndIncrement());
            regionLeaderMap.put(regionGroupId, leaderId);
            ((AtomicInteger)leaderCounter.get(leaderId)).getAndIncrement();
        }
    }

    private Statistics doBalancing(int dataNodeNum, int regionGroupNum, AbstractLeaderBalancer leaderBalancer, Map<TConsensusGroupId, Set<Integer>> regionReplicaSetMap, Map<TConsensusGroupId, Integer> regionLeaderMap, Map<Integer, NodeStatistics> nodeStatisticsMap, Map<TConsensusGroupId, Map<Integer, RegionStatistics>> regionStatisticsMap, Map<TConsensusGroupId, Integer> stableLeaderDistribution) {
        Statistics result = new Statistics();
        result.rounds = -1;
        ConcurrentHashMap<TConsensusGroupId, Integer> lastDistribution = new ConcurrentHashMap<TConsensusGroupId, Integer>(regionLeaderMap);
        for (int rounds = 0; rounds < 1000; ++rounds) {
            Map currentDistribution = leaderBalancer.generateOptimalLeaderDistribution(new TreeMap(), regionReplicaSetMap, lastDistribution, nodeStatisticsMap, regionStatisticsMap);
            if (currentDistribution.equals(lastDistribution)) {
                result.rounds = rounds;
                break;
            }
            AtomicInteger switchTimes = new AtomicInteger();
            lastDistribution.keySet().forEach(regionGroupId -> {
                if (!Objects.equals(lastDistribution.get(regionGroupId), currentDistribution.get(regionGroupId))) {
                    switchTimes.getAndIncrement();
                }
            });
            result.switchTimes += switchTimes.get();
            lastDistribution.clear();
            lastDistribution.putAll(currentDistribution);
        }
        stableLeaderDistribution.clear();
        stableLeaderDistribution.putAll(lastDistribution);
        double sum = 0.0;
        double avg = (double)regionGroupNum / (double)dataNodeNum;
        int minLeaderCount = Integer.MAX_VALUE;
        int maxLeaderCount = Integer.MIN_VALUE;
        ConcurrentHashMap leaderCounter = new ConcurrentHashMap();
        lastDistribution.forEach((regionGroupId, leaderId) -> leaderCounter.computeIfAbsent(leaderId, empty -> new AtomicInteger(0)).getAndIncrement());
        for (Map.Entry entry : leaderCounter.entrySet()) {
            int leaderCount = ((AtomicInteger)entry.getValue()).get();
            sum += Math.pow((double)leaderCount - avg, 2.0);
            minLeaderCount = Math.min(minLeaderCount, leaderCount);
            maxLeaderCount = Math.max(maxLeaderCount, leaderCount);
        }
        result.range = maxLeaderCount - minLeaderCount;
        result.variance = sum / (double)dataNodeNum;
        return result;
    }

    static {
        GREEDY_LEADER_BALANCER = new GreedyLeaderBalancer();
        MIN_COST_FLOW_LEADER_BALANCER = new MinCostFlowLeaderBalancer();
        RANDOM = new Random();
    }

    private static class Statistics {
        private int rounds = 0;
        private int switchTimes = 0;
        private int range = 0;
        private double variance = 0.0;

        private Statistics() {
        }

        private void toFile() throws IOException {
            WRITER.write(this.rounds + "," + this.switchTimes + "," + this.range + "," + String.format("%.6f", this.variance) + "\n");
            WRITER.flush();
        }

        public String toString() {
            return "Statistics{rounds=" + this.rounds + ", switchTimes=" + this.switchTimes + ", range=" + this.range + ", variance=" + this.variance + '}';
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Statistics that = (Statistics)o;
            return this.rounds == that.rounds && this.switchTimes == that.switchTimes && this.range == that.range && Math.abs(this.variance - that.variance) <= 0.1;
        }

        public int hashCode() {
            return Objects.hash(this.rounds, this.switchTimes, this.range, this.variance);
        }
    }
}

