embedding_bag_cpu_max_out Class — pytorch Architecture
Architecture documentation for the embedding_bag_cpu_max_out class in EmbeddingBag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/EmbeddingBag.cpp lines 1063–1126
template <typename scalar_t>
static void embedding_bag_cpu_max_out(
Tensor* max_indices,
const Tensor& weight,
const Tensor& indices,
const Tensor& offset2bag,
const Tensor& output,
[[maybe_unused]] bool include_last_offset,
Tensor& bag_size,
int64_t padding_idx) {
int64_t numIndices = indices.numel();
int64_t featureSize = weight.size(1);
int64_t vocab_size = weight.size(0);
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cpu_max_out", [&] {
auto* indices_data = indices.const_data_ptr<index_t>();
auto* offset2bag_data = offset2bag.data_ptr<index_t>();
index_t* max_indices_data = nullptr;
int64_t max_indices_stride = 0;
if (max_indices) {
max_indices_data = max_indices->data_ptr<index_t>();
max_indices_stride = max_indices->strides()[0];
}
auto* weight_data = weight.const_data_ptr<scalar_t>();
auto* output_data = output.data_ptr<scalar_t>();
auto* bag_size_data = bag_size.data_ptr<index_t>();
auto weight_stride0 = weight.strides()[0];
auto weight_stride1 = weight.strides()[1];
auto output_stride = output.strides()[0];
int64_t numBags = bag_size.size(0);
std::vector<bool> bag_empty(numBags, true);
for (const auto i : c10::irange(numIndices)) {
auto bag = offset2bag_data[i];
auto word_idx = indices_data[i];
TORCH_CHECK(
word_idx >= 0 && word_idx < vocab_size,
"embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
word_idx);
if (word_idx != static_cast<index_t>(padding_idx)) {
bool is_first_for_bag = bag_empty[bag];
for (const auto dim : c10::irange(featureSize)) {
auto& current_item = output_data[output_stride * bag + dim];
auto weight_item =
weight_data[weight_stride0 * word_idx + dim * weight_stride1];
if (is_first_for_bag || (weight_item > current_item)) {
current_item = weight_item;
if (max_indices_data) {
max_indices_data[max_indices_stride * bag + dim] = word_idx;
}
}
}
if (is_first_for_bag) {
bag_empty[bag] = false;
}
} else {
// Decrement bag_size to reflect that the index is padded
bag_size_data[bag]--;
}
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free