Home / Class/ _embedding_bag_dense_backward_cpu_sum_mean Class — pytorch Architecture

_embedding_bag_dense_backward_cpu_sum_mean Class — pytorch Architecture

Architecture documentation for the _embedding_bag_dense_backward_cpu_sum_mean class in EmbeddingBag.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/EmbeddingBag.cpp lines 1509–1594

template <typename scalar_t>
static void _embedding_bag_dense_backward_cpu_sum_mean(
    const Tensor& grad,
    const Tensor& indices_,
    const Tensor& offset2bag_,
    const Tensor& bag_size_,
    int64_t num_weights,
    bool scale_grad_by_freq,
    int64_t mode,
    const Tensor& per_sample_weights_,
    Tensor& index_grad_weight,
    int64_t padding_idx) {

  auto ind_sort_ = indices_.sort();
  auto const& indices = std::get<0>(ind_sort_);
  auto const& ind_sort = std::get<1>(ind_sort_);
  auto offset2bag = offset2bag_.index_select(0, ind_sort);

  std::optional<Tensor> per_sample_weights;
  const scalar_t* per_sample_weights_data = nullptr;
  std::optional<int64_t> per_sample_weights_stride;
  if (per_sample_weights_.defined()) {
    per_sample_weights = per_sample_weights_.index_select(0, ind_sort);
    per_sample_weights_data = per_sample_weights->const_data_ptr<scalar_t>();
    per_sample_weights_stride = per_sample_weights->strides()[0];
  }

  int64_t numel = indices.numel();

  // explicitly capture all required variables to work around windows build
  // TODO: fix this when windows can correctly capture variables in nested lambda
  AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_dense_backward_cpu_sum_mean",
    [&indices, &offset2bag, &bag_size_, &num_weights, &numel, &per_sample_weights,
      &per_sample_weights_data, &per_sample_weights_stride, &mode, &scale_grad_by_freq,
      &grad, &index_grad_weight, &padding_idx] {
    auto* indices_data = indices.const_data_ptr<index_t>();
    auto* offset2bag_data = offset2bag.const_data_ptr<index_t>();
    auto* bag_size_data = bag_size_.const_data_ptr<index_t>();

    auto counts = compute_counts(num_weights, indices_data, numel);
    auto next_unique_index_idx =
        compute_counts_uniq(num_weights, indices_data, numel, counts);

    auto loop =
      [&next_unique_index_idx, &indices_data, &offset2bag_data, &bag_size_data, &per_sample_weights,
        &mode, &per_sample_weights_data, &per_sample_weights_stride, &scale_grad_by_freq,
        &counts, &grad, &index_grad_weight, &padding_idx
      ](index_t start, index_t end) {
      for (index_t i = start; i < end; i++) {
        index_t indices_start = i == 0 ? 0 : next_unique_index_idx[i - 1];
        index_t index = indices_data[indices_start];

        if (index != static_cast<index_t>(padding_idx)) {
          for (index_t j = indices_start; j < next_unique_index_idx[i]; j++) {
            index_t source = offset2bag_data[j];
            double scale = 1.0;
            if (per_sample_weights) {
              AT_ASSERT(mode == EmbeddingBagMode::SUM);
              scale = per_sample_weights_data[*per_sample_weights_stride * j];
            }
            if (scale_grad_by_freq) {
              scale /= counts[indices_data[i]];
            }
            if (mode == EmbeddingBagMode::MEAN) {
              auto bag_size = bag_size_data[source];
              if (bag_size != 0) {
                scale /= bag_size;
              }
            }
            int64_t ddim = grad.size(1);
            auto igwd = index_grad_weight.data_ptr<scalar_t>();
            auto gd = grad.const_data_ptr<scalar_t>();
            at::native::cpublas::axpy<scalar_t>(ddim, (scalar_t)scale, gd + ddim * source, 1,
                        igwd + ddim * index, 1);
          }
        }
      }
    };

    if (numel > 1000) {
      at::parallel_for(0, (int64_t)next_unique_index_idx.size(), 0, loop);
    } else {
      loop(0, (int64_t)next_unique_index_idx.size());
    }
  });
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free