Home / Function/ write_decl_impl() — pytorch Function Reference

write_decl_impl() — pytorch Function Reference

Architecture documentation for the write_decl_impl() function in generate_kernels.py from the pytorch codebase.

Entity Profile

Dependency Diagram

graph TD
  1ec91f05_2749_ffa8_d3f5_5f37a18418dc["write_decl_impl()"]
  20303473_0948_19b1_a212_748886b73572["main()"]
  20303473_0948_19b1_a212_748886b73572 -->|calls| 1ec91f05_2749_ffa8_d3f5_5f37a18418dc
  style 1ec91f05_2749_ffa8_d3f5_5f37a18418dc fill:#6366f1,stroke:#818cf8,color:#fff

Relationship Graph

Source Code

aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernels/generate_kernels.py lines 306–378

def write_decl_impl(
    kernels: list[T],
    family_name: str,
    impl_file: str,
    autogen_dir: Path,
    disable_def: Optional[str] = None,
) -> None:
    cpp_file_header = """/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */
// This file is auto-generated. See "generate_kernels.py"
"""

    kernels.sort()

    implfile_to_kernels: dict[str, list[T]] = collections.defaultdict(list)
    cat_to_kernels: dict[tuple[str, int, int], list[T]] = collections.defaultdict(list)

    dispatch_all = ""
    declarations = cpp_file_header + "#pragma once\n"
    # declarations += f"#ifndef {disable_def}\n"
    declarations += f"""#include {impl_file}\n"""
    declarations += """using namespace PyTorchMemEffAttention;\n"""

    # Declaration of kernel functions
    for k in kernels:
        implfile_to_kernels[k.impl_group].append(k)
        cat_to_kernels[(k.dtype, k.sm_range[0], k.sm_range[1])].append(k)

    for (cat_dt, cat_sm, cat_sm_max), kernels in cat_to_kernels.items():
        declarations += f"// ======== {cat_dt} / sm{cat_sm} ========\n"
        declarations += "\n".join(
            k.cpp_impl.split("{")[0].rstrip() + ";" for k in kernels
        )
        dispatch_category_fn = f"dispatch_{family_name}_{cat_dt}_sm{cat_sm}"
        declarations += (
            f"\n\ntemplate <typename T> void {dispatch_category_fn}(T cb, int cc) {{\n"
        )
        for k in kernels:
            _call = f"cb({k.cpp_class}(), {k.name});\n"
            if k.dispatch_cond is not None:
                _call = f"if ({k.dispatch_cond}) {_call}"
            declarations += f"    {_call}"
        declarations += "}\n\n"
        dispatch_all += f"""
    if (std::is_same_v<DT, {DTYPES[cat_dt]}> && {cat_sm} <= cc && cc < {cat_sm_max}) {{
        {dispatch_category_fn}(cb, cc);
    }}"""

    declarations += f"""
template <typename DT, typename T>
void dispatch_{family_name}(T cb, int cc = 0) {{
{dispatch_all}
}}
"""
    # declarations += f"#endif // {disable_def}\n"

    # Write declarations to family header
    (autogen_dir / f"{family_name}.h").write_text(declarations)

    for f, f_kernels in implfile_to_kernels.items():
        impl_cu = cpp_file_header
        # impl_cu += f"#ifndef {disable_def}\n"
        impl_cu += f"""#include {impl_file}\n"""
        impl_cu += """using namespace PyTorchMemEffAttention;\n"""
        for k in f_kernels:
            impl_cu += k.cpp_impl
        # impl_cu += f"#endif // {disable_def}\n"
        (autogen_dir / f"{family_name}_{f}.cu").write_text(impl_cu)

Subdomains

Called By

Frequently Asked Questions

What does write_decl_impl() do?
write_decl_impl() is a function in the pytorch codebase.
What calls write_decl_impl()?
write_decl_impl() is called by 1 function(s): main.

Analyze Your Own Codebase

Get architecture documentation, dependency graphs, and domain analysis for your codebase in minutes.

Try Supermodel Free