check_cudnn_depthwise_workload_with_filter Class — pytorch Architecture
Architecture documentation for the check_cudnn_depthwise_workload_with_filter class in Convolution.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Convolution.cpp lines 220–253
template <typename T>
static bool check_cudnn_depthwise_workload_with_filter(const at::Tensor& input, T stride, const at::Tensor& weight) {
// 1D conv
if(at::symint::size<T>(input, 2) == 1 && stride == 1){
return true;
}
// 2d conv
// only square filters
if (at::symint::size<T>(weight, 2) != at::symint::size<T>(weight, 3)) return false;
auto filter = at::symint::size<T>(weight, 3);
// only 1/3/5 filter
if (filter != 1 && filter != 3 && filter != 5) return false;
// we don't enforce square input but only check width to reduce heuristic space
if (at::symint::size<T>(input, 3) < 7) return false; // min width 7
auto w = at::symint::size<T>(input, 3);
// only 1/2 stride, use cudnn for all stride 1
if (stride == 1) return true;
if (stride != 2) return false;
auto ch = at::symint::size<T>(input, 1);
auto bs = at::symint::size<T>(input, 0);
// special case since bs1 show good perf in lots of cases
if (bs == 1) {
if (filter == 1 && w <= 28) return true;
if (filter == 3 || filter == 5) return true;
} else {
if (filter == 1 && bs <= 16 && ch >= 128 && w <= 7) return true;
if (filter == 3 || filter == 5) {
if ((ch >= 512) || (ch >= 256 && w >= 28)) return true;
}
}
return false;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free