Home / Class/ is_same_v Class — pytorch Architecture

is_same_v Class — pytorch Architecture

Architecture documentation for the is_same_v class in EmbeddingBag.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/EmbeddingBag.cpp lines 188–368

template <typename data_t, typename index_t>
std::enable_if_t<
    std::is_same_v<data_t, at::Half> || std::is_same_v<data_t, at::BFloat16>,
    void>
index_select_add(
    const Tensor& select_indices,
    const Tensor& add_indices,
    const Tensor& src,
    Tensor& output,
    const Tensor& offsets,
    bool include_last_offset,
    Tensor& bag_size,
    index_t padding_idx,
    _EmbeddingBagKernelCache* fbgemm_kernel_cache) {
  int64_t ddim = src.size(1);
  auto* select_indices_data = select_indices.const_data_ptr<index_t>();
  auto* output_data = output.data_ptr<data_t>();

  if (is_fast_path_index_select(src, output, padding_idx)) {
    auto src_contig = src.contiguous();
    auto* src_data = src_contig.const_data_ptr<data_t>();
    int64_t output_size = offsets.numel() - 1;
    auto* offsets_data = offsets.const_data_ptr<index_t>();
    std::vector<index_t> offsets_include_last;

    if (include_last_offset) {
      output_size = offsets.numel() - 1;
    } else {
      output_size = offsets.numel();
      offsets_include_last.resize(offsets.numel() + 1);
      if (offsets.numel() > 0) {
        std::memcpy(
            offsets_include_last.data(),
            offsets.const_data_ptr<index_t>(),
            sizeof(index_t) * offsets.numel());
      }
      offsets_include_last[offsets.numel()] = select_indices.numel();
      offsets_data = offsets_include_last.data();
    }
#if defined(USE_FBGEMM)
    constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
    auto kernel_16bit_index_t = fbgemm_kernel_cache
        ? fbgemm_kernel_cache
              ->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
        : fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
              /* block_size */ ddim,
              /* has_weight */ false,
              /* normalize_by_lengths */ false,
              /* prefetch */ 16,
              /* is_weight_positional */ false,
              /* use_offsets */ true,
              /* is_bf16_out */ isbf16,
              /* is_bf16_in */ isbf16);
    at::parallel_for(
        0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
          bool success = kernel_16bit_index_t(
              /* output_size */ end_idx - start_idx,
              /* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
              /* data_size */ src.size(0),
              /* input */ reinterpret_cast<const uint16_t*>(src_data),
              /* indices */ select_indices_data + offsets_data[start_idx],
              /* offsets_or_lengths */ offsets_data + start_idx,
              /* weights */ nullptr,
              /* output */
              reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
          if (!success) {
            fbgemm_spmdm_report_error_(
                end_idx - start_idx,
                offsets_data[end_idx] - offsets_data[start_idx],
                src.size(0),
                offsets_data + start_idx,
                select_indices_data + offsets_data[start_idx]);
          }
        });
#else
    // Initialize the intermediate output buffer to be 0.
    Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
    auto* output_data_fp32 = output_fp32.data_ptr<float>();
    using bVec = vec::Vectorized<BFloat16>;
    using fVec = vec::Vectorized<float>;
    at::parallel_for(
        0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
          caffe2::EmbeddingLookupIdx(
              /*block_size=*/ddim,
              /*output_size=*/end_idx - start_idx,
              /*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
              /*data_size=*/src.size(0),
              /*input=*/src_data,
              /*indices=*/select_indices_data + offsets_data[start_idx],
              /*offsets=*/offsets_data + start_idx,
              /*weights=*/nullptr,
              /*scale_bias=*/nullptr,
              /*normalize_by_lengths=*/false,
              /*out=*/output_data_fp32 + start_idx * ddim);
          for (int64_t i = start_idx; i < end_idx; i++) {
            // Convert FP32 intermediate buffer result back to 16 bit for
            // output dtype
            if constexpr (std::is_same_v<data_t, at::Half>) {
              // FP16
              for (const auto d : c10::irange(ddim)) {
                (output_data + i * ddim)[d] =
                    static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
              }
            } else {
              // BF16
              int64_t d = 0;
              for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
                fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
                fVec temp_fp32_1 =
                    fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
                convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
                    .store(output_data + i * ddim + d);
              }
              for (; d < ddim; d++) {
                (output_data + i * ddim)[d] =
                    static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
              }
            }
          }
        });
#endif
  } else {
    TORCH_CHECK(select_indices.numel() == add_indices.numel());
    auto* src_data = src.const_data_ptr<data_t>();
    auto* add_indices_data = add_indices.const_data_ptr<index_t>();
    index_t* bag_size_data = nullptr;
    if (bag_size.defined()) {
      bag_size_data = bag_size.data_ptr<index_t>();
    }
    auto vocab_size = src.size(0);
    auto src_stride0 = src.strides()[0];
    auto src_stride1 = src.strides()[1];
    auto output_stride0 = output.strides()[0];
    auto output_stride1 = output.strides()[1];
    auto numel = add_indices.numel();

    Tensor src_fp32 = at::empty({ddim}, src.options().dtype(at::kFloat));
    auto* src_data_fp32 = src_fp32.mutable_data_ptr<float>();

    // Initialize the intermediate output buffer to be 0.
    Tensor output_fp32 =
        at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
    auto* output_data_fp32 = output_fp32.data_ptr<float>();

    for (const auto i : c10::irange(numel)) {
      // We can skip indices equal to padding_idx so they are not included in
      // the reduction
      auto idx = select_indices_data[i];
      TORCH_CHECK(
          idx >= 0 && idx < vocab_size,
          "embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
          idx);
      if (idx != padding_idx) {
        // Copy src_data + src_stride0 * idx to src_data_fp32
        for (const auto d : c10::irange(ddim)) {
          src_data_fp32[d] = static_cast<float>(
              (src_data + src_stride0 * idx)[d * src_stride1]);
        }
        at::native::cpublas::axpy<float>(
            ddim,
            1,
            src_data_fp32,
            1,
            output_data_fp32 + ddim * add_indices_data[i],
            1);

      } else if (bag_size_data) {
        // Decrement bag_size to reflect that the index is padded
        bag_size_data[add_indices_data[i]]--;
      }
    }
    for (const auto i : c10::irange(output.size(0))) {
      // Convert FP32 intermediate buffer result back to 16 bit for output
      // dtype
      for (const auto d : c10::irange(ddim)) {
        (output_data + output_stride0 * i)[d * output_stride1] =
            static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
      }
    }
  }
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free