avg_pool3d_out_frame Class — pytorch Architecture
Architecture documentation for the avg_pool3d_out_frame class in AveragePool3d.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/AveragePool3d.cpp lines 155–244
template <typename scalar_t>
void avg_pool3d_out_frame(
const scalar_t *input_p,
scalar_t *output_p,
int64_t nslices,
int64_t itime,
int64_t iwidth,
int64_t iheight,
int64_t otime,
int64_t owidth,
int64_t oheight,
int kT,
int kW,
int kH,
int dT,
int dW,
int dH,
int padT,
int padW,
int padH,
bool count_include_pad,
std::optional<int64_t> divisor_override)
{
at::parallel_for(0, nslices, 0, [&](int64_t start, int64_t end) {
for (const auto k : c10::irange(start, end)) {
/* local pointers. */
const scalar_t *ip = input_p + k * itime * iwidth * iheight;
scalar_t *op = output_p + k * otime * owidth * oheight;
for (int64_t i = 0; i < otime * oheight * owidth; ++i)
*(op + i) = 0;
/* loop over output */
for (int64_t ti = 0; ti < otime; ti++)
{
for (int64_t i = 0; i < oheight; i++)
{
for (int64_t j = 0; j < owidth; j++)
{
/* compute pool range. */
int64_t tstart = ti * dT - padT;
int64_t hstart = i * dH - padH;
int64_t wstart = j * dW - padW;
int64_t tend = std::min(tstart + kT, itime + padT);
int64_t hend = std::min(hstart + kH, iheight + padH);
int64_t wend = std::min(wstart + kW, iwidth + padW);
int64_t pool_size = (tend - tstart) * (hend - hstart) * (wend - wstart);
tstart = std::max(tstart, static_cast<int64_t>(0));
hstart = std::max(hstart, static_cast<int64_t>(0));
wstart = std::max(wstart, static_cast<int64_t>(0));
tend = std::min(tend, itime);
hend = std::min(hend, iheight);
wend = std::min(wend, iwidth);
if (tstart >= tend || hstart >= hend || wstart >= wend) {
++op;
continue;
}
int64_t divide_factor = 0;
if (divisor_override.has_value()) {
divide_factor = divisor_override.value();
} else {
if(count_include_pad) {
divide_factor = pool_size;
} else {
divide_factor = (tend - tstart) * (hend - hstart) * (wend - wstart);
}
}
/* compute local sum: */
scalar_t sum = 0.0;
for (int64_t z = tstart; z < tend; z++)
{
for (int64_t y = hstart; y < hend; y++)
{
for (int64_t x = wstart; x < wend; x++)
{
sum += *(ip + z * iwidth * iheight + y * iwidth + x);
}
}
}
/* set output to local max */
*op++ += sum / divide_factor;
}
}
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free