Home / Class/ index_select_add Class — pytorch Architecture

index_select_add Class — pytorch Architecture

Architecture documentation for the index_select_add class in EmbeddingBag.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/EmbeddingBag.cpp lines 110–156

template <typename data_t, typename index_t>
std::enable_if_t<std::is_same_v<data_t, double>, void>
index_select_add(
    const Tensor& select_indices,
    const Tensor& add_indices,
    const Tensor& src,
    Tensor& output,
    [[maybe_unused]] const Tensor& offsets,
    [[maybe_unused]] bool include_last_offset,
    Tensor& bag_size,
    index_t padding_idx,
    [[maybe_unused]] _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
  TORCH_CHECK(select_indices.numel() == add_indices.numel());
  auto* add_indices_data = add_indices.const_data_ptr<index_t>();
  auto* select_indices_data = select_indices.const_data_ptr<index_t>();
  auto* src_data = src.const_data_ptr<data_t>();
  auto* output_data = output.data_ptr<data_t>();
  index_t* bag_size_data = nullptr;
  if (bag_size.defined()) {
    bag_size_data = bag_size.data_ptr<index_t>();
  }
  auto numel = add_indices.numel();
  int64_t ddim = src.size(1);
  auto vocab_size = src.size(0);
  auto src_stride0 = src.strides()[0];
  auto src_stride1 = src.strides()[1];
  auto output_stride0 = output.strides()[0];
  auto output_stride1 = output.strides()[1];

  for (const auto i : c10::irange(numel)) {
    // We can skip indices equal to padding_idx so they are not included in
    // the reduction
    auto idx = select_indices_data[i];
    TORCH_CHECK(
        idx >= 0 && idx < vocab_size,
        "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
        idx);
    if (idx != padding_idx) {
      at::native::cpublas::axpy<data_t>(ddim, 1,
              src_data + src_stride0 * idx, src_stride1,
              output_data + output_stride0 * add_indices_data[i], output_stride1);
    } else if (bag_size_data) {
      // Decrement bag_size to reflect that the index is padded
      bag_size_data[add_indices_data[i]]--;
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free