_select_conv_backend Class — pytorch Architecture
Architecture documentation for the _select_conv_backend class in Convolution.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Convolution.cpp lines 1189–1296
template <typename T>
static ConvBackend _select_conv_backend(
const Tensor& input,
const Tensor& weight,
const std::optional<Tensor>& bias,
const at::OptionalArrayRef<T> bias_sizes_opt,
const bool need_backward,
const ConvParams<T>& params) {
// don't send empty inputs through backends
if (at::symint::size<T>(input, 0) == 0 || at::symint::size<T>(input, 1) == 0) {
return input.is_mkldnn() ? ConvBackend::MkldnnEmpty : ConvBackend::Empty;
} else if (at::symint::numel<T>(input) == 0) {
TORCH_CHECK(false, "Only zero batch or zero channel inputs are supported, but got input shape: ", at::symint::sizes<T>(input));
}
if (params.is_depthwise(input, weight)) {
if (params.use_cudnn_depthwise(input, weight)) {
return ConvBackend::Cudnn;
} else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
return ConvBackend::MiopenDepthwise;
} else {
if (input.ndimension() == 4) {
return ConvBackend::CudaDepthwise2d;
} else if (input.ndimension() == 5) {
return ConvBackend::CudaDepthwise3d;
} else {
// unsupported
}
}
} else if (params.use_cudnn(input, weight)) {
if (params.transposed) {
return ConvBackend::CudnnTranspose;
} else {
return ConvBackend::Cudnn;
}
} else if (params.use_miopen(input, weight, bias_sizes_opt.has_value())) {
if (params.transposed) {
return ConvBackend::MiopenTranspose;
} else {
return ConvBackend::Miopen;
}
} else if (params.use_mkldnn(input, weight)) {
if (params.transposed) {
return ConvBackend::MkldnnTranspose;
} else {
return ConvBackend::Mkldnn;
}
} else if (!need_backward && params.use_xnnpack(input, weight, bias_sizes_opt)) {
// Using prepacked conv is preferred, but XNNPACK is still the fastest
// option for NHWC.
return ConvBackend::Xnnpack2d;
// 3x3 depthwith convolutions implementation is inference only
} else if (!need_backward && params.use_cpu_depthwise3x3_winograd(input, weight, bias)) {
return ConvBackend::Winograd3x3Depthwise;
} else if (
!params.transposed && (input.ndimension() == 5) &&
(input.device().is_cpu()) &&
!params.is_dilated()) {
// fast path for grouped conv3d
return ConvBackend::Slow3d;
} else if (input.device().is_cpu() || input.is_cuda()) {
// backends without support for groups
if (params.transposed) {
if (input.ndimension() == 4) {
return ConvBackend::SlowTranspose2d;
} else if (input.ndimension() == 5) {
return ConvBackend::SlowTranspose3d;
} else {
// unsupported
}
} else { /* Not transposed */
if (input.ndimension() == 4) {
if (params.is_dilated()) {
return ConvBackend::SlowDilated2d;
} else { /* dim == 4, non-dilated */
if (params.use_nnpack(input, weight)) {
return ConvBackend::NnpackSpatial;
} else {
/* CPU implementation has specialized MM kernels
for non-dilated case here */
return ConvBackend::Slow2d;
}
}
} else if (input.ndimension() == 5 && (input.is_cuda() || params.is_dilated())) {
return ConvBackend::SlowDilated3d;
} else if (input.ndimension() == 5) { /* dim == 5, CPU, non-dilated */
/* CPU implementation has specialized MM kernels
for non-dilated case here */
return ConvBackend::Slow3d;
} else {
// unsupported
}
}
} else if (params.use_mps(input, weight)) {
if (params.transposed) {
return ConvBackend::MpsTranspose;
} else {
return ConvBackend::Mps;
}
} else {
// Only reach here when input is backend with out-of-source implementation.
return ConvBackend::Overrideable;
}
// Error out if no suitable backend was found.
TORCH_CHECK(false, "unsupported ConvNd parameters");
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free