loss_batch_rule_helper Class — pytorch Architecture
Architecture documentation for the loss_batch_rule_helper class in BatchRulesLoss.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/functorch/BatchRulesLoss.cpp lines 27–48
template <typename Func>
static std::tuple<at::Tensor, std::optional<int64_t>>
loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
std::optional<int64_t> target_bdim, int64_t reduction,
Func loss_fn) {
auto self_ = flatten_logical(self, self_bdim);
auto target_ = flatten_logical(target, target_bdim);
auto result = loss_fn(self_, target_, Reduction::None);
if (result.dim() == 1) {
return std::make_tuple(result, 0);
} else if (reduction == Reduction::None) {
DimVector end_shape;
const auto batched_elem = self_bdim.has_value() ?
moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
} else if (reduction == Reduction::Sum) {
return std::make_tuple(result.sum(-1), 0);
} else if (reduction == Reduction::Mean) {
return std::make_tuple(result.mean(-1), 0);
}
TORCH_INTERNAL_ASSERT(false);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free