ComputeLocationBase Class — pytorch Architecture
Architecture documentation for the ComputeLocationBase class in GridSamplerKernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/GridSamplerKernel.cpp lines 192–272
template<typename scalar_t>
struct ComputeLocationBase<scalar_t, /*align_corners=*/true> {
using Vec = Vectorized<scalar_t>;
// values are clipped to between 0 and max_val
const scalar_t max_val;
// unnormalization scaling factor
const scalar_t scaling_factor;
// reflection parameters: reflected coordinates land in [low, low+span] inclusive
const scalar_t low; // only used when align_corners=False
const scalar_t twice_span;
// if the reflecting span is empty, all reflected coords are set to 0
const bool empty;
ComputeLocationBase(int64_t size)
: max_val(static_cast<scalar_t>(size - 1))
, scaling_factor(static_cast<scalar_t>(size - 1) / 2)
, low(static_cast<scalar_t>(0))
, twice_span(static_cast<scalar_t>(size - 1) * 2)
, empty(size <= 1) {}
inline Vec unnormalize(const Vec &in) const {
return (in + Vec(1)) * Vec(scaling_factor);
}
inline Vec clip_coordinates(const Vec &in) const {
// Invert order of clamp_min operands in order to clamp Nans to zero
return clamp_max(Vec(max_val), clamp_min(Vec(0), in));
}
// same as clip_coordinates but also returns the gradient multiplier
inline std::pair<Vec, Vec> clip_coordinates_get_grad(const Vec &in) const {
using int_t = int_same_size_t<scalar_t>;
auto bounded_lo = maximum(in, Vec(0));
// Integral type equality comparison is very very fast because it just looks
// at the bits. Casting is free too. So we use the following pattern instead
// of comparison + blendv.
// Note that it is important for the gradient calculation that borders
// are considered out of bounds.
auto in_bound_lo = cast<scalar_t>(cast<int_t>(bounded_lo) != cast<int_t>(Vec(0)));
auto res = minimum(bounded_lo, Vec(max_val));
auto in_bound_hi = cast<scalar_t>(cast<int_t>(res) != cast<int_t>(Vec(max_val)));
return std::make_pair(res, in_bound_lo & in_bound_hi);
}
inline Vec reflect_coordinates(const Vec &in) const {
if (empty) {
return Vec(0);
}
Vec twice_span_vec(twice_span);
auto abs_in = in.abs();
auto fdouble_flips = abs_in / twice_span_vec;
auto double_flips = fdouble_flips.trunc();
auto extra = abs_in - double_flips * twice_span_vec;
// Now we need to test if extra > max_val to find out if another flip is
// needed. The following comparison does that and returns the correct
// flipped value.
return minimum(extra, twice_span_vec - extra);
}
// same as reflect_coordinates but also returns the gradient multiplier
inline std::pair<Vec, Vec> reflect_coordinates_get_grad(const Vec &in) const {
if (empty) {
return std::make_pair(Vec(0), Vec(0));
}
Vec twice_span_vec(twice_span);
auto neg_in = in < Vec(0);
auto abs_in = in.abs();
auto fdouble_flips = abs_in / twice_span_vec;
auto double_flips = fdouble_flips.trunc();
auto extra = abs_in - double_flips * twice_span_vec;
auto reflected_extra = twice_span_vec - extra;
auto one_more_flip = extra > reflected_extra;
return std::make_pair(
Vec::blendv(extra, reflected_extra, one_more_flip),
Vec::blendv(Vec(1), Vec(-1), one_more_flip ^ neg_in)
);
}
};
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free