slow_conv3d_backward_weight_frame Class — pytorch Architecture
Architecture documentation for the slow_conv3d_backward_weight_frame class in ConvolutionMM3d.cpp from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/ConvolutionMM3d.cpp lines 490–516
template <typename scalar_t>
void slow_conv3d_backward_weight_frame(
TensorAccessor<scalar_t, 2> grad_weight,
TensorAccessor<const scalar_t, 4> grad_output,
TensorAccessor<const scalar_t, 2> finput,
int64_t groups) {
// Compute grad_weight += grad_output.reshape({grad_output.shape(0), -1}) * finput.T
// Note gemm expects fortran order, so all 3 matrices are transposed.
// Swapping argument order cancels this, since C == AB <=> T(C) == T(B)T(A)
const int64_t m = grad_weight.size(1);
const int64_t n = grad_weight.size(0) / groups;
const int64_t k = grad_output.size(1) * grad_output.size(2) * grad_output.size(3);
const int64_t lda = k;
const int64_t ldb = k;
const int64_t ldc = m;
at::native::cpublas::gemm_batched_with_stride(
TransposeType::Transpose,
TransposeType::NoTranspose,
groups, m, n, k,
static_cast<scalar_t>(1),
finput.data(), lda, finput.stride(0) * m,
grad_output.data(), ldb, grad_output.stride(0) * n,
static_cast<scalar_t>(1),
grad_weight.data(), ldc, grad_weight.stride(0) * n);
}
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free