A
Awesome DL
@awesome_dl841 подп.
1.1Kпросмотров
16 марта 2025 г.
📷 ФотоScore: 1.2K
⚡️SageAttention — brand new attention Flash Attention бустит классический attention по скорости, но что если я скажу, что можно еще быстрее, если правильно квантизовать Flash Attention. Sage Attention за счёт умной квантизации + понимании CUDA ускоряют Flash Attention 2 от 2x до 5x раз. Небольшое введение в квантизацию. Допустим, у нас есть матрица A, которую мы хотим перевести из формата FP16 в INT8. Для этого: 1. Сначала находим максимальное абсолютное значение элементов матрицы и вычисляем коэффициент скейлинга (scale factor): δA = max(|A|) / 127 2. Делим исходную матрицу на этот коэффициент и округляем, получая матрицу низкой точности \hat{A} (например, INT8): Â = round(A / δA) 3. После вычислений в низкой точности (например, произведения матриц AB), получаем итоговый результат, снова умножая на коэффициенты скейлинга: C ≈ (Â × B̂) × (δA × δB) Здесь приведен простой пример квантизации — на самом деле, факторы скейлинга можно выбирать для каждой строчки отдельно, что улучшить качество за счёт бОльшего числа факторов. В статье SageAttention про это подробно рассказано, рекомендую ознакомиться. Основной проблемой данного подхода является то, что могут возникать выбросы: Например, если в исходной матрице A: A = [0.1, 0.2, 0.15, 1000] То коэффициент скейлинга будет: δA = 1000 / 127 ≈ 7.87 И после квантизации получится: A_quant = round(A / δA) = [0, 0, 0, 127] Поэтому залог успешной квантизации, подумать головой и разобраться с выбросами, что и сделали в SageAttention. Небольшое введение в формулы и погнали: ➤ P = Attention = Softmax(QK^T / \sqrt{d}) ➤ V - value -> V* = PV SageAttention: ➤ Переводим вычисление QK^T в INT8 и стабилизируем выбросы с помощью нормализации матрицы K (это стандартная техника для борьбы с выбросами, берите на вооружение!). Дополнительно оптимально подбираем стратегию квантизации, глядя на структуру самих матриц. ➤ Делаем фьюзинг (объединение) операций скейлинга и анскейлинга с соседними слоями. ➤ Вычисление PV делаем в FP16 — так стабильнее, потому что FP8 слишком сильно теряет точность. SageAttention2: ➤ Современные GPU (например, Nvidia H100) поддерживают не только INT8, но и INT4. Поэтому теперь делаем квантизацию матриц Q и K в INT4, а также дополнительно нормализуем матрицу Q, чтобы не искажать итоговый результат attention. ➤ Реализуем глубокие CUDA-оптимизации. Например, оказывается, что масштабирование при квантизации можно параллельно применять сразу к элементам векторов по специальной формуле (каждый 8i + 2k + 1-й элемент). Жёсткая оптимизация на низком уровне GPU 🔥 ➤ Вычисления PV делаем в FP8, так как теперь нашли способ стабилизировать результат при помощи двойной буферизации с FP32 SpargeAttention: ➤ FlashAttention работает с матрицами поблочно. Если какой-то блок attention заполнен нулями, то считать его нет смысла — результат не изменится. ➤ Авторы придумали механизм, как быстро находить и пропускать такие блоки на основе похожести матриц Q и K. ➤ Дополнительно, в процессе вычисления attention можно пропускать отдельные блоки, основываясь на пороге. ➤ Поверх этой разреженности используется подход из SageAttention2 (с INT4-квантизацией и CUDA-оптимизациями). Результат 🔥 SageAttention уже быстрее, чем FlashAttention2 примерно в 2.1 раза (на RTX4090), а также на 2.7 раза быстрее xformers. При этом точность моделей практически не страдает. 🔥 SageAttention2 ускоряет attention примерно в 3 раза по сравнению с FlashAttention2 и в 4.5 раза быстрее xformers на RTX4090. 🔥 SpargeAttn ускоряет inference дополнительно, обеспечивая от 2.5 до 5 раз быстрее по сравнению с другими методами attention. Как применить у себя. Устанавливаем библиотеку и дальше делаем вот так import torch.nn.functional as F from sageattention import sageattn F.scaled_dot_product_attention = sageattn
1.1K
просмотров
3808
символов
Нет
эмодзи
Да
медиа

Другие посты @awesome_dl

Все посты канала →
⚡️SageAttention — brand new attention Flash Attention бустит — @awesome_dl | PostSniper