Home / Class/ infer_size_impl Class — pytorch Architecture

infer_size_impl Class — pytorch Architecture

Architecture documentation for the infer_size_impl class in ExpandUtils.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/ExpandUtils.cpp lines 16–42

template <typename Container, typename ArrayType>
Container infer_size_impl(ArrayType a, ArrayType b) {
  // Use ptrdiff_t to ensure signed comparison.
  auto dimsA = static_cast<ptrdiff_t>(a.size());
  auto dimsB = static_cast<ptrdiff_t>(b.size());
  auto ndim = dimsA > dimsB ? dimsA : dimsB;
  Container expandedSizes(ndim);

  for (ptrdiff_t i = ndim - 1; i >= 0; --i) {
    ptrdiff_t offset = ndim - 1 - i;
    ptrdiff_t dimA = dimsA - 1 - offset;
    ptrdiff_t dimB = dimsB - 1 - offset;
    auto sizeA = (dimA >= 0) ? a[dimA] : 1;
    auto sizeB = (dimB >= 0) ? b[dimB] : 1;

    TORCH_CHECK(
        sizeA == sizeB || sizeA == 1 || sizeB == 1,
        "The size of tensor a (", sizeA,
        ") must match the size of tensor b (", sizeB,
        ") at non-singleton dimension ", i);

      // 1s map to the other size (even 0).
      expandedSizes[i] = sizeA == 1 ? sizeB : sizeA;
  }

  return expandedSizes;
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free