TensorHybrid Class — pytorch Architecture
Architecture documentation for the TensorHybrid class in pytorch_jni_common.cpp from the pytorch codebase.
Entity Profile
Relationship Graph
Source Code
android/pytorch_android/src/main/cpp/pytorch_jni_common.cpp lines 167–289
class TensorHybrid : public facebook::jni::HybridClass<TensorHybrid> {
public:
constexpr static const char* kJavaDescriptor = "Lorg/pytorch/Tensor;";
explicit TensorHybrid(at::Tensor tensor) : tensor_(tensor) {}
static facebook::jni::local_ref<TensorHybrid::jhybriddata> initHybrid(
facebook::jni::alias_ref<TensorHybrid::javaobject> jTensorThis) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto jMethodDTypeCode = cls->getMethod<jint()>("dtypeJniCode");
static const auto jMethodMemoryFormatCode =
cls->getMethod<jint()>("memoryFormatJniCode");
static const auto jFieldShape = cls->getField<jlongArray>("shape");
static const auto jMethodGetDataBuffer = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
at::Tensor tensor = newAtTensor(
jMethodGetDataBuffer(jTensorThis),
jTensorThis->getFieldValue(jFieldShape),
jMethodDTypeCode(jTensorThis),
jMethodMemoryFormatCode(jTensorThis));
return makeCxxInstance(std::move(tensor));
}
static facebook::jni::local_ref<TensorHybrid::javaobject>
newJTensorFromAtTensor(const at::Tensor& input_tensor) {
// Java wrapper currently only supports contiguous tensors.
int jmemoryFormat = 0;
at::Tensor tensor{};
if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast;
} else if (input_tensor.is_contiguous(at::MemoryFormat::ChannelsLast3d)) {
tensor = input_tensor;
jmemoryFormat = kTensorMemoryFormatChannelsLast3d;
} else {
tensor = input_tensor.contiguous();
jmemoryFormat = kTensorMemoryFormatContiguous;
}
const auto scalarType = tensor.scalar_type();
int jdtype = 0;
if (at::kFloat == scalarType) {
jdtype = kTensorDTypeFloat32;
} else if (at::kInt == scalarType) {
jdtype = kTensorDTypeInt32;
} else if (at::kByte == scalarType) {
jdtype = kTensorDTypeUInt8;
} else if (at::kChar == scalarType) {
jdtype = kTensorDTypeInt8;
} else if (at::kLong == scalarType) {
jdtype = kTensorDTypeInt64;
} else if (at::kDouble == scalarType) {
jdtype = kTensorDTypeFloat64;
} else {
facebook::jni::throwNewJavaException(
facebook::jni::gJavaLangIllegalArgumentException,
"at::Tensor scalar type %s is not supported on java side",
c10::toString(scalarType));
}
const auto& tensorShape = tensor.sizes();
std::vector<jlong> tensorShapeVec;
for (const auto& s : tensorShape) {
tensorShapeVec.push_back(s);
}
facebook::jni::local_ref<jlongArray> jTensorShape =
facebook::jni::make_long_array(tensorShapeVec.size());
jTensorShape->setRegion(0, tensorShapeVec.size(), tensorShapeVec.data());
static auto cls = TensorHybrid::javaClassStatic();
facebook::jni::local_ref<facebook::jni::JByteBuffer> jTensorBuffer =
facebook::jni::JByteBuffer::wrapBytes(
(uint8_t*)tensor.data_ptr(), tensor.nbytes());
jTensorBuffer->order(facebook::jni::JByteOrder::nativeOrder());
static const auto jMethodNewTensor =
cls->getStaticMethod<facebook::jni::local_ref<TensorHybrid::javaobject>(
facebook::jni::alias_ref<facebook::jni::JByteBuffer>,
facebook::jni::alias_ref<jlongArray>,
jint,
jint,
facebook::jni::alias_ref<jhybriddata>)>("nativeNewTensor");
return jMethodNewTensor(
cls,
jTensorBuffer,
jTensorShape,
jdtype,
jmemoryFormat,
makeCxxInstance(tensor));
}
static at::Tensor newAtTensorFromJTensor(
facebook::jni::alias_ref<TensorHybrid::javaobject> jtensor) {
static auto cls = TensorHybrid::javaClassStatic();
static const auto dtypeMethod = cls->getMethod<jint()>("dtypeJniCode");
jint jdtype = dtypeMethod(jtensor);
static const auto memoryFormatMethod =
cls->getMethod<jint()>("memoryFormatJniCode");
jint jmemoryFormat = memoryFormatMethod(jtensor);
static const auto shapeField = cls->getField<jlongArray>("shape");
auto jshape = jtensor->getFieldValue(shapeField);
static auto dataBufferMethod = cls->getMethod<
facebook::jni::local_ref<facebook::jni::JBuffer::javaobject>()>(
"getRawDataBuffer");
facebook::jni::local_ref<facebook::jni::JBuffer> jbuffer =
dataBufferMethod(jtensor);
return newAtTensor(jbuffer, jshape, jdtype, jmemoryFormat);
}
at::Tensor tensor() const {
return tensor_;
}
private:
friend HybridBase;
at::Tensor tensor_;
};
Domain
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free