Kernels

Upload README.md with huggingface_hub

#2
by sayakpaul HF Staff - opened
Files changed (1) hide show
  1. README.md +29 -55
README.md CHANGED
@@ -1,71 +1,45 @@
1
  ---
 
2
  license: apache-2.0
3
- tags:
4
- - kernels
5
  ---
6
 
7
- # Metal Flash SDPA
 
8
 
9
- Optimized SDPA kernels inspired by Flash Attention for Metal.
10
 
11
- Some components of these kernels are from [mlx](https://github.com/ml-explore/mlx).
12
 
13
- ## Supported Features
14
 
15
- - Variable-length sequences without padding
16
- - Causal masking
17
- - Grouped Query Attention (GQA) and Multi-Query Attention (MQA)
18
- - Softcapping support for attention score regularization
19
- - Data types: `float32`, `float16`, `bfloat16`
20
- - Head dimensions: `32`, `64`, `72`, `80`, `96`, `128`, `256`
21
 
22
- ## API Reference
 
 
23
 
24
- ### flash_attention_varlen
 
25
 
26
- ```python
27
- metal_flash_sdpa.flash_attention_varlen(
28
- out: torch.Tensor,
29
- query: torch.Tensor,
30
- key: torch.Tensor,
31
- value: torch.Tensor,
32
- cu_seqlens_q: torch.Tensor,
33
- cu_seqlens_k: torch.Tensor,
34
- max_seqlen_q: int,
35
- max_seqlen_k: int,
36
- do_causal: bool,
37
- scale: float,
38
- softcapping: float
39
- ) -> None
40
  ```
41
 
42
- - **out**: Output tensor `[total_q_tokens, num_heads, head_dim]`, modified in-place.
43
- - **query/key/value**: Input tensors `[total_tokens, num_heads(_kv), head_dim]`.
44
- - **cu_seqlens_q/cu_seqlens_k**: Cumulative sequence lengths (`torch.int32`), `[batch_size + 1]`.
45
- - **max_seqlen_q/max_seqlen_k**: Maximum sequence lengths.
46
- - **do_causal**: Enable causal masking.
47
- - **scale**: Attention score scaling factor (e.g., `1/sqrt(head_dim)`).
48
- - **softcapping**: Softcapping value for score regularization (use `1.0` for no softcapping).
49
 
50
- ### flash_attn_varlen_func
 
 
51
 
52
- Compatibility wrapper matching the original Flash Attention API:
53
 
54
- ```python
55
- out = metal_flash_sdpa.flash_attn_varlen_func(
56
- q: torch.Tensor,
57
- k: torch.Tensor,
58
- v: torch.Tensor,
59
- cu_seqlens_q: torch.Tensor,
60
- cu_seqlens_k: torch.Tensor,
61
- max_seqlen_q: int,
62
- max_seqlen_k: int,
63
- dropout_p: float = 0.0,
64
- softmax_scale: Optional[float] = None,
65
- causal: bool = False,
66
- window_size: Tuple[int, int] = (-1, -1),
67
- alibi_slopes: Optional[torch.Tensor] = None,
68
- deterministic: bool = False,
69
- return_attn_probs: bool = False
70
- )
71
- ```
 
1
  ---
2
+ library_name: kernels
3
  license: apache-2.0
 
 
4
  ---
5
 
6
+ <!-- This model card has automatically been generated. You
7
+ should probably proofread and complete it, then remove this comment. -->
8
 
 
9
 
10
+ This is the repository card of {repo_id} that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.
11
 
 
12
 
13
+ ## How to use
 
 
 
 
 
14
 
15
+ ```python
16
+ # make sure `kernels` is installed: `pip install -U kernels`
17
+ from kernels import get_kernel
18
 
19
+ kernel_module = get_kernel("kernels-community/metal-flash-sdpa") # <- change the ID if needed
20
+ flash_attention_varlen = kernel_module.flash_attention_varlen
21
 
22
+ flash_attention_varlen(...)
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  ```
24
 
25
+ ## Available functions
 
 
 
 
 
 
26
 
27
+ - `flash_attention_varlen`
28
+ - `flash_attn_varlen_func`
29
+ - `ops`
30
 
31
+ ## Supported backends
32
 
33
+ - metal
34
+
35
+ ## Benchmarks
36
+
37
+ [TODO: provide benchmarks if available]
38
+
39
+ ## Source code
40
+
41
+ [TODO: provide original source code and other relevant citations if available]
42
+
43
+ ## Notes
44
+
45
+ [TODO: provide additional notes about this kernel if needed]