/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.benchmark.jmh;

import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.MMapDirectory;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.UpdateableRandomVectorScorer;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.TearDown;
import org.openjdk.jmh.annotations.Warmup;

@BenchmarkMode(value={Mode.AverageTime})
@OutputTimeUnit(value=TimeUnit.MILLISECONDS)
@State(value=Scope.Benchmark)
@Warmup(iterations=4, time=1)
@Measurement(iterations=5, time=1)
@Fork(value=3, jvmArgsAppend={"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch", "--add-modules=jdk.incubator.vector"})
public class VectorScorerFloat32Benchmark {
    @Param(value={"1024"})
    public int size;
    @Param(value={"true", "false"})
    public boolean pollute = false;
    public int numVectors = 128000;
    public int numVectorsToScore = 20000;
    float[] scores;
    int[] indices;
    Path path;
    Directory dir;
    IndexInput in;
    KnnVectorValues values;
    UpdateableRandomVectorScorer defDotScorer;
    UpdateableRandomVectorScorer defCosScorer;
    UpdateableRandomVectorScorer defEucScorer;
    UpdateableRandomVectorScorer defMipScorer;
    UpdateableRandomVectorScorer optDotScorer;
    UpdateableRandomVectorScorer optCosScorer;
    UpdateableRandomVectorScorer optEucScorer;
    UpdateableRandomVectorScorer optMipScorer;

    @Setup(value=Level.Trial)
    public void setup() throws IOException {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        this.path = Files.createTempDirectory("VectorScorerFloat32Benchmark", new FileAttribute[0]);
        this.dir = new MMapDirectory(this.path);
        try (IndexOutput out = this.dir.createOutput("vector.data", IOContext.DEFAULT);){
            byte[] ba = new byte[this.size * 4];
            FloatBuffer buf = ByteBuffer.wrap(ba).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer();
            for (int v = 0; v < this.numVectors; ++v) {
                buf.put(0, VectorScorerFloat32Benchmark.randomVector(this.size, random));
                out.writeBytes(ba, 0, ba.length);
            }
        }
    }

    @Setup(value=Level.Iteration)
    public void perIterationInit() throws IOException {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        this.scores = new float[this.numVectorsToScore];
        this.in = this.dir.openInput("vector.data", IOContext.DEFAULT);
        int targetOrd = random.nextInt(this.numVectors);
        this.values = VectorScorerFloat32Benchmark.vectorValues(this.size, this.numVectors, this.in, VectorSimilarityFunction.DOT_PRODUCT);
        DefaultFlatVectorScorer def = DefaultFlatVectorScorer.INSTANCE;
        this.defDotScorer = def.getRandomVectorScorerSupplier(VectorSimilarityFunction.DOT_PRODUCT, this.values.copy()).scorer();
        this.defCosScorer = def.getRandomVectorScorerSupplier(VectorSimilarityFunction.COSINE, this.values.copy()).scorer();
        this.defEucScorer = def.getRandomVectorScorerSupplier(VectorSimilarityFunction.EUCLIDEAN, this.values.copy()).scorer();
        this.defMipScorer = def.getRandomVectorScorerSupplier(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, this.values.copy()).scorer();
        this.defDotScorer.setScoringOrdinal(targetOrd);
        this.defCosScorer.setScoringOrdinal(targetOrd);
        this.defEucScorer.setScoringOrdinal(targetOrd);
        this.defMipScorer.setScoringOrdinal(targetOrd);
        FlatVectorsScorer opt = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
        this.optDotScorer = opt.getRandomVectorScorerSupplier(VectorSimilarityFunction.DOT_PRODUCT, this.values.copy()).scorer();
        this.optCosScorer = opt.getRandomVectorScorerSupplier(VectorSimilarityFunction.COSINE, this.values.copy()).scorer();
        this.optEucScorer = opt.getRandomVectorScorerSupplier(VectorSimilarityFunction.EUCLIDEAN, this.values.copy()).scorer();
        this.optMipScorer = opt.getRandomVectorScorerSupplier(VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT, this.values.copy()).scorer();
        this.optDotScorer.setScoringOrdinal(targetOrd);
        this.optCosScorer.setScoringOrdinal(targetOrd);
        this.optEucScorer.setScoringOrdinal(targetOrd);
        this.optMipScorer.setScoringOrdinal(targetOrd);
        List list = IntStream.range(0, this.numVectors).boxed().collect(Collectors.toList());
        Collections.shuffle(list, random);
        this.indices = list.stream().limit(this.numVectorsToScore).mapToInt(i -> i).toArray();
        if (this.pollute) {
            this.pollute(random);
        }
    }

    @TearDown
    public void teardown() throws IOException {
        IOUtils.close((Closeable[])new Closeable[]{this.in});
        this.dir.deleteFile("vector.data");
        IOUtils.close((Closeable[])new Closeable[]{this.dir});
        Files.delete(this.path);
    }

    public void pollute(Random random) throws IOException {
        float[] vec = VectorScorerFloat32Benchmark.randomVector(this.size, random);
        FlatVectorsScorer opt = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
        for (int i = 0; i < 2; ++i) {
            this.dotProductOptScorer();
            this.dotProductOptBulkScore();
            this.cosineOptScorer();
            this.cosineDefaultBulk();
            this.euclideanOptScorer();
            this.euclideanOptBulkScore();
            this.mipOptScorer();
            this.mipOptBulkScore();
            for (VectorSimilarityFunction sim : List.of(VectorSimilarityFunction.COSINE, VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT)) {
                RandomVectorScorer scorer = opt.getRandomVectorScorer(sim, this.values.copy(), vec);
                for (int v = 0; v < this.numVectorsToScore; ++v) {
                    this.scores[v] = scorer.score(this.indices[v]);
                }
            }
        }
    }

    @Benchmark
    public float[] dotProductDefault() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.defDotScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] dotProductDefaultBulk() throws IOException {
        this.defDotScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] dotProductOptScorer() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.optDotScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] dotProductOptBulkScore() throws IOException {
        this.optDotScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] euclideanDefault() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.defEucScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] euclideanDefaultBulk() throws IOException {
        this.defEucScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] euclideanOptScorer() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.optEucScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] euclideanOptBulkScore() throws IOException {
        this.optEucScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] cosineDefault() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.defCosScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] cosineDefaultBulk() throws IOException {
        this.defCosScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] cosineOptScorer() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.optCosScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] cosineOptBulkScore() throws IOException {
        this.optCosScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] mipDefault() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.defMipScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] mipDefaultBulk() throws IOException {
        this.defMipScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    @Benchmark
    public float[] mipOptScorer() throws IOException {
        for (int v = 0; v < this.numVectorsToScore; ++v) {
            this.scores[v] = this.optMipScorer.score(this.indices[v]);
        }
        return this.scores;
    }

    @Benchmark
    public float[] mipOptBulkScore() throws IOException {
        this.optMipScorer.bulkScore(this.indices, this.scores, this.indices.length);
        return this.scores;
    }

    static float[] randomVector(int dims, Random random) {
        float[] fa = new float[dims];
        for (int i = 0; i < dims; ++i) {
            fa[i] = random.nextFloat();
        }
        return fa;
    }

    static KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException {
        int byteSize = dims * 4;
        return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(dims, size, in.slice("test", 0L, in.length()), byteSize, (FlatVectorsScorer)new ThrowingFlatVectorScorer(), sim);
    }

    static final class ThrowingFlatVectorScorer
    implements FlatVectorsScorer {
        ThrowingFlatVectorScorer() {
        }

        public RandomVectorScorerSupplier getRandomVectorScorerSupplier(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) {
            throw new UnsupportedOperationException();
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) {
            throw new UnsupportedOperationException();
        }

        public RandomVectorScorer getRandomVectorScorer(VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) {
            throw new UnsupportedOperationException();
        }
    }
}

