使用指南¶
本指南介绍如何在项目中使用 FlashAttention-Plus。
基础用法¶
启用 FlagGems 后端¶
FlashAttention-Plus 使用环境变量来切换后端:
import os
# 启用 FlagGems 后端(默认)
os.environ["FLASH_ATTENTION_USE_FLAGGEMS"] = "TRUE"
# 或禁用以使用原始 CUDA 后端(如果可用)
os.environ["FLASH_ATTENTION_USE_FLAGGEMS"] = "FALSE"
简单示例¶
import torch
from flash_attn import flash_attn_func
# 创建输入张量
batch_size = 2
seq_length = 1024
num_heads = 16
head_dim = 64
# 注意:输入必须是 fp16 或 bf16
q = torch.randn(batch_size, seq_length, num_heads, head_dim,
device='cuda', dtype=torch.float16)
k = torch.randn(batch_size, seq_length, num_heads, head_dim,
device='cuda', dtype=torch.float16)
v = torch.randn(batch_size, seq_length, num_heads, head_dim,
device='cuda', dtype=torch.float16)
# 应用闪存注意力
output = flash_attn_func(q, k, v, causal=True)
高级用法¶
使用 Dropout¶
Dropout 支持
Dropout 接口可用,但在当前版本中可能无法完全正常工作。
自定义 Softmax 缩放¶
# 自定义缩放因子
scale = 1.0 / math.sqrt(head_dim)
output = flash_attn_func(q, k, v, softmax_scale=scale, causal=True)
非因果注意力¶
输入要求¶
数据类型¶
FlashAttention-Plus 要求输入为半精度格式:
torch.float16
(fp16)torch.bfloat16
(bf16)
张量形状¶
输入张量应具有形状:[batch_size, seq_length, num_heads, head_dim]
设备¶
所有张量必须在 CUDA 设备上:
与 Transformers 集成¶
自定义注意力模块¶
import torch.nn as nn
from flash_attn import flash_attn_func
class FlashSelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = dropout
self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x, causal=False):
batch_size, seq_len, _ = x.shape
# 计算 Q, K, V
qkv = self.qkv(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
# 应用闪存注意力
output = flash_attn_func(q, k, v, dropout_p=self.dropout, causal=causal)
# 重塑并投影输出
output = output.reshape(batch_size, seq_len, self.embed_dim)
output = self.out_proj(output)
return output
与现有模型一起使用¶
# 用闪存注意力替换标准注意力
model = YourTransformerModel()
# 启用 FlagGems 后端
os.environ["FLASH_ATTENTION_USE_FLAGGEMS"] = "TRUE"
# 模型现在将使用 FlashAttention-Plus
output = model(input_ids)
性能提示¶
-
首次运行:由于 Triton 内核编译,首次运行可能较慢。后续运行将使用缓存的内核。
-
批处理:一起处理多个序列以获得更好的 GPU 利用率:
-
内存效率:FlashAttention 设计为内存高效。您可以处理比标准注意力更长的序列。
调试¶
启用调试模式以获取更多信息:
import logging
logging.basicConfig(level=logging.DEBUG)
# 这将显示有关注意力计算的详细信息
output = flash_attn_func(q, k, v, causal=True)