FwdKernel Class — pytorch Architecture
Architecture documentation for the FwdKernel class in generate_kernels.py from the pytorch codebase.
Entity Profile
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py lines 50–143
class FwdKernel:
sort_index: tuple[int, ...] = field(init=False, repr=False)
aligned: bool
dtype: str
sm_range: tuple[int, int]
q: int
k: int
max_k: int
supports_dropout: bool = True
supports_bias: bool = True
dispatch_cond: Optional[str] = None
def __post_init__(self) -> None:
# Set kernel selection priority
# The lowest value that matches inputs
# will be selected
self.sort_index = (
# First select aligned kernel
0 if self.aligned else 1,
# Then keep output in RF
self.max_k,
self.k,
# Prefer kernels without dropout/bias if available
1 if self.supports_dropout else 0,
1 if self.supports_bias else 0,
)
@property
def _aligned_suffix(self) -> str:
return "aligned" if self.aligned else "notaligned"
@property
def name(self) -> str:
acc = "rf" if self.max_k <= self.k else "gmem"
return f"fmha_cutlassF_{self.dtype}_{self._aligned_suffix}_{self.q}x{self.k}_{acc}_sm{self.sm_range[0]}"
@property
def cpp_class(self) -> str:
template_args = ", ".join(
[
DTYPES[self.dtype],
f"cutlass::arch::Sm{self.sm_range[0]}",
"true" if self.aligned else "false",
str(self.q),
str(self.k),
str(self.max_k),
"true" if self.supports_dropout else "false",
"true" if self.supports_bias else "false",
]
)
return f"AttentionKernel<{template_args}>"
@property
def impl_group(self) -> str:
# Maps to file which will contain the implementation
return f"{self.dtype}_{self._aligned_suffix}"
@property
def cpp_impl(self) -> str:
return KERNEL_IMPL_TEMPLATE.format(
CPP_CLASS=self.cpp_class,
NAME=self.name,
SM=self.sm_range[0],
SM_MAX=self.sm_range[1],
)
@classmethod
def get_all(cls) -> list["FwdKernel"]:
kernels: list[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), itertools.pairwise(SM)
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
for q, k, max_k in [
(64, 64, 64),
# We get better perf with 64x128 on A100
(64 if sm > 75 else 32, 128, 128),
(32, 128, 2**16),
]:
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
q=q,
k=k,
max_k=max_k,
)
)
return kernels
Source
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free