| [general] | |
| name = "sage_attention" | |
| backends = ["cuda"] | |
| [torch] | |
| src = [ | |
| "torch-ext/torch_binding.cpp", | |
| "torch-ext/torch_binding.h", | |
| ] | |
| [kernel._qattn_sm89] | |
| backend = "cuda" | |
| cuda-capabilities = ["8.9"] | |
| cuda-flags = [ | |
| "-O3", | |
| "-std=c++17", | |
| "-U__CUDA_NO_HALF_OPERATORS__", | |
| "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| "--use_fast_math", | |
| "--threads=1", | |
| "-Xptxas=-v", | |
| "-diag-suppress=174", | |
| ] | |
| cuda-minver = "12.6" | |
| cxx-flags = [ | |
| "-g", | |
| "-O3", | |
| "-fopenmp", | |
| "-lgomp", | |
| "-std=c++17", | |
| "-DENABLE_BF16", | |
| ] | |
| depends = ["torch"] | |
| include = ["."] | |
| src = [ | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu", | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu", | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu", | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu", | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu", | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", | |
| "sage_attention/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu", | |
| "sage_attention/qattn/attn_cuda_sm89.h", | |
| "sage_attention/qattn/qk_int_sv_f8_cuda_sm89.cuh", | |
| "sage_attention/qattn/attn_utils.cuh", | |
| ] | |
| [kernel._qattn_sm90] | |
| backend = "cuda" | |
| cuda-capabilities = ["9.0a"] | |
| cuda-flags = [ | |
| "-O3", | |
| "-std=c++17", | |
| "-U__CUDA_NO_HALF_OPERATORS__", | |
| "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| "--use_fast_math", | |
| "--threads=1", | |
| "-Xptxas=-v", | |
| "-diag-suppress=174", | |
| ] | |
| cuda-minver = "12.6" | |
| cxx-flags = [ | |
| "-g", | |
| "-O3", | |
| "-fopenmp", | |
| "-lgomp", | |
| "-std=c++17", | |
| "-DENABLE_BF16", | |
| ] | |
| depends = ["torch"] | |
| include = ["."] | |
| src = [ | |
| "sage_attention/qattn/qk_int_sv_f8_cuda_sm90.cu", | |
| "sage_attention/qattn/attn_cuda_sm90.h", | |
| "sage_attention/qattn/attn_utils.cuh", | |
| ] | |
| [kernel._qattn] | |
| backend = "cuda" | |
| cuda-capabilities = [ | |
| "8.0", | |
| "8.9", | |
| "9.0a", | |
| ] | |
| cuda-flags = [ | |
| "-O3", | |
| "-std=c++17", | |
| "-U__CUDA_NO_HALF_OPERATORS__", | |
| "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| "--use_fast_math", | |
| "--threads=1", | |
| "-Xptxas=-v", | |
| "-diag-suppress=174", | |
| ] | |
| cuda-minver = "12.6" | |
| cxx-flags = [ | |
| "-g", | |
| "-O3", | |
| "-fopenmp", | |
| "-lgomp", | |
| "-std=c++17", | |
| "-DENABLE_BF16", | |
| ] | |
| depends = ["torch"] | |
| src = [ | |
| "sage_attention/cp_async.cuh", | |
| "sage_attention/dispatch_utils.h", | |
| "sage_attention/math.cuh", | |
| "sage_attention/mma.cuh", | |
| "sage_attention/numeric_conversion.cuh", | |
| "sage_attention/permuted_smem.cuh", | |
| "sage_attention/reduction_utils.cuh", | |
| "sage_attention/wgmma.cuh", | |
| "sage_attention/utils.cuh", | |
| "sage_attention/cuda_tensormap_shim.cuh", | |
| ] | |
| [kernel._qattn_sm80] | |
| backend = "cuda" | |
| cuda-capabilities = ["8.0"] | |
| cuda-flags = [ | |
| "-O3", | |
| "-std=c++17", | |
| "-U__CUDA_NO_HALF_OPERATORS__", | |
| "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| "--use_fast_math", | |
| "--threads=1", | |
| "-Xptxas=-v", | |
| "-diag-suppress=174", | |
| ] | |
| cuda-minver = "12.6" | |
| cxx-flags = [ | |
| "-g", | |
| "-O3", | |
| "-fopenmp", | |
| "-lgomp", | |
| "-std=c++17", | |
| "-DENABLE_BF16", | |
| ] | |
| depends = ["torch"] | |
| include = ["."] | |
| src = [ | |
| "sage_attention/qattn/qk_int_sv_f16_cuda_sm80.cu", | |
| "sage_attention/qattn/attn_cuda_sm80.h", | |
| "sage_attention/qattn/attn_utils.cuh", | |
| ] | |
| [kernel._fused] | |
| backend = "cuda" | |
| cuda-capabilities = [ | |
| "8.0", | |
| "8.9", | |
| "9.0a", | |
| ] | |
| cuda-flags = [ | |
| "-O3", | |
| "-std=c++17", | |
| "-U__CUDA_NO_HALF_OPERATORS__", | |
| "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| "--use_fast_math", | |
| "--threads=1", | |
| "-Xptxas=-v", | |
| "-diag-suppress=174", | |
| ] | |
| cuda-minver = "12.6" | |
| cxx-flags = [ | |
| "-g", | |
| "-O3", | |
| "-fopenmp", | |
| "-lgomp", | |
| "-std=c++17", | |
| "-DENABLE_BF16", | |
| ] | |
| depends = ["torch"] | |
| include = ["."] | |
| src = [ | |
| "sage_attention/fused/fused.cu", | |
| "sage_attention/fused/fused.h", | |
| ] | |