package org.apache.commons.math4.neuralnet.sofm;

import org.apache.commons.math4.neuralnet.DistanceMeasure;
import org.apache.commons.math4.neuralnet.EuclideanDistance;
import org.apache.commons.math4.neuralnet.FeatureInitializer;
import org.apache.commons.math4.neuralnet.FeatureInitializerFactory;
import org.apache.commons.math4.neuralnet.MapRanking;
import org.apache.commons.math4.neuralnet.Network;
import org.apache.commons.math4.neuralnet.Neuron;
import org.apache.commons.math4.neuralnet.OffsetFeatureInitializer;
import org.apache.commons.math4.neuralnet.oned.NeuronString;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.simple.RandomSource;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/commons/math4/neuralnet/sofm/KohonenUpdateActionTest.class */
public class KohonenUpdateActionTest {
    private final UniformRandomProvider rng = RandomSource.SPLIT_MIX_64.create();

    @Test
    public void testUpdate() {
        Network network = new NeuronString(3, false, new FeatureInitializer[]{new OffsetFeatureInitializer(FeatureInitializerFactory.uniform(this.rng, 0.0d, 0.1d))}).getNetwork();
        EuclideanDistance euclideanDistance = new EuclideanDistance();
        KohonenUpdateAction kohonenUpdateAction = new KohonenUpdateAction(euclideanDistance, LearningFactorFunctionFactory.exponentialDecay(1.0d, 0.1d, 100L), NeighbourhoodSizeFunctionFactory.exponentialDecay(3.0d, 1.0d, 100L));
        MapRanking mapRanking = new MapRanking(network, euclideanDistance);
        double[] dArr = {0.3d};
        double[] distances = getDistances(network, euclideanDistance, dArr);
        Neuron neuron = (Neuron) mapRanking.rank(dArr, 1).get(0);
        Assert.assertTrue(euclideanDistance.applyAsDouble(neuron.getFeatures(), dArr) >= 0.2d);
        kohonenUpdateAction.update(network, dArr);
        double[] distances2 = getDistances(network, euclideanDistance, dArr);
        Neuron neuron2 = (Neuron) mapRanking.rank(dArr, 1).get(0);
        Assert.assertEquals(neuron, neuron2);
        Assert.assertEquals(0.0d, euclideanDistance.applyAsDouble(neuron2.getFeatures(), dArr), 1.0E-16d);
        for (int i = 0; i < 3; i++) {
            Assert.assertTrue(distances2[i] < distances[i]);
        }
    }

    private static double[] getDistances(Network network, DistanceMeasure distanceMeasure, double[] dArr) {
        return network.getNeurons().stream().sorted((neuron, neuron2) -> {
            return Long.compare(neuron.getIdentifier(), neuron2.getIdentifier());
        }).mapToDouble(neuron3 -> {
            return distanceMeasure.applyAsDouble(neuron3.getFeatures(), dArr);
        }).toArray();
    }
}
