check_cudnn_depthwise_workload Class — pytorch Architecture
Architecture documentation for the check_cudnn_depthwise_workload class in Convolution.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/Convolution.cpp lines 97–217
template <typename T>
static bool check_cudnn_depthwise_workload(const at::Tensor& input, T stride) {
auto w = at::symint::size<T>(input, 3); // same as h
auto ch = at::symint::size<T>(input, 1);
auto bs = at::symint::size<T>(input, 0);
if (stride==1) {
if (w >= 7) {
// All batch sizes and nb_channels
if (w >= 112) {
return true;
}
// large nb_channels
if (ch >= 1024) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if (w >= 56) {
return true;
} else if (bs >= 32) {
return true;
}
}
// batch_size specific
if (bs >= 128) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if (ch >= 512) {
return true;
} else if (ch >= 64) {
if (w >= 14) {
return true;
}
} else if ((ch >= 32) && (w >=28)) {
return true;
}
} else if (bs >= 64) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 32) && (w >= 28)) {
return true;
}
} else if (bs >= 32) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 256) && (w >= 14)) {
return true;
} else if ((ch >= 128) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 16) {
if ((ch >= 1024) && (w >= 14)) {
return true;
}
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 256) && (w >= 28)) {
return true;
} else if ((ch >= 32) && (w >= 56)) {
return true;
}
} else if (bs >= 8) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 512) && (w >= 28)) {
return true;
} else if ((ch >= 64) && (w >= 56)) {
return true;
}
}
}
} else if (stride==2) {
if (ch < 256) {
return false;
}
if (w >= 7) {
if (bs >= 128) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if (ch >= 1024) {
return true;
} else if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 64) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 512) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 32) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 1024) && (w >= 14)) {
return true;
} else if (w >= 28) {
return true;
}
} else if (bs >= 16) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 512) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 8) {
// NOLINTNEXTLINE(bugprone-branch-clone,cppcoreguidelines-avoid-magic-numbers)
if ((ch >= 1024) && (w >= 28)) {
return true;
} else if (w >= 56) {
return true;
}
} else if (bs >= 1) {
if ((ch >= 512) && (w >=112)) {
return true;
}
}
}
}
return false;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free