Home / Class/ atol_rtol_tensor_batch_rule Class — pytorch Architecture

atol_rtol_tensor_batch_rule Class — pytorch Architecture

Architecture documentation for the atol_rtol_tensor_batch_rule class in BatchRulesLinearAlgebra.cpp from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp lines 439–467

template<typename F>
std::tuple<Tensor, std::optional<int64_t>>
atol_rtol_tensor_batch_rule(
    F Func, const Tensor& input, std::optional<int64_t> input_bdim,
    const std::optional<Tensor>& atol, const std::optional<int64_t> atol_bdim,
    const std::optional<Tensor>& rtol, const std::optional<int64_t> rtol_bdim, bool hermitian, char const *op_name) {
  auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);

  TORCH_CHECK(input_logical_rank >= 2,
            op_name, ": The input tensor input must have at least 2 dimensions.");

  // atol and rtol's dims must be broadcastable to the number of batch dims of input
  // which is input's dim - 2 (input represents a batch of matrices, so 2 is for the matrix dimensions)
  const auto input_logical_num_bdims = input_logical_rank - 2;
  const int64_t atol_logical_num_bdims = atol.has_value() ? rankWithoutBatchDim(*atol, atol_bdim) : 0;
  const int64_t rtol_logical_num_bdims = rtol.has_value() ? rankWithoutBatchDim(*rtol, rtol_bdim) : 0;
  const auto max_logical_bdims = std::max({input_logical_num_bdims, atol_logical_num_bdims, rtol_logical_num_bdims});

  auto input_ = moveBatchDimToFront(input, input_bdim);
  auto atol_ = atol.has_value() ? moveBatchDimToFront(*atol, atol_bdim) : atol;
  auto rtol_ = rtol.has_value() ? moveBatchDimToFront(*rtol, rtol_bdim) : rtol;

  // pad all inputs to have the same number of (non-vmap) batch dimensions
  input_ = maybePadToLogicalRank(input_, input_bdim, max_logical_bdims + 2);
  atol_ = atol_.has_value() ? maybePadToLogicalRank(*atol_, atol_bdim, max_logical_bdims) : atol_;
  rtol_ = rtol_.has_value() ? maybePadToLogicalRank(*rtol_, rtol_bdim, max_logical_bdims) : rtol_;

  return std::make_tuple(Func(input_, atol_, rtol_, hermitian), 0);
}

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free