scatter_batch_rule Class — pytorch Architecture
Architecture documentation for the scatter_batch_rule class in BatchRulesScatterOps.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/functorch/BatchRulesScatterOps.cpp lines 678–708
template<typename Func, typename ...Args>
std::tuple<Tensor, std::optional<int64_t>> scatter_batch_rule(
Func f,
const Tensor& self, std::optional<int64_t> self_bdim,
int64_t dim,
const Tensor& index, std::optional<int64_t> index_bdim,
const Scalar& value, Args... args) {
auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
auto index_logical_rank = rankWithoutBatchDim(index, index_bdim);
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto index_ = moveBatchDimToFront(index, index_bdim);
if (self_logical_rank == 0) {
self_ = self_.unsqueeze(-1);
}
if (index_logical_rank == 0) {
index_ = index_.unsqueeze(-1);
}
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
index_ = ensure_has_bdim(index_, index_bdim.has_value(), batch_size);
auto physical_dim = getPhysicalDim(self_, /*has_batch_dim*/true, dim);
auto result = f(self_, physical_dim, index_, value, args...);
// result should have same shape as self
if (self_logical_rank == 0) {
result = result.squeeze(-1);
}
return std::make_tuple(result, 0);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free