collapse_dims Class — pytorch Architecture
Architecture documentation for the collapse_dims class in CollapseDims.h from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/CollapseDims.h lines 22–92
template <typename T>
inline std::pair<int64_t, int64_t> collapse_dims(
T* sizes,
T* strides,
int64_t dims,
const int excludeDim = -1) {
TORCH_CHECK(
excludeDim >= -1 && excludeDim < dims,
"expected excluded dim between -1 and dims - 1");
int64_t stopDim = (excludeDim == -1) ? dims : excludeDim;
int64_t newIndex = -1;
int64_t oldIndex = 0;
int64_t remappedExcludedDim = -1;
while (oldIndex < dims) {
// Finds a dimension to collapse into
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
++oldIndex;
break;
}
// Collapses dims
for (; oldIndex < stopDim; ++oldIndex) {
if (sizes[oldIndex] == 1) {
continue;
}
if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) {
sizes[newIndex] *= sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
} else {
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
}
}
// Handles excludeDim being set (oldIndex == excludeDim)
if (oldIndex != dims) {
// Preserves excluded dimension
++newIndex;
sizes[newIndex] = sizes[oldIndex];
strides[newIndex] = strides[oldIndex];
remappedExcludedDim = newIndex;
// Restarts iteration after excludeDim
++oldIndex;
stopDim = dims;
}
}
// Handles special case of all dims size 1
if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) {
dims = 1;
sizes[0] = 1;
strides[0] = 1;
return std::pair<int64_t, int64_t>(0, 1);
}
dims = newIndex + 1;
return std::pair<int64_t, int64_t>(remappedExcludedDim, dims);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free