index_select_scale_add Class — pytorch Architecture
Architecture documentation for the index_select_scale_add class in EmbeddingBag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/EmbeddingBag.cpp lines 496–549
template <typename data_t, typename index_t>
std::enable_if_t<std::is_same_v<data_t, double>, void>
index_select_scale_add(
const Tensor& select_indices,
const Tensor& add_indices,
const Tensor& scale,
const Tensor& src,
Tensor& output,
[[maybe_unused]] const Tensor& offsets,
[[maybe_unused]] bool include_last_offset,
Tensor& bag_size,
index_t padding_idx,
[[maybe_unused]] _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
AT_ASSERT(select_indices.numel() == add_indices.numel());
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* src_data = src.const_data_ptr<data_t>();
auto* output_data = output.data_ptr<data_t>();
index_t* bag_size_data = nullptr;
if (bag_size.defined()) {
bag_size_data = bag_size.data_ptr<index_t>();
}
auto numel = add_indices.numel();
int64_t ddim = src.size(1);
auto vocab_size = src.size(0);
auto src_stride0 = src.strides()[0];
auto src_stride1 = src.strides()[1];
auto output_stride0 = output.strides()[0];
auto output_stride1 = output.strides()[1];
auto* scale_data = scale.const_data_ptr<data_t>();
auto scale_stride = scale.strides()[0];
for (const auto i : c10::irange(numel)) {
// We can skip indices equal to padding_idx so they are not included in
// the reduction
auto idx = select_indices_data[i];
TORCH_CHECK(
idx >= 0 && idx < vocab_size,
"embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
idx);
if (idx != padding_idx) {
auto* src_base = src_data + src_stride0 * idx;
auto* output_base = output_data + output_stride0 * add_indices_data[i];
auto element_scale = scale_data[i * scale_stride];
for (const auto j : c10::irange(ddim)) {
output_base[j * output_stride1] += src_base[j * src_stride1] * element_scale;
}
} else if (bag_size_data) {
// Decrement bag_size to reflect that the index is padded
bag_size_data[add_indices_data[i]]--;
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free