BwdKernel Class — pytorch Architecture
Architecture documentation for the BwdKernel class in generate_kernels.py from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py lines 147–300
class BwdKernel:
sort_index: tuple[int, ...] = field(init=False, repr=False)
sm_range: tuple[int, int]
dtype: str
aligned: bool
apply_dropout: bool
preload_mmas: bool
block_i: int
block_j: int
max_k: int
dispatch_cond: Optional[str] = None
keys_queries_aligned_to_blocksizes: bool = False
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
# Take a kernel without dropout if possible
1 if self.apply_dropout else 0,
# Then take the smallest maxK
self.max_k,
# .. and the highest block_i
-self.block_i,
# and finally avoid bounds-checks if possible
0 if self.keys_queries_aligned_to_blocksizes else 1,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def name(self) -> str:
dropout_suffix = "_dropout" if self.apply_dropout else ""
seqlen_aligned_suffix = (
"_seqaligned" if self.keys_queries_aligned_to_blocksizes else ""
)
return (
f"fmha_cutlassB_{self.dtype}_{self._aligned_suffix}"
f"_{self.block_i}x{self.block_j}_k{self.max_k}{dropout_suffix}{seqlen_aligned_suffix}_sm{self.sm_range[0]}"
)
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
f"cutlass::arch::Sm{self.sm_range[0]}",
DTYPES[self.dtype],
"true" if self.aligned else "false",
"true" if self.apply_dropout else "false",
"true" if self.preload_mmas else "false",
str(self.block_i),
str(self.block_j),
str(self.max_k),
]
)
if self.keys_queries_aligned_to_blocksizes:
template_args += ", true"
return f"AttentionBackwardKernel<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
dropout_suffix = "_dropout" if self.apply_dropout else ""
return f"{self.dtype}_{self._aligned_suffix}_k{self.max_k}{dropout_suffix}"
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class,
NAME=self.name,
SM=self.sm_range[0],
SM_MAX=self.sm_range[1],
)
@classmethod
def get_all(cls) -> list["BwdKernel"]:
kernels: list[BwdKernel] = []
for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product(
[True, False],
DTYPES.keys(),
itertools.pairwise(SM),
[True, False],
[32, 64, 128, 2**16],
):
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
is_half = dtype in ["bf16", "f16"]
bi_values = [64]
# Some architectures have more shmem and can use 128
# We still need fallback to 64 for GPUs with less shmem
# (Sm75, Sm86 ...)
if sm >= 80 or (sm >= 70 and is_half):
if max_k > 64:
bi_values.append(128)
for bi in bi_values:
output_in_rf = is_half and max_k <= bi
preload_mmas = is_half and sm >= 80 and output_in_rf
bj = 128 if (preload_mmas and max_k > 64) else 64
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
apply_dropout=apply_dropout,
preload_mmas=preload_mmas,
block_i=bi,
block_j=bj,
max_k=max_k,
)
)
# A few specialized kernels that are faster
if apply_dropout or max_k > 128 or not is_half or not aligned:
continue
if sm not in [70, 80]:
continue
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
apply_dropout=apply_dropout,
preload_mmas=preload_mmas,
block_i=bi,
block_j=bj,
max_k=max_k,
keys_queries_aligned_to_blocksizes=True,
)
)
# Add some specialized kernels for stable diffusion BW (K=80)
# This is the only kernel that can keep the outputs on RF on
# Sm86/Sm89, so it's much faster than the 64x64 one
for dtype in ["f16", "bf16"]:
kernels.append(
cls(
aligned=True,
dtype=dtype,
sm_range=(80, SM[SM.index(80) + 1]),
apply_dropout=False,
preload_mmas=True,
block_i=128,
block_j=64,
max_k=96,
# Sm80 has a faster kernel for this case
dispatch_cond="cc == 86 || cc == 89",
)
)
return kernels
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free