package org.apache.mahout.math.neighborhood;

import com.google.common.base.Preconditions;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.random.RandomProjector;
import org.apache.mahout.math.random.WeightedThing;

/* loaded from: input_file:org/apache/mahout/math/neighborhood/FastProjectionSearch.class */
public class FastProjectionSearch extends UpdatableSearcher {
    private final List<Vector> pendingAdditions;
    private Matrix basisMatrix;
    private List<List<WeightedThing<Vector>>> scalarProjections;
    private final int numProjections;
    private final int searchSize;
    private boolean initialized;
    private int numPendingRemovals;
    private static final double ADDITION_THRESHOLD = 0.05d;
    private static final double REMOVAL_THRESHOLD = 0.02d;

    public FastProjectionSearch(DistanceMeasure distanceMeasure, int i, int i2) {
        super(distanceMeasure);
        this.pendingAdditions = Lists.newArrayList();
        this.basisMatrix = null;
        this.initialized = false;
        this.numPendingRemovals = 0;
        Preconditions.checkArgument(i > 0 && i < 100, "Unreasonable value for number of projections. Must be: 0 < numProjections < 100");
        this.numProjections = i;
        this.searchSize = i2;
        this.scalarProjections = Lists.newArrayListWithCapacity(i);
        for (int i3 = 0; i3 < i; i3++) {
            this.scalarProjections.add(Lists.newArrayList());
        }
    }

    private void initialize(int i) {
        if (this.initialized) {
            return;
        }
        this.basisMatrix = RandomProjector.generateBasisNormal(this.numProjections, i);
        this.initialized = true;
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public void add(Vector vector) {
        initialize(vector.size());
        this.pendingAdditions.add(vector);
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public int size() {
        return (this.pendingAdditions.size() + this.scalarProjections.get(0).size()) - this.numPendingRemovals;
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public List<WeightedThing<Vector>> search(Vector vector, int i) {
        reindex(false);
        HashSet newHashSet = Sets.newHashSet();
        Vector times = this.basisMatrix.times(vector);
        for (int i2 = 0; i2 < this.basisMatrix.numRows(); i2++) {
            List<WeightedThing<Vector>> list = this.scalarProjections.get(i2);
            int binarySearch = Collections.binarySearch(list, new WeightedThing(times.get(i2)));
            if (binarySearch < 0) {
                binarySearch = -(binarySearch + 1);
            }
            for (int max = Math.max(0, binarySearch - this.searchSize); max < Math.min(list.size(), binarySearch + this.searchSize + 1); max++) {
                if (list.get(max).getValue() != null) {
                    newHashSet.add(list.get(max).getValue());
                }
            }
        }
        ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(newHashSet.size() + this.pendingAdditions.size());
        for (Vector vector2 : Iterables.concat(newHashSet, this.pendingAdditions)) {
            newArrayListWithCapacity.add(new WeightedThing(vector2, this.distanceMeasure.distance(vector2, vector)));
        }
        Collections.sort(newArrayListWithCapacity);
        return newArrayListWithCapacity.subList(0, Math.min(newArrayListWithCapacity.size(), i));
    }

    @Override // org.apache.mahout.math.neighborhood.Searcher
    public WeightedThing<Vector> searchFirst(Vector vector, boolean z) {
        reindex(false);
        double d = Double.POSITIVE_INFINITY;
        Vector vector2 = null;
        Vector times = this.basisMatrix.times(vector);
        for (int i = 0; i < this.basisMatrix.numRows(); i++) {
            List<WeightedThing<Vector>> list = this.scalarProjections.get(i);
            int binarySearch = Collections.binarySearch(list, new WeightedThing(times.get(i)));
            if (binarySearch < 0) {
                binarySearch = -(binarySearch + 1);
            }
            for (int max = Math.max(0, binarySearch - this.searchSize); max < Math.min(list.size(), binarySearch + this.searchSize + 1); max++) {
                if (list.get(max).getValue() != null) {
                    Vector value = list.get(max).getValue();
                    double distance = this.distanceMeasure.distance(value, vector);
                    if (distance < d && (!z || !value.equals(vector))) {
                        d = distance;
                        vector2 = value;
                    }
                }
            }
        }
        for (Vector vector3 : this.pendingAdditions) {
            double distance2 = this.distanceMeasure.distance(vector3, vector);
            if (distance2 < d && (!z || !vector3.equals(vector))) {
                d = distance2;
                vector2 = vector3;
            }
        }
        return new WeightedThing<>(vector2, d);
    }

    @Override // org.apache.mahout.math.neighborhood.UpdatableSearcher, org.apache.mahout.math.neighborhood.Searcher
    public boolean remove(Vector vector, double d) {
        if (this.distanceMeasure.distance(searchFirst(vector, false).getValue(), vector) > d) {
            return false;
        }
        boolean z = true;
        Vector times = this.basisMatrix.times(vector);
        int i = 0;
        while (true) {
            if (i >= this.basisMatrix.numRows()) {
                break;
            }
            List<WeightedThing<Vector>> list = this.scalarProjections.get(i);
            WeightedThing<Vector> weightedThing = new WeightedThing<>(times.get(i));
            int binarySearch = Collections.binarySearch(list, weightedThing);
            if (binarySearch < 0) {
                z = false;
                break;
            }
            this.scalarProjections.get(i).set(binarySearch, weightedThing);
            i++;
        }
        if (z) {
            this.numPendingRemovals++;
            return true;
        }
        for (int i2 = 0; i2 < this.pendingAdditions.size(); i2++) {
            if (this.pendingAdditions.get(i2).equals(vector)) {
                this.pendingAdditions.remove(i2);
                return true;
            }
        }
        return true;
    }

    private void reindex(boolean z) {
        int size = this.scalarProjections.get(0).size();
        if (z || this.pendingAdditions.size() > ADDITION_THRESHOLD * size || this.numPendingRemovals > 0.02d * size) {
            ArrayList newArrayListWithCapacity = Lists.newArrayListWithCapacity(this.numProjections);
            for (int i = 0; i < this.numProjections; i++) {
                if (i == 0) {
                    newArrayListWithCapacity.add(Lists.newArrayList(this.scalarProjections.get(i)));
                } else {
                    newArrayListWithCapacity.add(this.scalarProjections.get(i));
                }
            }
            for (Vector vector : this.pendingAdditions) {
                Vector times = this.basisMatrix.times(vector);
                for (int i2 = 0; i2 < this.numProjections; i2++) {
                    ((List) newArrayListWithCapacity.get(i2)).add(new WeightedThing(vector, times.get(i2)));
                }
            }
            this.pendingAdditions.clear();
            for (int i3 = 0; i3 < this.numProjections; i3++) {
                List<WeightedThing> list = (List) newArrayListWithCapacity.get(i3);
                for (WeightedThing weightedThing : list) {
                    if (weightedThing.getValue() == null) {
                        weightedThing.setWeight(Double.POSITIVE_INFINITY);
                    }
                }
                Collections.sort(list);
                for (int i4 = 0; i4 < this.numPendingRemovals; i4++) {
                    list.remove(list.size() - 1);
                }
            }
            this.numPendingRemovals = 0;
            this.scalarProjections = newArrayListWithCapacity;
        }
    }

    @Override // org.apache.mahout.math.neighborhood.UpdatableSearcher, org.apache.mahout.math.neighborhood.Searcher
    public void clear() {
        this.pendingAdditions.clear();
        for (int i = 0; i < this.numProjections; i++) {
            this.scalarProjections.get(i).clear();
        }
        this.numPendingRemovals = 0;
    }

    @Override // java.lang.Iterable
    public Iterator<Vector> iterator() {
        reindex(true);
        return new AbstractIterator<Vector>() { // from class: org.apache.mahout.math.neighborhood.FastProjectionSearch.1
            private final Iterator<WeightedThing<Vector>> data;

            {
                this.data = ((List) FastProjectionSearch.this.scalarProjections.get(0)).iterator();
            }

            /* JADX INFO: Access modifiers changed from: protected */
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // com.google.common.collect.AbstractIterator
            public Vector computeNext() {
                while (this.data.hasNext()) {
                    WeightedThing<Vector> next = this.data.next();
                    if (next.getValue() != null) {
                        return next.getValue();
                    }
                }
                return endOfData();
            }
        };
    }
}
