Home / Class/ TensorHybrid Class — pytorch Architecture

TensorHybrid Class — pytorch Architecture

Architecture documentation for the TensorHybrid class in pytorch_jni_common.cpp from the pytorch codebase.

Entity Profile

Relationship Graph

Source Code

android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp lines 167–289

class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
 public:
  constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";

  explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}

  static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
      facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
    static auto cls = TensorHybrid::javaClassStatic();
    static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
    static const auto jMethodMemoryFormatCode =
        cls->getMethod<jint()>("memoryFormatJniCode");
    static const auto jFieldShape = cls->getField<jlongArray>("shape");
    static const auto jMethodGetDataBuffer = cls->getMethod<
        facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
        "getRawDataBuffer");

    at::Tensor tensor = newAtTensor(
        jMethodGetDataBuffer(jTensorThis),
        jTensorThis->getFieldValue(jFieldShape),
        jMethodDTypeCode(jTensorThis),
        jMethodMemoryFormatCode(jTensorThis));
    return makeCxxInstance(std::move(tensor));
  }

  static facebook::jni::local_ref<TensorHybrid::javaobject>
  newJTensorFromAtTensor(const at::Tensor& input_tensor) {
    // Java wrapper currently only supports contiguous tensors.

    int jmemoryFormat = 0;
    at::Tensor tensor{};
    if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
      tensor = input_tensor;
      jmemoryFormat = kTensorMemoryFormatChannelsLast;
    } else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
      tensor = input_tensor;
      jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
    } else {
      tensor = input_tensor.contiguous();
      jmemoryFormat = kTensorMemoryFormatContiguous;
    }

    const auto scalarType = tensor.scalar_type();
    int jdtype = 0;
    if (at::kFloat == scalarType) {
      jdtype = kTensorDTypeFloat32;
    } else if (at::kInt == scalarType) {
      jdtype = kTensorDTypeInt32;
    } else if (at::kByte == scalarType) {
      jdtype = kTensorDTypeUInt8;
    } else if (at::kChar == scalarType) {
      jdtype = kTensorDTypeInt8;
    } else if (at::kLong == scalarType) {
      jdtype = kTensorDTypeInt64;
    } else if (at::kDouble == scalarType) {
      jdtype = kTensorDTypeFloat64;
    } else {
      facebook::jni::throwNewJavaException(
          facebook::jni::gJavaLangIllegalArgumentException,
          "at::Tensor scalar type %s is not supported on java side",
          c10::toString(scalarType));
    }

    const auto& tensorShape = tensor.sizes();
    std::vector<jlong> tensorShapeVec;
    for (const auto& s : tensorShape) {
      tensorShapeVec.push_back(s);
    }
    facebook::jni::local_ref<jlongArray> jTensorShape =
        facebook::jni::make_long_array(tensorShapeVec.size());
    jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());

    static auto cls = TensorHybrid::javaClassStatic();
    facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
        facebook::jni::JByteBuffer::wrapBytes(
            (uint8_t*)tensor.data_ptr(), tensor.nbytes());
    jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());

    static const auto jMethodNewTensor =
        cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
            facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
            facebook::jni::alias_ref<jlongArray>,
            jint,
            jint,
            facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
    return jMethodNewTensor(
        cls,
        jTensorBuffer,
        jTensorShape,
        jdtype,
        jmemoryFormat,
        makeCxxInstance(tensor));
  }

  static at::Tensor newAtTensorFromJTensor(
      facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
    static auto cls = TensorHybrid::javaClassStatic();
    static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
    jint jdtype = dtypeMethod(jtensor);

    static const auto memoryFormatMethod =
        cls->getMethod<jint()>("memoryFormatJniCode");
    jint jmemoryFormat = memoryFormatMethod(jtensor);

    static const auto shapeField = cls->getField<jlongArray>("shape");
    auto jshape = jtensor->getFieldValue(shapeField);

    static auto dataBufferMethod = cls->getMethod<
        facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
        "getRawDataBuffer");
    facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
        dataBufferMethod(jtensor);
    return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
  }

  at::Tensor tensor() const {
    return tensor_;
  }

 private:
  friend HybridBase;
  at::Tensor tensor_;
};

Domain

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free