get_all() — pytorch Function Reference
Architecture documentation for the get_all() function in generate_kernels.py from the pytorch codebase.
Entity Profile
Dependency Diagram
graph TD 0e771794_8e99_3f41_d4a0_5d896b90ba41["get_all()"] b3b0428d_3544_6dc9_2f4c_2ec796aba900["get_all()"] b3b0428d_3544_6dc9_2f4c_2ec796aba900 -->|calls| 0e771794_8e99_3f41_d4a0_5d896b90ba41 20303473_0948_19b1_a212_748886b73572["main()"] 20303473_0948_19b1_a212_748886b73572 -->|calls| 0e771794_8e99_3f41_d4a0_5d896b90ba41 b3b0428d_3544_6dc9_2f4c_2ec796aba900["get_all()"] 0e771794_8e99_3f41_d4a0_5d896b90ba41 -->|calls| b3b0428d_3544_6dc9_2f4c_2ec796aba900 style 0e771794_8e99_3f41_d4a0_5d896b90ba41 fill:#6366f1,stroke:#818cf8,color:#fff
Relationship Graph
Source Code
aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py lines 117–143
def get_all(cls) -> list["FwdKernel"]:
kernels: list[FwdKernel] = []
for aligned, dtype, (sm, sm_max) in itertools.product(
[True, False], DTYPES.keys(), itertools.pairwise(SM)
):
# Remove some kernels we don't use
if dtype == "bf16" and sm < 80:
continue
if not aligned and sm >= 80:
continue
for q, k, max_k in [
(64, 64, 64),
# We get better perf with 64x128 on A100
(64 if sm > 75 else 32, 128, 128),
(32, 128, 2**16),
]:
kernels.append(
cls(
aligned=aligned,
dtype=dtype,
sm_range=(sm, sm_max),
q=q,
k=k,
max_k=max_k,
)
)
return kernels
Domain
Subdomains
Calls
Source
Frequently Asked Questions
What does get_all() do?
get_all() is a function in the pytorch codebase.
What does get_all() call?
get_all() calls 1 function(s): get_all.
What calls get_all()?
get_all() is called by 2 function(s): get_all, main.
Analyze Your Own Codebase
Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.
Try Supermodel Free