atol_rtol_tensor_batch_rule Class — pytorch Architecture
Architecture documentation for the atol_rtol_tensor_batch_rule class in BatchRulesLinearAlgebra.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp lines 439–467
template<typename F>
std::tuple<Tensor, std::optional<int64_t>>
atol_rtol_tensor_batch_rule(
F Func, const Tensor& input, std::optional<int64_t> input_bdim,
const std::optional<Tensor>& atol, const std::optional<int64_t> atol_bdim,
const std::optional<Tensor>& rtol, const std::optional<int64_t> rtol_bdim, bool hermitian, char const *op_name) {
auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
TORCH_CHECK(input_logical_rank >= 2,
op_name, ": The input tensor input must have at least 2 dimensions.");
// atol and rtol's dims must be broadcastable to the number of batch dims of input
// which is input's dim - 2 (input represents a batch of matrices, so 2 is for the matrix dimensions)
const auto input_logical_num_bdims = input_logical_rank - 2;
const int64_t atol_logical_num_bdims = atol.has_value() ? rankWithoutBatchDim(*atol, atol_bdim) : 0;
const int64_t rtol_logical_num_bdims = rtol.has_value() ? rankWithoutBatchDim(*rtol, rtol_bdim) : 0;
const auto max_logical_bdims = std::max({input_logical_num_bdims, atol_logical_num_bdims, rtol_logical_num_bdims});
auto input_ = moveBatchDimToFront(input, input_bdim);
auto atol_ = atol.has_value() ? moveBatchDimToFront(*atol, atol_bdim) : atol;
auto rtol_ = rtol.has_value() ? moveBatchDimToFront(*rtol, rtol_bdim) : rtol;
// pad all inputs to have the same number of (non-vmap) batch dimensions
input_ = maybePadToLogicalRank(input_, input_bdim, max_logical_bdims + 2);
atol_ = atol_.has_value() ? maybePadToLogicalRank(*atol_, atol_bdim, max_logical_bdims) : atol_;
rtol_ = rtol_.has_value() ? maybePadToLogicalRank(*rtol_, rtol_bdim, max_logical_bdims) : rtol_;
return std::make_tuple(Func(input_, atol_, rtol_, hermitian), 0);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free