is_same_v Class — pytorch Architecture
Architecture documentation for the is_same_v class in batch_norm_kernel.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/cpu/batch_norm_kernel.cpp lines 708–770
template<typename scalar_t>
typename std::enable_if_t<!std::is_same_v<scalar_t, at::opmath_type<scalar_t>>, void>
batch_norm_cpu_contiguous_impl(Tensor& output, const Tensor& input,
const Tensor& weight, const Tensor& bias, const Tensor& save_mean, const Tensor& save_invstd,
const Tensor& running_mean, const Tensor& running_var, bool train, double eps) {
using opmath_t = at::opmath_type<scalar_t>;
using bVec = Vectorized<scalar_t>;
using fVec = Vectorized<opmath_t>;
int64_t n_batch = input.size(0);
int64_t n_channel = input.size(1);
int64_t image_size = input.numel() / n_batch / n_channel;
// use float as acc type
Tensor alpha = at::empty({n_channel}, input.options().dtype(kFloat));
Tensor beta = at::empty({n_channel}, input.options().dtype(kFloat));
opmath_t* alpha_data = alpha.mutable_data_ptr<opmath_t>();
opmath_t* beta_data = beta.data_ptr<opmath_t>();
const bool mixed_type = is_mixed_type(input, weight, bias, save_mean, save_invstd, running_mean, running_var);
if (mixed_type) {
batch_norm_cpu_collect_linear_and_constant_terms<opmath_t, opmath_t>(
alpha_data, beta_data, n_channel, weight, bias,
save_mean, save_invstd, running_mean, running_var, train, eps);
} else {
batch_norm_cpu_collect_linear_and_constant_terms<scalar_t, opmath_t>(
alpha_data, beta_data, n_channel, weight, bias,
save_mean, save_invstd, running_mean, running_var, train, eps);
}
scalar_t* output_data = output.data_ptr<scalar_t>();
const scalar_t* input_data = input.const_data_ptr<scalar_t>();
const int64_t loop_size = image_size - (image_size % bVec::size());
at::parallel_for(0, n_batch * n_channel, 1, [&](int64_t begin, int64_t end) {
int64_t n = 0;
int64_t c = 0;
data_index_init(begin, n, n_batch, c, n_channel);
for (const auto i : c10::irange(begin, end)) {
const scalar_t* input_ptr = input_data + i * image_size;
scalar_t* output_ptr = output_data + i * image_size;
const opmath_t alpha_val = alpha_data[c];
const opmath_t beta_val = beta_data[c];
const fVec alpha_fvec(alpha_val);
const fVec beta_fvec(beta_val);
int64_t d = 0;
for (; d < loop_size; d += bVec::size()) {
bVec data_bvec = bVec::loadu(input_ptr + d);
auto [data_fvec0, data_fvec1] = convert_to_float<scalar_t>(data_bvec);
fVec out_fvec0 = data_fvec0 * alpha_fvec + beta_fvec;
fVec out_fvec1 = data_fvec1 * alpha_fvec + beta_fvec;
bVec out_bvec = convert_from_float<scalar_t>(out_fvec0, out_fvec1);
out_bvec.store(output_ptr + d);
}
for (; d < image_size; d++) {
output_ptr[d] = scalar_t(opmath_t(input_ptr[d]) * alpha_val + beta_val);
}
// move on to next index
data_index_step(n, n_batch, c, n_channel);
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free