package com.ibm.research.time_series.ml.clustering;

import com.ibm.research.time_series.core.observation.Observation;
import com.ibm.research.time_series.core.timeseries.MultiTimeSeries;
import com.ibm.research.time_series.core.utils.ObservationCollection;
import com.ibm.research.time_series.core.utils.Pair;
import com.ibm.research.time_series.ml.clustering.k_means.functions.DistanceComputer;
import com.ibm.research.time_series.transforms.reducers.distance.emd.algorithm.EMD;
import com.ibm.research.time_series.transforms.reducers.distance.hungarian.algorithm.WBM;
import java.io.IOException;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.math3.optimization.direct.CMAESOptimizer;
import org.codehaus.jackson.JsonEncoding;
import org.codehaus.jackson.JsonFactory;
import org.codehaus.jackson.JsonGenerator;
import org.codehaus.jackson.util.DefaultPrettyPrinter;

/* loaded from: input_file:com/ibm/research/time_series/ml/clustering/TimeSeriesClusteringModel.class */
public abstract class TimeSeriesClusteringModel<T> implements Serializable {
    public final List<ObservationCollection<T>> centroids;
    public final List<Double> intraClusterDistances;
    public final List<Double> interClusterDistances;
    public final List<Double> silhouetteCoefficients;
    public final List<Double> clusterDistributions;
    public final List<Double> sumSquares;

    public TimeSeriesClusteringModel(List<ObservationCollection<T>> list, List<Double> list2, List<Double> list3, List<Double> list4, List<Double> list5, List<Double> list6) {
        this.centroids = list;
        this.intraClusterDistances = list2;
        this.interClusterDistances = list3;
        this.silhouetteCoefficients = list4;
        this.clusterDistributions = list5;
        this.sumSquares = list6;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public int score(ObservationCollection<T> observationCollection) {
        Optional<T> min = IntStream.range(0, this.centroids.size()).mapToObj(i -> {
            return new Pair(Integer.valueOf(i), Double.valueOf(computeDistance(this.centroids.get(i), observationCollection)));
        }).min(Comparator.comparing(pair -> {
            return (Double) pair.right;
        }));
        if (min.isPresent()) {
            return ((Integer) ((Pair) min.get()).left).intValue();
        }
        return -1;
    }

    protected abstract double computeDistance(ObservationCollection<T> observationCollection, ObservationCollection<T> observationCollection2);

    /* JADX WARN: Multi-variable type inference failed */
    public Pair<Integer, Double> scoreWithSilhouette(ObservationCollection<T> observationCollection) {
        List list = (List) IntStream.range(0, this.centroids.size()).parallel().mapToObj(i -> {
            return new Pair(Integer.valueOf(i), Double.valueOf(computeDistance(this.centroids.get(i), observationCollection)));
        }).sorted(Comparator.comparing(pair -> {
            return (Double) pair.right;
        })).collect(Collectors.toList());
        if (list.size() < 2) {
            return new Pair<>(-1, Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS));
        }
        return new Pair<>(Integer.valueOf(((Integer) ((Pair) list.get(0)).left).intValue()), Double.valueOf((((Double) ((Pair) list.get(1)).right).doubleValue() - ((Double) ((Pair) list.get(0)).right).doubleValue()) / ((Double) ((Pair) list.get(1)).right).doubleValue()));
    }

    public Double diffEMD(TimeSeriesClusteringModel<T> timeSeriesClusteringModel) {
        double[] array = this.clusterDistributions.stream().mapToDouble(d -> {
            return d.doubleValue();
        }).toArray();
        double[] array2 = timeSeriesClusteringModel.clusterDistributions.stream().mapToDouble(d2 -> {
            return d2.doubleValue();
        }).toArray();
        double[][] dArr = new double[array.length][array2.length];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                dArr[i][i2] = computeDistance(this.centroids.get(i), timeSeriesClusteringModel.centroids.get(i2));
            }
        }
        return EMD.solve(array, array2, dArr).left;
    }

    public List<Integer> diffWBM(TimeSeriesClusteringModel<T> timeSeriesClusteringModel) {
        double[][] dArr = new double[this.centroids.size()][timeSeriesClusteringModel.centroids.size()];
        for (int i = 0; i < dArr.length; i++) {
            for (int i2 = 0; i2 < dArr[i].length; i2++) {
                dArr[i][i2] = computeDistance(this.centroids.get(i), timeSeriesClusteringModel.centroids.get(i2)) * (-1.0d);
            }
        }
        return (List) IntStream.of(WBM.execute(dArr)).boxed().collect(Collectors.toList());
    }

    public <KEY> Map<KEY, Drift> detectDrift(MultiTimeSeries<KEY, T> multiTimeSeries, MultiTimeSeries<KEY, T> multiTimeSeries2) {
        return detectDrift(multiTimeSeries, multiTimeSeries2, this);
    }

    public <KEY> Map<KEY, Drift> detectDrift(MultiTimeSeries<KEY, T> multiTimeSeries, MultiTimeSeries<KEY, T> multiTimeSeries2, TimeSeriesClusteringModel<T> timeSeriesClusteringModel) {
        Map<KEY, ObservationCollection<T>> collectAsMap = multiTimeSeries.collectAsMap();
        Map<KEY, ObservationCollection<T>> collectAsMap2 = multiTimeSeries2.collectAsMap();
        Set<KEY> keySet = collectAsMap.keySet();
        Set<KEY> keySet2 = collectAsMap2.keySet();
        Stream<KEY> stream = keySet.stream();
        keySet2.getClass();
        Set set = (Set) stream.filter(keySet2::contains).collect(Collectors.toSet());
        Set set2 = (Set) keySet.stream().filter(obj -> {
            return !keySet2.contains(obj);
        }).collect(Collectors.toSet());
        Set set3 = (Set) keySet2.stream().filter(obj2 -> {
            return !keySet.contains(obj2);
        }).collect(Collectors.toSet());
        Stream map = set.stream().map(obj3 -> {
            Pair<Integer, Double> scoreWithSilhouette = scoreWithSilhouette((ObservationCollection) collectAsMap.get(obj3));
            Pair<Integer, Double> scoreWithSilhouette2 = timeSeriesClusteringModel.scoreWithSilhouette((ObservationCollection) collectAsMap2.get(obj3));
            return new Pair(obj3, new Drift(scoreWithSilhouette.left.intValue(), scoreWithSilhouette2.left.intValue(), scoreWithSilhouette.right.doubleValue(), scoreWithSilhouette2.right.doubleValue()));
        });
        Stream map2 = set2.stream().map(obj4 -> {
            Pair<Integer, Double> scoreWithSilhouette = scoreWithSilhouette((ObservationCollection) collectAsMap.get(obj4));
            return new Pair(obj4, new Drift(scoreWithSilhouette.left.intValue(), -1, scoreWithSilhouette.right.doubleValue(), Double.NaN));
        });
        return (Map) Stream.concat(Stream.concat(map, map2), set3.stream().map(obj5 -> {
            Pair<Integer, Double> scoreWithSilhouette = timeSeriesClusteringModel.scoreWithSilhouette((ObservationCollection) collectAsMap2.get(obj5));
            return new Pair(obj5, new Drift(-1, scoreWithSilhouette.left.intValue(), Double.NaN, scoreWithSilhouette.right.doubleValue()));
        })).collect(Collectors.toMap(pair -> {
            return pair.left;
        }, pair2 -> {
            return (Drift) pair2.right;
        }));
    }

    public void save(OutputStream outputStream) {
        try {
            JsonGenerator createJsonGenerator = new JsonFactory().createJsonGenerator(outputStream, JsonEncoding.UTF8);
            createJsonGenerator.setPrettyPrinter(new DefaultPrettyPrinter());
            createJsonGenerator.writeStartObject();
            createJsonGenerator.writeFieldName("intra-cluster-distance-per-centroid");
            createJsonGenerator.writeStartArray();
            Iterator<Double> it = this.intraClusterDistances.iterator();
            while (it.hasNext()) {
                createJsonGenerator.writeNumber(it.next().doubleValue());
            }
            createJsonGenerator.writeEndArray();
            createJsonGenerator.writeFieldName("inter-cluster-distance-per-centroid");
            createJsonGenerator.writeStartArray();
            Iterator<Double> it2 = this.interClusterDistances.iterator();
            while (it2.hasNext()) {
                createJsonGenerator.writeNumber(it2.next().doubleValue());
            }
            createJsonGenerator.writeEndArray();
            createJsonGenerator.writeFieldName("silhouette-coefficients-per-centroid");
            createJsonGenerator.writeStartArray();
            Iterator<Double> it3 = this.silhouetteCoefficients.iterator();
            while (it3.hasNext()) {
                createJsonGenerator.writeNumber(it3.next().doubleValue());
            }
            createJsonGenerator.writeEndArray();
            createJsonGenerator.writeFieldName("cluster-distributions-per-centroid");
            createJsonGenerator.writeStartArray();
            Iterator<Double> it4 = this.clusterDistributions.iterator();
            while (it4.hasNext()) {
                createJsonGenerator.writeNumber(it4.next().doubleValue());
            }
            createJsonGenerator.writeEndArray();
            createJsonGenerator.writeFieldName("sum-square-per-centroid");
            createJsonGenerator.writeStartArray();
            Iterator<Double> it5 = this.sumSquares.iterator();
            while (it5.hasNext()) {
                createJsonGenerator.writeNumber(it5.next().doubleValue());
            }
            createJsonGenerator.writeEndArray();
            createJsonGenerator.writeFieldName("centroids");
            createJsonGenerator.writeStartArray();
            for (ObservationCollection<T> observationCollection : this.centroids) {
                createJsonGenerator.writeStartArray();
                for (Observation<T> observation : observationCollection) {
                    createJsonGenerator.writeStartObject();
                    createJsonGenerator.writeNumberField("timestamp", observation.getTimeTick());
                    writeObservationValueToJSON(createJsonGenerator, observation.getValue());
                    createJsonGenerator.writeEndObject();
                }
                createJsonGenerator.writeEndArray();
            }
            createJsonGenerator.writeEndArray();
            writeCustomFieldsToJSON(createJsonGenerator);
            createJsonGenerator.writeEndObject();
            createJsonGenerator.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    protected abstract void writeObservationValueToJSON(JsonGenerator jsonGenerator, T t) throws IOException;

    protected abstract void writeCustomFieldsToJSON(JsonGenerator jsonGenerator) throws IOException;

    public List<ObservationCollection<T>> centroids() {
        return this.centroids;
    }

    public List<Double> interClusterDistances() {
        return this.interClusterDistances;
    }

    public List<Double> intraClusterDistances() {
        return this.intraClusterDistances;
    }

    public List<Double> silhouetteCoefficients() {
        return this.silhouetteCoefficients;
    }

    public List<Double> clusterDistributions() {
        return this.clusterDistributions;
    }

    public static <KEY, VALUE> List<List<Double>> perClusterMetrics(MultiTimeSeries<KEY, VALUE> multiTimeSeries, List<ObservationCollection<VALUE>> list, DistanceComputer<VALUE> distanceComputer) {
        HashMap hashMap = new HashMap();
        Map<KEY, ObservationCollection<VALUE>> collectAsMap = multiTimeSeries.collectAsMap();
        collectAsMap.entrySet().stream().map(entry -> {
            List list2 = (List) IntStream.range(0, list.size()).mapToObj(i -> {
                Double compute = distanceComputer.compute((ObservationCollection) list.get(i), (ObservationCollection) entry.getValue());
                return new Pair(Integer.valueOf(i), Double.valueOf(compute.equals(Double.valueOf(Double.NaN)) ? 1.0d : compute.doubleValue()));
            }).sorted(Comparator.comparing(pair -> {
                return (Double) pair.right;
            })).collect(Collectors.toList());
            return new Pair(((Pair) list2.get(0)).left, new Pair(((Pair) list2.get(0)).right, ((Pair) list2.get(1)).right));
        }).forEach(pair -> {
            if (hashMap.containsKey(pair.left)) {
                ArrayList arrayList = new ArrayList((Collection) ((Pair) hashMap.get(pair.left)).left);
                ArrayList arrayList2 = new ArrayList((Collection) ((Pair) hashMap.get(pair.left)).right);
                arrayList.add(((Pair) pair.right).left);
                arrayList2.add(((Pair) pair.right).right);
                hashMap.put(pair.left, new Pair(arrayList, arrayList2));
                return;
            }
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            arrayList3.add(((Pair) pair.right).left);
            arrayList4.add(((Pair) pair.right).right);
            hashMap.put(pair.left, new Pair(arrayList3, arrayList4));
        });
        for (int i = 0; i < list.size(); i++) {
            if (!hashMap.containsKey(Integer.valueOf(i))) {
                hashMap.put(Integer.valueOf(i), new Pair(Collections.emptyList(), Collections.emptyList()));
            }
        }
        List list2 = (List) hashMap.entrySet().stream().map(entry2 -> {
            return Double.valueOf(((List) ((Pair) entry2.getValue()).left).stream().mapToDouble(d -> {
                return Math.pow(d.doubleValue(), 2.0d);
            }).sum());
        }).collect(Collectors.toList());
        List list3 = (List) hashMap.entrySet().stream().map(entry3 -> {
            OptionalDouble average = ((List) ((Pair) entry3.getValue()).left).stream().mapToDouble(d -> {
                return d.doubleValue();
            }).average();
            return average.isPresent() ? new Pair(entry3.getKey(), Double.valueOf(average.getAsDouble())) : new Pair(entry3.getKey(), Double.valueOf(1.0d));
        }).sorted(Comparator.comparing(pair2 -> {
            return (Integer) pair2.left;
        })).map(pair3 -> {
            return (Double) pair3.right;
        }).collect(Collectors.toList());
        List list4 = (List) hashMap.entrySet().stream().map(entry4 -> {
            OptionalDouble average = ((List) ((Pair) entry4.getValue()).right).stream().mapToDouble(d -> {
                return d.doubleValue();
            }).average();
            return average.isPresent() ? new Pair(entry4.getKey(), Double.valueOf(average.getAsDouble())) : new Pair(entry4.getKey(), Double.valueOf(CMAESOptimizer.DEFAULT_STOPFITNESS));
        }).sorted(Comparator.comparing(pair4 -> {
            return (Integer) pair4.left;
        })).map(pair5 -> {
            return (Double) pair5.right;
        }).collect(Collectors.toList());
        return Arrays.asList(list3, list4, (List) IntStream.range(0, list.size()).mapToObj(i2 -> {
            double doubleValue = ((Double) list4.get(i2)).doubleValue();
            return Double.valueOf(doubleValue == CMAESOptimizer.DEFAULT_STOPFITNESS ? CMAESOptimizer.DEFAULT_STOPFITNESS : (doubleValue - ((Double) list3.get(i2)).doubleValue()) / doubleValue);
        }).collect(Collectors.toList()), (List) hashMap.entrySet().stream().map(entry5 -> {
            return new Pair(entry5.getKey(), Double.valueOf((((List) ((Pair) entry5.getValue()).left).size() * 1.0d) / collectAsMap.size()));
        }).sorted(Comparator.comparing(pair6 -> {
            return (Integer) pair6.left;
        })).map(pair7 -> {
            return (Double) pair7.right;
        }).collect(Collectors.toList()), list2);
    }
}
