Home / Class/ loss_batch_rule_helper Class — pytorch Architecture

loss_batch_rule_helper Class — pytorch Architecture

Architecture documentation for the loss_batch_rule_helper class in BatchRulesLoss.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesLoss.cpp lines 27–48

template <typename Func>
static std::tuple<at::Tensor, std::optional<int64_t>>
loss_batch_rule_helper(const at::Tensor& self, std::optional<int64_t> self_bdim, const at::Tensor& target,
          std::optional<int64_t> target_bdim, int64_t reduction,
          Func loss_fn) {
  auto self_ = flatten_logical(self, self_bdim);
  auto target_ = flatten_logical(target, target_bdim);
  auto result = loss_fn(self_, target_, Reduction::None);
  if (result.dim() == 1) {
    return std::make_tuple(result, 0);
  } else if (reduction == Reduction::None) {
    DimVector end_shape;
    const auto batched_elem = self_bdim.has_value() ?
        moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
    return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
  } else if (reduction == Reduction::Sum) {
    return std::make_tuple(result.sum(-1), 0);
  } else if (reduction == Reduction::Mean) {
    return std::make_tuple(result.mean(-1), 0);
  }
  TORCH_INTERNAL_ASSERT(false);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free