package com.ibm.research.time_series.ml.clustering.k_means;

import com.ibm.research.time_series.core.constants.Padding;
import com.ibm.research.time_series.core.timeseries.MultiTimeSeries;
import com.ibm.research.time_series.core.timeseries.TimeSeries;
import com.ibm.research.time_series.core.transform.UnaryReducer;
import com.ibm.research.time_series.core.utils.ObservationCollection;
import com.ibm.research.time_series.core.utils.Pair;
import com.ibm.research.time_series.core.utils.Segment;
import com.ibm.research.time_series.ml.clustering.k_means.functions.DistanceComputer;
import com.ibm.research.time_series.ml.clustering.k_means.functions.WeightedSumFunction;
import com.ibm.research.time_series.transforms.reducers.math.MathReducers;
import java.lang.invoke.SerializedLambda;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:com/ibm/research/time_series/ml/clustering/k_means/KMeansComputer.class */
abstract class KMeansComputer<V> implements DistanceComputer<V> {
    private UnaryReducer<V, Pair<Integer, ObservationCollection<V>>> minDistanceCluster(final Map<Integer, ObservationCollection<V>> map) {
        return new UnaryReducer<V, Pair<Integer, ObservationCollection<V>>>() { // from class: com.ibm.research.time_series.ml.clustering.k_means.KMeansComputer.1
            /* JADX WARN: Multi-variable type inference failed */
            @Override // com.ibm.research.time_series.core.transform.UnaryReducer
            public Pair<Integer, ObservationCollection<V>> reduceSegment(Segment<V> segment) {
                TimeSeries<V> timeSeriesStream = segment.toTimeSeriesStream();
                return new Pair<>(Integer.valueOf(((Integer) ((Pair) map.entrySet().stream().map(entry -> {
                    long timeTick = (((ObservationCollection) entry.getValue()).last().getTimeTick() - ((ObservationCollection) entry.getValue()).first().getTimeTick()) + 1;
                    return new Pair(entry.getKey(), timeSeriesStream.segmentByTime(timeTick, timeTick, Padding.NONE).map(segment2 -> {
                        return KMeansComputer.this.compute(segment2, (ObservationCollection) entry.getValue());
                    }).reduce((UnaryReducer<T2, T2>) MathReducers.average()));
                }).min(Comparator.comparing(pair -> {
                    return (Double) pair.right;
                })).get()).left).intValue()), segment);
            }

            private static /* synthetic */ Object $deserializeLambda$(SerializedLambda serializedLambda) {
                String implMethodName = serializedLambda.getImplMethodName();
                boolean z = -1;
                switch (implMethodName.hashCode()) {
                    case 2063877927:
                        if (implMethodName.equals("lambda$null$eb6c761b$1")) {
                            z = false;
                            break;
                        }
                        break;
                }
                switch (z) {
                    case false:
                        if (serializedLambda.getImplMethodKind() == 7 && serializedLambda.getFunctionalInterfaceClass().equals("com/ibm/research/time_series/core/functions/UnaryMapFunction") && serializedLambda.getFunctionalInterfaceMethodName().equals("evaluate") && serializedLambda.getFunctionalInterfaceMethodSignature().equals("(Ljava/lang/Object;)Ljava/lang/Object;") && serializedLambda.getImplClass().equals("com/ibm/research/time_series/ml/clustering/k_means/KMeansComputer$1") && serializedLambda.getImplMethodSignature().equals("(Ljava/util/Map$Entry;Lcom/ibm/research/time_series/core/utils/Segment;)Ljava/lang/Double;")) {
                            AnonymousClass1 anonymousClass1 = (AnonymousClass1) serializedLambda.getCapturedArg(0);
                            Map.Entry entry = (Map.Entry) serializedLambda.getCapturedArg(1);
                            return segment2 -> {
                                return KMeansComputer.this.compute(segment2, (ObservationCollection) entry.getValue());
                            };
                        }
                        break;
                }
                throw new IllegalArgumentException("Invalid lambda deserialization");
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<ObservationCollection<V>> performTrain(MultiTimeSeries<?, V> multiTimeSeries, List<ObservationCollection<V>> list, WeightedSumFunction<V> weightedSumFunction, int i, double d) {
        Map<Integer, ObservationCollection<V>> map = (Map) IntStream.range(0, list.size()).mapToObj(i2 -> {
            return new Pair(Integer.valueOf(i2), list.get(i2));
        }).collect(Collectors.toMap(pair -> {
            return (Integer) pair.left;
        }, pair2 -> {
            return (ObservationCollection) pair2.right;
        }));
        for (int i3 = 0; i3 < i; i3++) {
            Map<Integer, ObservationCollection<V>> map2 = map;
            ArrayList arrayList = new ArrayList(multiTimeSeries.reduceSeries((UnaryReducer<V, T2>) minDistanceCluster(map2)).entrySet());
            HashMap hashMap = new HashMap();
            arrayList.forEach(entry -> {
                if (!hashMap.containsKey(((Pair) entry.getValue()).left)) {
                    hashMap.put(((Pair) entry.getValue()).left, new Pair(((Pair) entry.getValue()).right, Double.valueOf(1.0d)));
                } else {
                    Pair pair3 = (Pair) hashMap.get(((Pair) entry.getValue()).left);
                    hashMap.put(((Pair) entry.getValue()).left, new Pair(KMeansUtils.computeNewCentroid((ObservationCollection) pair3.left, (ObservationCollection) ((Pair) entry.getValue()).right, ((Double) pair3.right).doubleValue(), 1.0d, weightedSumFunction), Double.valueOf(((Double) pair3.right).doubleValue() + 1.0d)));
                }
            });
            double sum = hashMap.entrySet().stream().mapToDouble(entry2 -> {
                return compute((ObservationCollection) map2.get(entry2.getKey()), (ObservationCollection) ((Pair) entry2.getValue()).left).doubleValue();
            }).sum();
            map = (Map) hashMap.entrySet().stream().map(entry3 -> {
                return new Pair(entry3.getKey(), ((Pair) entry3.getValue()).left);
            }).collect(Collectors.toMap(pair3 -> {
                return (Integer) pair3.left;
            }, pair4 -> {
                return (ObservationCollection) pair4.right;
            }));
            if (sum <= d) {
                break;
            }
        }
        return (List) map.entrySet().stream().map((v0) -> {
            return v0.getValue();
        }).collect(Collectors.toList());
    }
}
