avg_pool3d_backward_out_frame Class — pytorch Architecture
Architecture documentation for the avg_pool3d_backward_out_frame class in AveragePool3d.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/AveragePool3d.cpp lines 335–416
template <typename scalar_t>
void avg_pool3d_backward_out_frame(
scalar_t *gradInput_p,
const scalar_t *gradOutput_p,
int64_t nslices,
int64_t itime,
int64_t iwidth,
int64_t iheight,
int64_t otime,
int64_t owidth,
int64_t oheight,
int kT,
int kW,
int kH,
int dT,
int dW,
int dH,
int padT,
int padW,
int padH,
bool count_include_pad,
std::optional<int64_t> divisor_override)
{
at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
for (const auto k : c10::irange(start, end)) {
/* local pointers */
scalar_t *ip = gradInput_p + k * itime * iwidth * iheight;
const scalar_t *op = gradOutput_p + k * otime * owidth * oheight;
for (int64_t i = 0; i < itime*iwidth*iheight; i++)
*(ip + i) = 0;
/* loop over output */
for (int64_t ti = 0; ti < otime; ti++)
{
for (int64_t i = 0; i < oheight; i++)
{
for (int64_t j = 0; j < owidth; j++)
{
int64_t tstart = ti * dT - padT;
int64_t hstart = i * dH - padH;
int64_t wstart = j * dW - padW;
int64_t tend = std::min(tstart + kT, itime + padT);
int64_t hend = std::min(hstart + kH, iheight + padH);
int64_t wend = std::min(wstart + kW, iwidth + padW);
int64_t pool_size = (tend -tstart) * (hend - hstart) * (wend - wstart);
tstart = std::max(tstart, static_cast<int64_t>(0));
hstart = std::max(hstart, static_cast<int64_t>(0));
wstart = std::max(wstart, static_cast<int64_t>(0));
tend = std::min(tend, itime);
hend = std::min(hend, iheight);
wend = std::min(wend, iwidth);
int64_t divide_factor = 0;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
}
}
/* scatter gradients out to footprint: */
scalar_t val = *op++;
for (auto z = tstart; z < tend; z++)
{
for (auto y = hstart; y < hend; y++)
{
for (auto x = wstart; x < wend; x++)
{
*(ip + z * iheight * iwidth + y * iwidth + x) += val / divide_factor;
}
}
}
}
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free