package ai.onnxruntime;

import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.TensorInfo;
import java.io.IOException;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.ShortBuffer;

/* loaded from: input_file:ai/onnxruntime/OnnxTensor.class */
public class OnnxTensor implements OnnxValue {
    private final long nativeHandle;
    private final long allocatorHandle;
    private final TensorInfo info;
    private final Buffer buffer;

    OnnxTensor(long j, long j2, TensorInfo tensorInfo) {
        this(j, j2, tensorInfo, null);
    }

    OnnxTensor(long j, long j2, TensorInfo tensorInfo, Buffer buffer) {
        this.nativeHandle = j;
        this.allocatorHandle = j2;
        this.info = tensorInfo;
        this.buffer = buffer;
    }

    @Override // ai.onnxruntime.OnnxValue
    public OnnxValue.OnnxValueType getType() {
        return OnnxValue.OnnxValueType.ONNX_TYPE_TENSOR;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public long getNativeHandle() {
        return this.nativeHandle;
    }

    @Override // ai.onnxruntime.OnnxValue
    public Object getValue() throws OrtException {
        if (!this.info.isScalar()) {
            Object makeCarrier = this.info.makeCarrier();
            getArray(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocatorHandle, makeCarrier);
            return makeCarrier;
        }
        switch (this.info.type) {
            case FLOAT:
                return Float.valueOf(getFloat(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case DOUBLE:
                return Double.valueOf(getDouble(OnnxRuntime.ortApiHandle, this.nativeHandle));
            case INT8:
                return Byte.valueOf(getByte(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case INT16:
                return Short.valueOf(getShort(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case INT32:
                return Integer.valueOf(getInt(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case INT64:
                return Long.valueOf(getLong(OnnxRuntime.ortApiHandle, this.nativeHandle, this.info.onnxType.value));
            case BOOL:
                return Boolean.valueOf(getBool(OnnxRuntime.ortApiHandle, this.nativeHandle));
            case STRING:
                return getString(OnnxRuntime.ortApiHandle, this.nativeHandle, this.allocatorHandle);
            case UNKNOWN:
            default:
                throw new OrtException("Extracting the value of an invalid Tensor.");
        }
    }

    @Override // ai.onnxruntime.OnnxValue
    public TensorInfo getInfo() {
        return this.info;
    }

    public String toString() {
        return "OnnxTensor(info=" + this.info.toString() + ")";
    }

    @Override // ai.onnxruntime.OnnxValue, java.lang.AutoCloseable
    public void close() {
        close(OnnxRuntime.ortApiHandle, this.nativeHandle);
    }

    public ByteBuffer getByteBuffer() {
        if (this.info.type == OnnxJavaType.STRING) {
            return null;
        }
        ByteBuffer buffer = getBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle);
        ByteBuffer allocate = ByteBuffer.allocate(buffer.capacity());
        allocate.put(buffer);
        allocate.rewind();
        return allocate;
    }

    public FloatBuffer getFloatBuffer() {
        if (this.info.type != OnnxJavaType.FLOAT) {
            return null;
        }
        if (this.info.onnxType != TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
            FloatBuffer asFloatBuffer = getBuffer().asFloatBuffer();
            FloatBuffer allocate = FloatBuffer.allocate(asFloatBuffer.capacity());
            allocate.put(asFloatBuffer);
            allocate.rewind();
            return allocate;
        }
        ShortBuffer asShortBuffer = getBuffer().asShortBuffer();
        int capacity = asShortBuffer.capacity();
        FloatBuffer allocate2 = FloatBuffer.allocate(capacity);
        for (int i = 0; i < capacity; i++) {
            allocate2.put(fp16ToFloat(asShortBuffer.get(i)));
        }
        allocate2.rewind();
        return allocate2;
    }

    public DoubleBuffer getDoubleBuffer() {
        if (this.info.type != OnnxJavaType.DOUBLE) {
            return null;
        }
        DoubleBuffer asDoubleBuffer = getBuffer().asDoubleBuffer();
        DoubleBuffer allocate = DoubleBuffer.allocate(asDoubleBuffer.capacity());
        allocate.put(asDoubleBuffer);
        allocate.rewind();
        return allocate;
    }

    public ShortBuffer getShortBuffer() {
        if (this.info.type != OnnxJavaType.INT16) {
            return null;
        }
        ShortBuffer asShortBuffer = getBuffer().asShortBuffer();
        ShortBuffer allocate = ShortBuffer.allocate(asShortBuffer.capacity());
        allocate.put(asShortBuffer);
        allocate.rewind();
        return allocate;
    }

    public IntBuffer getIntBuffer() {
        if (this.info.type != OnnxJavaType.INT32) {
            return null;
        }
        IntBuffer asIntBuffer = getBuffer().asIntBuffer();
        IntBuffer allocate = IntBuffer.allocate(asIntBuffer.capacity());
        allocate.put(asIntBuffer);
        allocate.rewind();
        return allocate;
    }

    public LongBuffer getLongBuffer() {
        if (this.info.type != OnnxJavaType.INT64) {
            return null;
        }
        LongBuffer asLongBuffer = getBuffer().asLongBuffer();
        LongBuffer allocate = LongBuffer.allocate(asLongBuffer.capacity());
        allocate.put(asLongBuffer);
        allocate.rewind();
        return allocate;
    }

    private ByteBuffer getBuffer() {
        return getBuffer(OnnxRuntime.ortApiHandle, this.nativeHandle).order(ByteOrder.nativeOrder());
    }

    private native ByteBuffer getBuffer(long j, long j2);

    private native float getFloat(long j, long j2, int i) throws OrtException;

    private native double getDouble(long j, long j2) throws OrtException;

    private native byte getByte(long j, long j2, int i) throws OrtException;

    private native short getShort(long j, long j2, int i) throws OrtException;

    private native int getInt(long j, long j2, int i) throws OrtException;

    private native long getLong(long j, long j2, int i) throws OrtException;

    private native String getString(long j, long j2, long j3) throws OrtException;

    private native boolean getBool(long j, long j2) throws OrtException;

    private native void getArray(long j, long j2, long j3, Object obj) throws OrtException;

    private native void close(long j, long j2);

    private static float fp16ToFloat(short s) {
        return Float.intBitsToFloat(((s & 32768) << 16) | (((s & 31744) + 114688) << 13) | ((s & 1023) << 13));
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, Object obj) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, obj);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, Object obj) throws OrtException {
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor with a closed OrtAllocator.");
        }
        TensorInfo constructFromJavaArray = TensorInfo.constructFromJavaArray(obj);
        if (constructFromJavaArray.type == OnnxJavaType.STRING) {
            return constructFromJavaArray.shape.length == 0 ? new OnnxTensor(createString(OnnxRuntime.ortApiHandle, ortAllocator.handle, (String) obj), ortAllocator.handle, constructFromJavaArray) : new OnnxTensor(createStringTensor(OnnxRuntime.ortApiHandle, ortAllocator.handle, OrtUtil.flattenString(obj), constructFromJavaArray.shape), ortAllocator.handle, constructFromJavaArray);
        }
        if (constructFromJavaArray.shape.length == 0) {
            obj = OrtUtil.convertBoxedPrimitiveToArray(obj);
        }
        return new OnnxTensor(createTensor(OnnxRuntime.ortApiHandle, ortAllocator.handle, obj, constructFromJavaArray.shape, constructFromJavaArray.onnxType.value), ortAllocator.handle, constructFromJavaArray);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, String[] strArr, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, strArr, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, String[] strArr, long[] jArr) throws OrtException {
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        return new OnnxTensor(createStringTensor(OnnxRuntime.ortApiHandle, ortAllocator.handle, strArr, jArr), ortAllocator.handle, new TensorInfo(jArr, OnnxJavaType.STRING, TensorInfo.OnnxTensorType.ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING));
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, FloatBuffer floatBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, floatBuffer, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, FloatBuffer floatBuffer, long[] jArr) throws OrtException {
        FloatBuffer asFloatBuffer;
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        OnnxJavaType onnxJavaType = OnnxJavaType.FLOAT;
        int capacity = floatBuffer.capacity() * onnxJavaType.size;
        if (floatBuffer.isDirect()) {
            asFloatBuffer = floatBuffer;
        } else {
            asFloatBuffer = ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()).asFloatBuffer();
            asFloatBuffer.put(floatBuffer);
        }
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(asFloatBuffer, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, asFloatBuffer, capacity, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, asFloatBuffer);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, DoubleBuffer doubleBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, doubleBuffer, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, DoubleBuffer doubleBuffer, long[] jArr) throws OrtException {
        DoubleBuffer asDoubleBuffer;
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        OnnxJavaType onnxJavaType = OnnxJavaType.DOUBLE;
        int capacity = doubleBuffer.capacity() * onnxJavaType.size;
        if (doubleBuffer.isDirect()) {
            asDoubleBuffer = doubleBuffer;
        } else {
            asDoubleBuffer = ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()).asDoubleBuffer();
            asDoubleBuffer.put(doubleBuffer);
        }
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(asDoubleBuffer, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, asDoubleBuffer, capacity, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, asDoubleBuffer);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, ByteBuffer byteBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, byteBuffer, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, ByteBuffer byteBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortAllocator, byteBuffer, jArr, OnnxJavaType.INT8);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, ByteBuffer byteBuffer, long[] jArr, OnnxJavaType onnxJavaType) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, byteBuffer, jArr, onnxJavaType);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, ByteBuffer byteBuffer, long[] jArr, OnnxJavaType onnxJavaType) throws OrtException {
        ByteBuffer order;
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        int capacity = byteBuffer.capacity();
        if (byteBuffer.isDirect()) {
            order = byteBuffer;
        } else {
            order = ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder());
            order.put(byteBuffer);
        }
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(order, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, order, capacity, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, order);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, ShortBuffer shortBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, shortBuffer, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, ShortBuffer shortBuffer, long[] jArr) throws OrtException {
        ShortBuffer asShortBuffer;
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        OnnxJavaType onnxJavaType = OnnxJavaType.INT16;
        int capacity = shortBuffer.capacity() * onnxJavaType.size;
        if (shortBuffer.isDirect()) {
            asShortBuffer = shortBuffer;
        } else {
            asShortBuffer = ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()).asShortBuffer();
            asShortBuffer.put(shortBuffer);
        }
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(asShortBuffer, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, asShortBuffer, capacity, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, asShortBuffer);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, IntBuffer intBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, intBuffer, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, IntBuffer intBuffer, long[] jArr) throws OrtException {
        IntBuffer asIntBuffer;
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        OnnxJavaType onnxJavaType = OnnxJavaType.INT32;
        int capacity = intBuffer.capacity() * onnxJavaType.size;
        if (intBuffer.isDirect()) {
            asIntBuffer = intBuffer;
        } else {
            asIntBuffer = ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()).asIntBuffer();
            asIntBuffer.put(intBuffer);
        }
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(asIntBuffer, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, asIntBuffer, capacity, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, asIntBuffer);
    }

    public static OnnxTensor createTensor(OrtEnvironment ortEnvironment, LongBuffer longBuffer, long[] jArr) throws OrtException {
        return createTensor(ortEnvironment, ortEnvironment.defaultAllocator, longBuffer, jArr);
    }

    static OnnxTensor createTensor(OrtEnvironment ortEnvironment, OrtAllocator ortAllocator, LongBuffer longBuffer, long[] jArr) throws OrtException {
        LongBuffer asLongBuffer;
        if (ortEnvironment.isClosed() || ortAllocator.isClosed()) {
            throw new IllegalStateException("Trying to create an OnnxTensor on a closed OrtAllocator.");
        }
        OnnxJavaType onnxJavaType = OnnxJavaType.INT64;
        int capacity = longBuffer.capacity() * onnxJavaType.size;
        if (longBuffer.isDirect()) {
            asLongBuffer = longBuffer;
        } else {
            asLongBuffer = ByteBuffer.allocateDirect(capacity).order(ByteOrder.nativeOrder()).asLongBuffer();
            asLongBuffer.put(longBuffer);
        }
        TensorInfo constructFromBuffer = TensorInfo.constructFromBuffer(asLongBuffer, jArr, onnxJavaType);
        return new OnnxTensor(createTensorFromBuffer(OnnxRuntime.ortApiHandle, ortAllocator.handle, asLongBuffer, capacity, jArr, constructFromBuffer.onnxType.value), ortAllocator.handle, constructFromBuffer, asLongBuffer);
    }

    private static native long createTensor(long j, long j2, Object obj, long[] jArr, int i) throws OrtException;

    private static native long createTensorFromBuffer(long j, long j2, Buffer buffer, long j3, long[] jArr, int i) throws OrtException;

    private static native long createString(long j, long j2, String str) throws OrtException;

    private static native long createStringTensor(long j, long j2, Object[] objArr, long[] jArr) throws OrtException;

    static {
        try {
            OnnxRuntime.init();
        } catch (IOException e) {
            throw new RuntimeException("Failed to load onnx-runtime library", e);
        }
    }
}
