is_same_v Class — pytorch Architecture
Architecture documentation for the is_same_v class in EmbeddingBag.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/EmbeddingBag.cpp lines 188–368
template <typename data_t, typename index_t>
std::enable_if_t<
std::is_same_v<data_t, at::Half> || std::is_same_v<data_t, at::BFloat16>,
void>
index_select_add(
const Tensor& select_indices,
const Tensor& add_indices,
const Tensor& src,
Tensor& output,
const Tensor& offsets,
bool include_last_offset,
Tensor& bag_size,
index_t padding_idx,
_EmbeddingBagKernelCache* fbgemm_kernel_cache) {
int64_t ddim = src.size(1);
auto* select_indices_data = select_indices.const_data_ptr<index_t>();
auto* output_data = output.data_ptr<data_t>();
if (is_fast_path_index_select(src, output, padding_idx)) {
auto src_contig = src.contiguous();
auto* src_data = src_contig.const_data_ptr<data_t>();
int64_t output_size = offsets.numel() - 1;
auto* offsets_data = offsets.const_data_ptr<index_t>();
std::vector<index_t> offsets_include_last;
if (include_last_offset) {
output_size = offsets.numel() - 1;
} else {
output_size = offsets.numel();
offsets_include_last.resize(offsets.numel() + 1);
if (offsets.numel() > 0) {
std::memcpy(
offsets_include_last.data(),
offsets.const_data_ptr<index_t>(),
sizeof(index_t) * offsets.numel());
}
offsets_include_last[offsets.numel()] = select_indices.numel();
offsets_data = offsets_include_last.data();
}
#if defined(USE_FBGEMM)
constexpr bool isbf16 = std::is_same_v<data_t, at::Half> ? false : true;
auto kernel_16bit_index_t = fbgemm_kernel_cache
? fbgemm_kernel_cache
->getCallback</* has_weight */ false, index_t, uint16_t>(ddim)
: fbgemm::GenerateEmbeddingSpMDM<uint16_t, index_t, index_t, uint16_t>(
/* block_size */ ddim,
/* has_weight */ false,
/* normalize_by_lengths */ false,
/* prefetch */ 16,
/* is_weight_positional */ false,
/* use_offsets */ true,
/* is_bf16_out */ isbf16,
/* is_bf16_in */ isbf16);
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
bool success = kernel_16bit_index_t(
/* output_size */ end_idx - start_idx,
/* index_size */ offsets_data[end_idx] - offsets_data[start_idx],
/* data_size */ src.size(0),
/* input */ reinterpret_cast<const uint16_t*>(src_data),
/* indices */ select_indices_data + offsets_data[start_idx],
/* offsets_or_lengths */ offsets_data + start_idx,
/* weights */ nullptr,
/* output */
reinterpret_cast<uint16_t*>(output_data + start_idx * ddim));
if (!success) {
fbgemm_spmdm_report_error_(
end_idx - start_idx,
offsets_data[end_idx] - offsets_data[start_idx],
src.size(0),
offsets_data + start_idx,
select_indices_data + offsets_data[start_idx]);
}
});
#else
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 = at::zeros({output_size, ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
using bVec = vec::Vectorized<BFloat16>;
using fVec = vec::Vectorized<float>;
at::parallel_for(
0, output_size, 1, [&](index_t start_idx, index_t end_idx) {
caffe2::EmbeddingLookupIdx(
/*block_size=*/ddim,
/*output_size=*/end_idx - start_idx,
/*index_size=*/offsets_data[end_idx] - offsets_data[start_idx],
/*data_size=*/src.size(0),
/*input=*/src_data,
/*indices=*/select_indices_data + offsets_data[start_idx],
/*offsets=*/offsets_data + start_idx,
/*weights=*/nullptr,
/*scale_bias=*/nullptr,
/*normalize_by_lengths=*/false,
/*out=*/output_data_fp32 + start_idx * ddim);
for (int64_t i = start_idx; i < end_idx; i++) {
// Convert FP32 intermediate buffer result back to 16 bit for
// output dtype
if constexpr (std::is_same_v<data_t, at::Half>) {
// FP16
for (const auto d : c10::irange(ddim)) {
(output_data + i * ddim)[d] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
} else {
// BF16
int64_t d = 0;
for (; d < ddim - (ddim % bVec::size()); d += bVec::size()) {
fVec temp_fp32_0 = fVec::loadu(output_data_fp32 + ddim * i + d);
fVec temp_fp32_1 =
fVec::loadu(output_data_fp32 + ddim * i + d + fVec::size());
convert_float_bfloat16(temp_fp32_0, temp_fp32_1)
.store(output_data + i * ddim + d);
}
for (; d < ddim; d++) {
(output_data + i * ddim)[d] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
}
}
});
#endif
} else {
TORCH_CHECK(select_indices.numel() == add_indices.numel());
auto* src_data = src.const_data_ptr<data_t>();
auto* add_indices_data = add_indices.const_data_ptr<index_t>();
index_t* bag_size_data = nullptr;
if (bag_size.defined()) {
bag_size_data = bag_size.data_ptr<index_t>();
}
auto vocab_size = src.size(0);
auto src_stride0 = src.strides()[0];
auto src_stride1 = src.strides()[1];
auto output_stride0 = output.strides()[0];
auto output_stride1 = output.strides()[1];
auto numel = add_indices.numel();
Tensor src_fp32 = at::empty({ddim}, src.options().dtype(at::kFloat));
auto* src_data_fp32 = src_fp32.mutable_data_ptr<float>();
// Initialize the intermediate output buffer to be 0.
Tensor output_fp32 =
at::zeros({output.size(0), ddim}, output.options().dtype(at::kFloat));
auto* output_data_fp32 = output_fp32.data_ptr<float>();
for (const auto i : c10::irange(numel)) {
// We can skip indices equal to padding_idx so they are not included in
// the reduction
auto idx = select_indices_data[i];
TORCH_CHECK(
idx >= 0 && idx < vocab_size,
"embedding_bag: Expected idx >= 0 && idx < num_embeddings but found idx to be ",
idx);
if (idx != padding_idx) {
// Copy src_data + src_stride0 * idx to src_data_fp32
for (const auto d : c10::irange(ddim)) {
src_data_fp32[d] = static_cast<float>(
(src_data + src_stride0 * idx)[d * src_stride1]);
}
at::native::cpublas::axpy<float>(
ddim,
1,
src_data_fp32,
1,
output_data_fp32 + ddim * add_indices_data[i],
1);
} else if (bag_size_data) {
// Decrement bag_size to reflect that the index is padded
bag_size_data[add_indices_data[i]]--;
}
}
for (const auto i : c10::irange(output.size(0))) {
// Convert FP32 intermediate buffer result back to 16 bit for output
// dtype
for (const auto d : c10::irange(ddim)) {
(output_data + output_stride0 * i)[d * output_stride1] =
static_cast<data_t>((output_data_fp32 + ddim * i)[d]);
}
}
}
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free