Home / Class/ BwdKernel Class — pytorch Architecture

BwdKernel Class — pytorch Architecture

Architecture documentation for the BwdKernel class in generate_kernels.py from the pytorch codebase.

Entity Profile

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py lines 147–300

class BwdKernel:
    sort_index: tuple[int, ...] = field(init=False, repr=False)
    sm_range: tuple[int, int]
    dtype: str
    aligned: bool
    apply_dropout: bool
    preload_mmas: bool
    block_i: int
    block_j: int
    max_k: int
    dispatch_cond: Optional[str] = None
    keys_queries_aligned_to_blocksizes: bool = False

    def __post_init__(self) -> None:
        # Set kernel selection priority
        # The lowest value that matches inputs
        # will be selected
        self.sort_index = (
            # First select aligned kernel
            0 if self.aligned else 1,
            # Take a kernel without dropout if possible
            1 if self.apply_dropout else 0,
            # Then take the smallest maxK
            self.max_k,
            # .. and the highest block_i
            -self.block_i,
            # and finally avoid bounds-checks if possible
            0 if self.keys_queries_aligned_to_blocksizes else 1,
        )

    @property
    def _aligned_suffix(self) -> str:
        return "aligned" if self.aligned else "notaligned"

    @property
    def name(self) -> str:
        dropout_suffix = "_dropout" if self.apply_dropout else ""
        seqlen_aligned_suffix = (
            "_seqaligned" if self.keys_queries_aligned_to_blocksizes else ""
        )
        return (
            f"fmha_cutlassB_{self.dtype}_{self._aligned_suffix}"
            f"_{self.block_i}x{self.block_j}_k{self.max_k}{dropout_suffix}{seqlen_aligned_suffix}_sm{self.sm_range[0]}"
        )

    @property
    def cpp_class(self) -> str:
        template_args = ", ".join(
            [
                f"cutlass::arch::Sm{self.sm_range[0]}",
                DTYPES[self.dtype],
                "true" if self.aligned else "false",
                "true" if self.apply_dropout else "false",
                "true" if self.preload_mmas else "false",
                str(self.block_i),
                str(self.block_j),
                str(self.max_k),
            ]
        )
        if self.keys_queries_aligned_to_blocksizes:
            template_args += ", true"
        return f"AttentionBackwardKernel<{template_args}>"

    @property
    def impl_group(self) -> str:
        # Maps to file which will contain the implementation
        dropout_suffix = "_dropout" if self.apply_dropout else ""
        return f"{self.dtype}_{self._aligned_suffix}_k{self.max_k}{dropout_suffix}"

    @property
    def cpp_impl(self) -> str:
        return KERNEL_IMPL_TEMPLATE.format(
            CPP_CLASS=self.cpp_class,
            NAME=self.name,
            SM=self.sm_range[0],
            SM_MAX=self.sm_range[1],
        )

    @classmethod
    def get_all(cls) -> list["BwdKernel"]:
        kernels: list[BwdKernel] = []
        for aligned, dtype, (sm, sm_max), apply_dropout, max_k in itertools.product(
            [True, False],
            DTYPES.keys(),
            itertools.pairwise(SM),
            [True, False],
            [32, 64, 128, 2**16],
        ):
            if dtype == "bf16" and sm < 80:
                continue
            if not aligned and sm >= 80:
                continue
            is_half = dtype in ["bf16", "f16"]

            bi_values = [64]
            # Some architectures have more shmem and can use 128
            # We still need fallback to 64 for GPUs with less shmem
            # (Sm75, Sm86 ...)
            if sm >= 80 or (sm >= 70 and is_half):
                if max_k > 64:
                    bi_values.append(128)
            for bi in bi_values:
                output_in_rf = is_half and max_k <= bi
                preload_mmas = is_half and sm >= 80 and output_in_rf
                bj = 128 if (preload_mmas and max_k > 64) else 64
                kernels.append(
                    cls(
                        aligned=aligned,
                        dtype=dtype,
                        sm_range=(sm, sm_max),
                        apply_dropout=apply_dropout,
                        preload_mmas=preload_mmas,
                        block_i=bi,
                        block_j=bj,
                        max_k=max_k,
                    )
                )
                # A few specialized kernels that are faster
                if apply_dropout or max_k > 128 or not is_half or not aligned:
                    continue
                if sm not in [70, 80]:
                    continue
                kernels.append(
                    cls(
                        aligned=aligned,
                        dtype=dtype,
                        sm_range=(sm, sm_max),
                        apply_dropout=apply_dropout,
                        preload_mmas=preload_mmas,
                        block_i=bi,
                        block_j=bj,
                        max_k=max_k,
                        keys_queries_aligned_to_blocksizes=True,
                    )
                )
        # Add some specialized kernels for stable diffusion BW (K=80)
        # This is the only kernel that can keep the outputs on RF on
        # Sm86/Sm89, so it's much faster than the 64x64 one
        for dtype in ["f16", "bf16"]:
            kernels.append(
                cls(
                    aligned=True,
                    dtype=dtype,
                    sm_range=(80, SM[SM.index(80) + 1]),
                    apply_dropout=False,
                    preload_mmas=True,
                    block_i=128,
                    block_j=64,
                    max_k=96,
                    # Sm80 has a faster kernel for this case
                    dispatch_cond="cc == 86 || cc == 89",
                )
            )
        return kernels

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free