1.1Kпросмотров
16 марта 2025 г.
📷 ФотоScore: 1.2K
⚡️SageAttention — brand new attention Flash Attention бустит классический attention по скорости, но что если я скажу, что можно еще быстрее, если правильно квантизовать Flash Attention. Sage Attention за счёт умной квантизации + понимании CUDA ускоряют Flash Attention 2 от 2x до 5x раз. Небольшое введение в квантизацию. Допустим, у нас есть матрица A, которую мы хотим перевести из формата FP16 в INT8. Для этого:
1. Сначала находим максимальное абсолютное значение элементов матрицы и вычисляем коэффициент скейлинга (scale factor):
δA = max(|A|) / 127 2. Делим исходную матрицу на этот коэффициент и округляем, получая матрицу низкой точности \hat{A} (например, INT8):
 = round(A / δA) 3. После вычислений в низкой точности (например, произведения матриц AB), получаем итоговый результат, снова умножая на коэффициенты скейлинга:
C ≈ (Â × B̂) × (δA × δB) Здесь приведен простой пример квантизации — на самом деле, факторы скейлинга можно выбирать для каждой строчки отдельно, что улучшить качество за счёт бОльшего числа факторов. В статье SageAttention про это подробно рассказано, рекомендую ознакомиться. Основной проблемой данного подхода является то, что могут возникать выбросы: Например, если в исходной матрице A:
A = [0.1, 0.2, 0.15, 1000] То коэффициент скейлинга будет:
δA = 1000 / 127 ≈ 7.87 И после квантизации получится:
A_quant = round(A / δA) = [0, 0, 0, 127] Поэтому залог успешной квантизации, подумать головой и разобраться с выбросами, что и сделали в SageAttention. Небольшое введение в формулы и погнали:
➤ P = Attention = Softmax(QK^T / \sqrt{d})
➤ V - value -> V* = PV SageAttention:
➤ Переводим вычисление QK^T в INT8 и стабилизируем выбросы с помощью нормализации матрицы K (это стандартная техника для борьбы с выбросами, берите на вооружение!). Дополнительно оптимально подбираем стратегию квантизации, глядя на структуру самих матриц.
➤ Делаем фьюзинг (объединение) операций скейлинга и анскейлинга с соседними слоями.
➤ Вычисление PV делаем в FP16 — так стабильнее, потому что FP8 слишком сильно теряет точность. SageAttention2:
➤ Современные GPU (например, Nvidia H100) поддерживают не только INT8, но и INT4. Поэтому теперь делаем квантизацию матриц Q и K в INT4, а также дополнительно нормализуем матрицу Q, чтобы не искажать итоговый результат attention.
➤ Реализуем глубокие CUDA-оптимизации. Например, оказывается, что масштабирование при квантизации можно параллельно применять сразу к элементам векторов по специальной формуле (каждый 8i + 2k + 1-й элемент). Жёсткая оптимизация на низком уровне GPU 🔥
➤ Вычисления PV делаем в FP8, так как теперь нашли способ стабилизировать результат при помощи двойной буферизации с FP32 SpargeAttention:
➤ FlashAttention работает с матрицами поблочно. Если какой-то блок attention заполнен нулями, то считать его нет смысла — результат не изменится.
➤ Авторы придумали механизм, как быстро находить и пропускать такие блоки на основе похожести матриц Q и K.
➤ Дополнительно, в процессе вычисления attention можно пропускать отдельные блоки, основываясь на пороге.
➤ Поверх этой разреженности используется подход из SageAttention2 (с INT4-квантизацией и CUDA-оптимизациями). Результат 🔥 SageAttention уже быстрее, чем FlashAttention2 примерно в 2.1 раза (на RTX4090), а также на 2.7 раза быстрее xformers. При этом точность моделей практически не страдает.
🔥 SageAttention2 ускоряет attention примерно в 3 раза по сравнению с FlashAttention2 и в 4.5 раза быстрее xformers на RTX4090. 🔥 SpargeAttn ускоряет inference дополнительно, обеспечивая от 2.5 до 5 раз быстрее по сравнению с другими методами attention. Как применить у себя. Устанавливаем библиотеку и дальше делаем вот так import torch.nn.functional as F from sageattention import sageattn
F.scaled_dot_product_attention = sageattn