_embedding_bag_dense_backward_cpu_sum_mean Class — pytorch Architecture
Architecture documentation for the _embedding_bag_dense_backward_cpu_sum_mean class in EmbeddingBag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/EmbeddingBag.cpp lines 1509–1594
template <typename scalar_t>
static void _embedding_bag_dense_backward_cpu_sum_mean(
const Tensor& grad,
const Tensor& indices_,
const Tensor& offset2bag_,
const Tensor& bag_size_,
int64_t num_weights,
bool scale_grad_by_freq,
int64_t mode,
const Tensor& per_sample_weights_,
Tensor& index_grad_weight,
int64_t padding_idx) {
auto ind_sort_ = indices_.sort();
auto const& indices = std::get<0>(ind_sort_);
auto const& ind_sort = std::get<1>(ind_sort_);
auto offset2bag = offset2bag_.index_select(0, ind_sort);
std::optional<Tensor> per_sample_weights;
const scalar_t* per_sample_weights_data = nullptr;
std::optional<int64_t> per_sample_weights_stride;
if (per_sample_weights_.defined()) {
per_sample_weights = per_sample_weights_.index_select(0, ind_sort);
per_sample_weights_data = per_sample_weights->const_data_ptr<scalar_t>();
per_sample_weights_stride = per_sample_weights->strides()[0];
}
int64_t numel = indices.numel();
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean",
[&indices, &offset2bag, &bag_size_, &num_weights, &numel, &per_sample_weights,
&per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq,
&grad, &index_grad_weight, &padding_idx] {
auto* indices_data = indices.const_data_ptr<index_t>();
auto* offset2bag_data = offset2bag.const_data_ptr<index_t>();
auto* bag_size_data = bag_size_.const_data_ptr<index_t>();
auto counts = compute_counts(num_weights, indices_data, numel);
auto next_unique_index_idx =
compute_counts_uniq(num_weights, indices_data, numel, counts);
auto loop =
[&next_unique_index_idx, &indices_data, &offset2bag_data, &bag_size_data, &per_sample_weights,
&mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq,
&counts, &grad, &index_grad_weight, &padding_idx
](index_t start, index_t end) {
for (index_t i = start; i < end; i++) {
index_t indices_start = i == 0 ? 0 : next_unique_index_idx[i - 1];
index_t index = indices_data[indices_start];
if (index != static_cast<index_t>(padding_idx)) {
for (index_t j = indices_start; j < next_unique_index_idx[i]; j++) {
index_t source = offset2bag_data[j];
double scale = 1.0;
if (per_sample_weights) {
AT_ASSERT(mode == EmbeddingBagMode::SUM);
scale = per_sample_weights_data[*per_sample_weights_stride * j];
}
if (scale_grad_by_freq) {
scale /= counts[indices_data[i]];
}
if (mode == EmbeddingBagMode::MEAN) {
auto bag_size = bag_size_data[source];
if (bag_size != 0) {
scale /= bag_size;
}
}
int64_t ddim = grad.size(1);
auto igwd = index_grad_weight.data_ptr<scalar_t>();
auto gd = grad.const_data_ptr<scalar_t>();
at::native::cpublas::axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
igwd + ddim * index, 1);
}
}
}
};
if (numel > 1000) {
at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
} else {
loop(0, (int64_t)next_unique_index_idx.size());
}
});
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free