package org.apache.lucene.codecs.lucene99;

import java.io.IOException;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.ScalarQuantizedVectorScorer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;

/* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.class */
public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
    private final FlatVectorsScorer nonQuantizedDelegate;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer$CompressedInt4DotProduct.class */
    public static class CompressedInt4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final RandomAccessQuantizedByteVectorValues values;
        private final byte[] compressedVector;
        private final byte[] targetBytes;
        private final float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;
        static final /* synthetic */ boolean $assertionsDisabled;

        private CompressedInt4DotProduct(RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues, float f, byte[] bArr, float f2, FloatToFloatFunction floatToFloatFunction) {
            super(randomAccessQuantizedByteVectorValues);
            this.constMultiplier = f;
            this.values = randomAccessQuantizedByteVectorValues;
            this.compressedVector = new byte[randomAccessQuantizedByteVectorValues.getVectorByteLength()];
            this.targetBytes = bArr;
            this.offsetCorrection = f2;
            this.scoreAdjustmentFunction = floatToFloatFunction;
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorer
        public float score(int i) throws IOException {
            this.values.getSlice().seek(i * (this.values.getVectorByteLength() + 4));
            this.values.getSlice().readBytes(this.compressedVector, 0, this.compressedVector.length);
            float scoreCorrectionConstant = this.values.getScoreCorrectionConstant(i);
            int int4DotProductPacked = VectorUtil.int4DotProductPacked(this.targetBytes, this.compressedVector);
            if (!$assertionsDisabled && int4DotProductPacked < 0) {
                throw new AssertionError();
            }
            return this.scoreAdjustmentFunction.apply((int4DotProductPacked * this.constMultiplier) + this.offsetCorrection + scoreCorrectionConstant);
        }

        static {
            $assertionsDisabled = !Lucene99ScalarQuantizedVectorScorer.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer$DotProduct.class */
    public static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final RandomAccessQuantizedByteVectorValues values;
        private final byte[] targetBytes;
        private final float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;
        static final /* synthetic */ boolean $assertionsDisabled;

        public DotProduct(RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues, float f, byte[] bArr, float f2, FloatToFloatFunction floatToFloatFunction) {
            super(randomAccessQuantizedByteVectorValues);
            this.constMultiplier = f;
            this.values = randomAccessQuantizedByteVectorValues;
            this.targetBytes = bArr;
            this.offsetCorrection = f2;
            this.scoreAdjustmentFunction = floatToFloatFunction;
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorer
        public float score(int i) throws IOException {
            byte[] vectorValue = this.values.vectorValue(i);
            float scoreCorrectionConstant = this.values.getScoreCorrectionConstant(i);
            int dotProduct = VectorUtil.dotProduct(vectorValue, this.targetBytes);
            if (!$assertionsDisabled && dotProduct < 0) {
                throw new AssertionError();
            }
            return this.scoreAdjustmentFunction.apply((dotProduct * this.constMultiplier) + this.offsetCorrection + scoreCorrectionConstant);
        }

        static {
            $assertionsDisabled = !Lucene99ScalarQuantizedVectorScorer.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer$Euclidean.class */
    public static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final byte[] targetBytes;
        private final RandomAccessQuantizedByteVectorValues values;

        private Euclidean(RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues, float f, byte[] bArr) {
            super(randomAccessQuantizedByteVectorValues);
            this.values = randomAccessQuantizedByteVectorValues;
            this.constMultiplier = f;
            this.targetBytes = bArr;
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorer
        public float score(int i) throws IOException {
            return 1.0f / (1.0f + (VectorUtil.squareDistance(this.values.vectorValue(i), this.targetBytes) * this.constMultiplier));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    @FunctionalInterface
    /* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer$FloatToFloatFunction.class */
    public interface FloatToFloatFunction {
        float apply(float f);
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer$Int4DotProduct.class */
    public static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer {
        private final float constMultiplier;
        private final RandomAccessQuantizedByteVectorValues values;
        private final byte[] targetBytes;
        private final float offsetCorrection;
        private final FloatToFloatFunction scoreAdjustmentFunction;
        static final /* synthetic */ boolean $assertionsDisabled;

        public Int4DotProduct(RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues, float f, byte[] bArr, float f2, FloatToFloatFunction floatToFloatFunction) {
            super(randomAccessQuantizedByteVectorValues);
            this.constMultiplier = f;
            this.values = randomAccessQuantizedByteVectorValues;
            this.targetBytes = bArr;
            this.offsetCorrection = f2;
            this.scoreAdjustmentFunction = floatToFloatFunction;
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorer
        public float score(int i) throws IOException {
            byte[] vectorValue = this.values.vectorValue(i);
            float scoreCorrectionConstant = this.values.getScoreCorrectionConstant(i);
            int int4DotProduct = VectorUtil.int4DotProduct(vectorValue, this.targetBytes);
            if (!$assertionsDisabled && int4DotProduct < 0) {
                throw new AssertionError();
            }
            return this.scoreAdjustmentFunction.apply((int4DotProduct * this.constMultiplier) + this.offsetCorrection + scoreCorrectionConstant);
        }

        static {
            $assertionsDisabled = !Lucene99ScalarQuantizedVectorScorer.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:WEB-INF/lib/lucene-core-9.11.0.jar:org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer$ScalarQuantizedRandomVectorScorerSupplier.class */
    private static final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
        private final VectorSimilarityFunction vectorSimilarityFunction;
        private final RandomAccessQuantizedByteVectorValues values;
        private final RandomAccessQuantizedByteVectorValues values1;
        private final RandomAccessQuantizedByteVectorValues values2;

        public ScalarQuantizedRandomVectorScorerSupplier(RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues, VectorSimilarityFunction vectorSimilarityFunction) throws IOException {
            this.values = randomAccessQuantizedByteVectorValues;
            this.values1 = randomAccessQuantizedByteVectorValues.copy();
            this.values2 = randomAccessQuantizedByteVectorValues.copy();
            this.vectorSimilarityFunction = vectorSimilarityFunction;
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorerSupplier
        public RandomVectorScorer scorer(int i) throws IOException {
            return Lucene99ScalarQuantizedVectorScorer.fromVectorSimilarity(this.values1.vectorValue(i), this.values1.getScoreCorrectionConstant(i), this.vectorSimilarityFunction, this.values.getScalarQuantizer().getConstantMultiplier(), this.values2);
        }

        @Override // org.apache.lucene.util.hnsw.RandomVectorScorerSupplier
        public ScalarQuantizedRandomVectorScorerSupplier copy() throws IOException {
            return new ScalarQuantizedRandomVectorScorerSupplier(this.values.copy(), this.vectorSimilarityFunction);
        }
    }

    public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) {
        this.nonQuantizedDelegate = flatVectorsScorer;
    }

    @Override // org.apache.lucene.codecs.hnsw.FlatVectorsScorer
    public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction vectorSimilarityFunction, RandomAccessVectorValues randomAccessVectorValues) throws IOException {
        return randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues ? new ScalarQuantizedRandomVectorScorerSupplier((RandomAccessQuantizedByteVectorValues) randomAccessVectorValues, vectorSimilarityFunction) : this.nonQuantizedDelegate.getRandomVectorScorerSupplier(vectorSimilarityFunction, randomAccessVectorValues);
    }

    @Override // org.apache.lucene.codecs.hnsw.FlatVectorsScorer
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, RandomAccessVectorValues randomAccessVectorValues, float[] fArr) throws IOException {
        if (!(randomAccessVectorValues instanceof RandomAccessQuantizedByteVectorValues)) {
            return this.nonQuantizedDelegate.getRandomVectorScorer(vectorSimilarityFunction, randomAccessVectorValues, fArr);
        }
        RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues = (RandomAccessQuantizedByteVectorValues) randomAccessVectorValues;
        ScalarQuantizer scalarQuantizer = randomAccessQuantizedByteVectorValues.getScalarQuantizer();
        byte[] bArr = new byte[fArr.length];
        return fromVectorSimilarity(bArr, ScalarQuantizedVectorScorer.quantizeQuery(fArr, bArr, vectorSimilarityFunction, scalarQuantizer), vectorSimilarityFunction, scalarQuantizer.getConstantMultiplier(), randomAccessQuantizedByteVectorValues);
    }

    @Override // org.apache.lucene.codecs.hnsw.FlatVectorsScorer
    public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction vectorSimilarityFunction, RandomAccessVectorValues randomAccessVectorValues, byte[] bArr) throws IOException {
        return this.nonQuantizedDelegate.getRandomVectorScorer(vectorSimilarityFunction, randomAccessVectorValues, bArr);
    }

    public String toString() {
        return "ScalarQuantizedVectorScorer(nonQuantizedDelegate=" + this.nonQuantizedDelegate + ")";
    }

    static RandomVectorScorer fromVectorSimilarity(byte[] bArr, float f, VectorSimilarityFunction vectorSimilarityFunction, float f2, RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues) {
        switch (vectorSimilarityFunction) {
            case EUCLIDEAN:
                return new Euclidean(randomAccessQuantizedByteVectorValues, f2, bArr);
            case COSINE:
            case DOT_PRODUCT:
                return dotProductFactory(bArr, f, f2, randomAccessQuantizedByteVectorValues, f3 -> {
                    return Math.max((1.0f + f3) / 2.0f, 0.0f);
                });
            case MAXIMUM_INNER_PRODUCT:
                return dotProductFactory(bArr, f, f2, randomAccessQuantizedByteVectorValues, VectorUtil::scaleMaxInnerProductScore);
            default:
                throw new IllegalArgumentException("Unsupported similarity function: " + vectorSimilarityFunction);
        }
    }

    private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(byte[] bArr, float f, float f2, RandomAccessQuantizedByteVectorValues randomAccessQuantizedByteVectorValues, FloatToFloatFunction floatToFloatFunction) {
        return randomAccessQuantizedByteVectorValues.getScalarQuantizer().getBits() <= 4 ? (randomAccessQuantizedByteVectorValues.getVectorByteLength() == randomAccessQuantizedByteVectorValues.dimension() || randomAccessQuantizedByteVectorValues.getSlice() == null) ? new Int4DotProduct(randomAccessQuantizedByteVectorValues, f2, bArr, f, floatToFloatFunction) : new CompressedInt4DotProduct(randomAccessQuantizedByteVectorValues, f2, bArr, f, floatToFloatFunction) : new DotProduct(randomAccessQuantizedByteVectorValues, f2, bArr, f, floatToFloatFunction);
    }
}
