/*
 * Decompiled with CFR 0.152.
 */
package org.elasticsearch.xpack.vectors.mapper;

import java.nio.ByteBuffer;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.InPlaceMergeSorter;
import org.elasticsearch.Version;

public final class VectorEncoderDecoder {
    static final byte INT_BYTES = 4;
    static final byte SHORT_BYTES = 2;

    private VectorEncoderDecoder() {
    }

    public static BytesRef encodeSparseVector(Version indexVersion, int[] dims, float[] values, int dimCount) {
        VectorEncoderDecoder.sortSparseDimsValues(dims, values, dimCount);
        byte[] bytes = indexVersion.onOrAfter(Version.V_7_5_0) ? new byte[dimCount * 6 + 4] : new byte[dimCount * 6];
        ByteBuffer byteBuffer = ByteBuffer.wrap(bytes);
        for (int dim = 0; dim < dimCount; ++dim) {
            int dimValue = dims[dim];
            byteBuffer.put((byte)(dimValue >> 8));
            byteBuffer.put((byte)dimValue);
        }
        double dotProduct = 0.0;
        for (int dim = 0; dim < dimCount; ++dim) {
            float value = values[dim];
            byteBuffer.putFloat(value);
            dotProduct += (double)(value * value);
        }
        if (indexVersion.onOrAfter(Version.V_7_5_0)) {
            float vectorMagnitude = (float)Math.sqrt(dotProduct);
            byteBuffer.putFloat(vectorMagnitude);
        }
        return new BytesRef(bytes);
    }

    public static int[] decodeSparseVectorDims(Version indexVersion, BytesRef vectorBR) {
        int dimCount = indexVersion.onOrAfter(Version.V_7_5_0) ? (vectorBR.length - 4) / 6 : vectorBR.length / 6;
        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, dimCount * 2);
        int[] dims = new int[dimCount];
        for (int dim = 0; dim < dimCount; ++dim) {
            dims[dim] = (byteBuffer.get() & 0xFF) << 8 | byteBuffer.get() & 0xFF;
        }
        return dims;
    }

    public static float[] decodeSparseVector(Version indexVersion, BytesRef vectorBR) {
        int dimCount = indexVersion.onOrAfter(Version.V_7_5_0) ? (vectorBR.length - 4) / 6 : vectorBR.length / 6;
        int offset = vectorBR.offset + 2 * dimCount;
        float[] vector = new float[dimCount];
        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, offset, dimCount * 4);
        for (int dim = 0; dim < dimCount; ++dim) {
            vector[dim] = byteBuffer.getFloat();
        }
        return vector;
    }

    public static void sortSparseDimsValues(final int[] dims, final float[] values, int n) {
        new InPlaceMergeSorter(){

            public int compare(int i, int j) {
                return Integer.compare(dims[i], dims[j]);
            }

            public void swap(int i, int j) {
                int tempDim = dims[i];
                dims[i] = dims[j];
                dims[j] = tempDim;
                float tempValue = values[j];
                values[j] = values[i];
                values[i] = tempValue;
            }
        }.sort(0, n);
    }

    public static void sortSparseDimsFloatValues(final int[] dims, final float[] values, int n) {
        new InPlaceMergeSorter(){

            public int compare(int i, int j) {
                return Integer.compare(dims[i], dims[j]);
            }

            public void swap(int i, int j) {
                int tempDim = dims[i];
                dims[i] = dims[j];
                dims[j] = tempDim;
                float tempValue = values[j];
                values[j] = values[i];
                values[i] = tempValue;
            }
        }.sort(0, n);
    }

    public static int denseVectorLength(Version indexVersion, BytesRef vectorBR) {
        return indexVersion.onOrAfter(Version.V_7_5_0) ? (vectorBR.length - 4) / 4 : vectorBR.length / 4;
    }

    public static float decodeMagnitude(Version indexVersion, BytesRef vectorBR) {
        assert (indexVersion.onOrAfter(Version.V_7_5_0));
        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
        return byteBuffer.getFloat(vectorBR.offset + vectorBR.length - 4);
    }

    private static float calculateMagnitude(Version indexVersion, BytesRef vectorBR) {
        int length = VectorEncoderDecoder.denseVectorLength(indexVersion, vectorBR);
        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
        double magnitude = 0.0;
        for (int i = 0; i < length; ++i) {
            float value = byteBuffer.getFloat();
            magnitude += (double)(value * value);
        }
        magnitude = Math.sqrt(magnitude);
        return (float)magnitude;
    }

    public static float getMagnitude(Version indexVersion, BytesRef vectorBR) {
        if (vectorBR == null) {
            throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
        }
        if (indexVersion.onOrAfter(Version.V_7_5_0)) {
            return VectorEncoderDecoder.decodeMagnitude(indexVersion, vectorBR);
        }
        return VectorEncoderDecoder.calculateMagnitude(indexVersion, vectorBR);
    }

    public static void decodeDenseVector(BytesRef vectorBR, float[] vector) {
        if (vectorBR == null) {
            throw new IllegalArgumentException("A document doesn't have a value for a vector field!");
        }
        ByteBuffer byteBuffer = ByteBuffer.wrap(vectorBR.bytes, vectorBR.offset, vectorBR.length);
        for (int dim = 0; dim < vector.length; ++dim) {
            vector[dim] = byteBuffer.getFloat();
        }
    }
}

