split_batch_dim_to_32bit_out Class — pytorch Architecture
Architecture documentation for the split_batch_dim_to_32bit_out class in Conv_miopen.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/miopen/Conv_miopen.cpp lines 648–723
template <typename func_t>
static inline void split_batch_dim_to_32bit_out(
const at::Tensor& output,
const at::Tensor& input,
const at::Tensor& weight,
IntArrayRef padding,
IntArrayRef stride,
IntArrayRef dilation,
int64_t groups,
bool benchmark,
bool deterministic,
bool depthwise,
int64_t max_worksize,
func_t func_indexing) {
constexpr int64_t int_max = std::numeric_limits<int>::max();
const int64_t ni = input.numel();
const int64_t no = output.numel();
// Assume the shape of the tensor is (N, C, D1, D2, ...)
// if N * C * D1 * D2 * ... <= int_max, then no need to split at all
if (ni <= int_max && no <= int_max) {
func_indexing(
output,
input,
weight,
padding,
stride,
dilation,
groups,
benchmark,
deterministic,
depthwise);
return;
}
// else, if C * D1 * D2 * ... <= int_max, then we just need to split across
// the N dimension
//
// Here we use a simple heuristics to determine the size of each split
// We don't max out the 2^31 address space because this number is super
// large and very likely to get an OOM.
int64_t n = output.size(0);
int64_t max_inner_size = std::max<int64_t>(ni, no) / n;
int64_t split_size = std::max<int64_t>(max_worksize / max_inner_size, 1L);
int64_t num_splits = (n + split_size - 1) / split_size;
if (split_size * max_inner_size < int_max) {
for (const auto i : c10::irange(num_splits)) {
int64_t start = split_size * i;
int64_t split_size_ = std::min<int64_t>(split_size, n - start);
Tensor input_ = input.narrow(0, start, split_size_);
Tensor output_ = output.narrow(0, start, split_size_);
func_indexing(
output_,
input_,
weight,
padding,
stride,
dilation,
groups,
benchmark,
deterministic,
depthwise);
}
return;
}
// MIOpen supports 64-bit indexing via miopenSetTensorDescriptorV2 API.
func_indexing(
output,
input,
weight,
padding,
stride,
dilation,
groups,
benchmark,
deterministic,
depthwise);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free