infer_size_impl Class — pytorch Architecture
Architecture documentation for the infer_size_impl class in ExpandUtils.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/ExpandUtils.cpp lines 16–42
template <typename Container, typename ArrayType>
Container infer_size_impl(ArrayType a, ArrayType b) {
// Use ptrdiff_t to ensure signed comparison.
auto dimsA = static_cast<ptrdiff_t>(a.size());
auto dimsB = static_cast<ptrdiff_t>(b.size());
auto ndim = dimsA > dimsB ? dimsA : dimsB;
Container expandedSizes(ndim);
for (ptrdiff_t i = ndim - 1; i >= 0; --i) {
ptrdiff_t offset = ndim - 1 - i;
ptrdiff_t dimA = dimsA - 1 - offset;
ptrdiff_t dimB = dimsB - 1 - offset;
auto sizeA = (dimA >= 0) ? a[dimA] : 1;
auto sizeB = (dimB >= 0) ? b[dimB] : 1;
TORCH_CHECK(
sizeA == sizeB || sizeA == 1 || sizeB == 1,
"The size of tensor a (", sizeA,
") must match the size of tensor b (", sizeB,
") at non-singleton dimension ", i);
// 1s map to the other size (even 0).
expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
}
return expandedSizes;
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free