grid_sampler_3d_cpu_impl Class — pytorch Architecture
Architecture documentation for the grid_sampler_3d_cpu_impl class in GridSampler.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/GridSampler.cpp lines 42–203
template<typename scalar_t>
Tensor grid_sampler_3d_cpu_impl(const Tensor& input, const Tensor& grid,
GridSamplerInterpolation interpolation_mode,
GridSamplerPadding padding_mode,
bool align_corners) {
// See NOTE [ grid_sampler Native Functions ].
// Add checks here in case this is called instead of grid_sampler.
check_grid_sampler_common(input, grid);
check_grid_sampler_3d(
input, grid, static_cast<int64_t>(interpolation_mode));
int64_t N = input.size(0);
int64_t C = input.size(1);
int64_t inp_D = input.size(2);
int64_t inp_H = input.size(3);
int64_t inp_W = input.size(4);
int64_t out_D = grid.size(1);
int64_t out_H = grid.size(2);
int64_t out_W = grid.size(3);
auto output = at::empty({N, C, out_D, out_H, out_W}, input.options());
if (output.numel() == 0) {
return output;
}
int64_t inp_sN = input.stride(0);
int64_t inp_sC = input.stride(1);
int64_t inp_sD = input.stride(2);
int64_t inp_sH = input.stride(3);
int64_t inp_sW = input.stride(4);
int64_t grid_sN = grid.stride(0);
int64_t grid_sD = grid.stride(1);
int64_t grid_sH = grid.stride(2);
int64_t grid_sW = grid.stride(3);
int64_t grid_sCoor = grid.stride(4);
int64_t out_sN = output.stride(0);
int64_t out_sC = output.stride(1);
int64_t out_sD = output.stride(2);
int64_t out_sH = output.stride(3);
int64_t out_sW = output.stride(4);
const scalar_t *inp_ptr = input.const_data_ptr<scalar_t>();
scalar_t *out_ptr = output.data_ptr<scalar_t>();
const scalar_t *grid_ptr = grid.const_data_ptr<scalar_t>();
// loop over each output pixel
at::parallel_for(0, N, 0, [&](int64_t start, int64_t end) {
for (const auto n : c10::irange(start, end)) {
const scalar_t *grid_ptr_N = grid_ptr + n * grid_sN;
const scalar_t *inp_ptr_N = inp_ptr + n * inp_sN;
for (const auto d : c10::irange(out_D)) {
for (const auto h : c10::irange(out_H)) {
for (const auto w : c10::irange(out_W)) {
// get the corresponding input x, y, z coordinates from grid
const scalar_t *grid_ptr_NDHW = grid_ptr_N + d * grid_sD + h * grid_sH + w * grid_sW;
scalar_t ix = *grid_ptr_NDHW;
scalar_t iy = grid_ptr_NDHW[grid_sCoor];
scalar_t iz = grid_ptr_NDHW[2 * grid_sCoor];
ix = grid_sampler_compute_source_index(ix, inp_W, padding_mode, align_corners);
iy = grid_sampler_compute_source_index(iy, inp_H, padding_mode, align_corners);
iz = grid_sampler_compute_source_index(iz, inp_D, padding_mode, align_corners);
if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
int64_t ix_tnw = static_cast<int64_t>(std::floor(ix));
int64_t iy_tnw = static_cast<int64_t>(std::floor(iy));
int64_t iz_tnw = static_cast<int64_t>(std::floor(iz));
int64_t ix_tne = ix_tnw + 1;
int64_t iy_tne = iy_tnw;
int64_t iz_tne = iz_tnw;
int64_t ix_tsw = ix_tnw;
int64_t iy_tsw = iy_tnw + 1;
int64_t iz_tsw = iz_tnw;
int64_t ix_tse = ix_tnw + 1;
int64_t iy_tse = iy_tnw + 1;
int64_t iz_tse = iz_tnw;
int64_t ix_bnw = ix_tnw;
int64_t iy_bnw = iy_tnw;
int64_t iz_bnw = iz_tnw + 1;
int64_t ix_bne = ix_tnw + 1;
int64_t iy_bne = iy_tnw;
int64_t iz_bne = iz_tnw + 1;
int64_t ix_bsw = ix_tnw;
int64_t iy_bsw = iy_tnw + 1;
int64_t iz_bsw = iz_tnw + 1;
int64_t ix_bse = ix_tnw + 1;
int64_t iy_bse = iy_tnw + 1;
int64_t iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
// calculate bilinear weighted pixel value and set output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
const scalar_t *inp_ptr_NC = inp_ptr_N;
for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
*out_ptr_NCDHW = static_cast<scalar_t>(0);
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
}
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
}
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
}
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
}
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
}
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
}
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
}
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
}
}
} else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
int64_t iz_nearest = static_cast<int64_t>(std::nearbyint(iz));
// assign nearest neighbour pixel value to output pixel
scalar_t *out_ptr_NCDHW = out_ptr + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
const scalar_t *inp_ptr_NC = inp_ptr_N;
for (int64_t c = 0; c < C; ++c, out_ptr_NCDHW += out_sC, inp_ptr_NC += inp_sC) {
if (within_bounds_3d(iz_nearest, iy_nearest, ix_nearest, inp_D, inp_H, inp_W)) {
*out_ptr_NCDHW = inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + ix_nearest * inp_sW];
} else {
*out_ptr_NCDHW = static_cast<scalar_t>(0);
}
}
}
}
}
}
}
});
return output;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free