_embedding_bag_per_sample_weights_backward_cpu_template Class — pytorch Architecture
Architecture documentation for the _embedding_bag_per_sample_weights_backward_cpu_template class in EmbeddingBag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/EmbeddingBag.cpp lines 1645–1729
template<typename scalar_t>
static Tensor _embedding_bag_per_sample_weights_backward_cpu_template(
const Tensor& grad,
const Tensor& weight, // NB: embedding table, not per_sample_weights
const Tensor& indices_,
const Tensor& offsets_,
const Tensor& offset2bag,
int64_t mode,
int64_t padding_idx) {
TORCH_CHECK(
mode == EmbeddingBagMode::SUM,
"embedding_bag_backward: per_sample_weights only supported for mode='sum'");
AT_ASSERT(grad.dim() == 2);
auto embedding_features = grad.sizes()[1];
auto [indicesMaybeOwned, offsetsMaybeOwned] = promoteIndicesAndOffsets(indices_, offsets_);
const auto& indices = *indicesMaybeOwned;
const auto& offsets = *offsetsMaybeOwned;
AT_ASSERT(indices.dim() == 1);
auto num_samples = indices.size(0);
AT_ASSERT(weight.dim() == 2);
AT_ASSERT(weight.sizes()[1] == embedding_features);
auto output = at::zeros({num_samples}, grad.options());
auto indices_arg = TensorArg(indices, "indices", 1);
checkScalarTypes("embedding_bag", indices_arg, {kLong, kInt});
checkContiguous("embedding_bag", indices_arg);
Tensor offset2bag_;
if (indices.numel() != 0 && offset2bag.numel() == 0) {
offset2bag_ = at::zeros(
{indices.size(0) + 1}, offset2bag.options()); // offset2bag = [0 0 0 0 0]
make_offset2bag(offsets, offset2bag_);
at::native::resize_(offset2bag_, {indices.size(0)}, std::nullopt);
} else {
auto offset2bag_arg = TensorArg(offset2bag, "offset2bag", 1);
checkScalarTypes("embedding_bag", offset2bag_arg, {kLong, kInt});
checkContiguous("embedding_bag", offset2bag_arg);
offset2bag_ = offset2bag;
}
auto* grad_data = grad.const_data_ptr<scalar_t>();
auto grad_stride0 = grad.strides()[0];
auto grad_stride1 = grad.strides()[1];
auto* weight_data = weight.const_data_ptr<scalar_t>();
auto weight_stride0 = weight.strides()[0];
auto weight_stride1 = weight.strides()[1];
// explicitly capture all required variables to work around windows build
// TODO: fix this when windows can correctly capture variables in nested lambda
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cpu_template",
[&indices, &output, &offset2bag_, &num_samples, &embedding_features,
&grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0, &weight_stride1,
&padding_idx] () {
auto* indices_data = indices.const_data_ptr<index_t>();
// The following are contiguous
auto* output_data = output.data_ptr<scalar_t>();
auto* offset2bag_data = offset2bag_.const_data_ptr<index_t>();
// XXX: 64 was arbitrarily chosen. There is probably a sweet spot for this number.
parallel_for(0, num_samples, 64,
[&embedding_features, &grad_data, &grad_stride0, &grad_stride1, &weight_data, &weight_stride0,
&weight_stride1, &offset2bag_data, &indices_data, &output_data, &padding_idx](index_t begin, index_t end) {
for (index_t sample_idx = begin; sample_idx < end; sample_idx++) {
auto bag_idx = offset2bag_data[sample_idx];
auto embedding_idx = indices_data[sample_idx];
if (embedding_idx != static_cast<index_t>(padding_idx)) {
output_data[sample_idx] = dot_impl<scalar_t>(
embedding_features, grad_data + grad_stride0 * bag_idx, grad_stride1,
weight_data + weight_stride0 * embedding_idx, weight_stride1);
}
}
});
});
return output;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free