Sobes.tech
Back to questions
Senior
14

Что такое FlashAttention и какой выигрыш по памяти он даёт?

Companies where asked
Skyeng
AstonAston

Answer from AI

sobes.tech AI

Что хотят услышать интервьюеры:

FlashAttention — это способ считать attention без хранения всей матрицы attention в памяти. Основная идея — делать вычисления блоками и сразу сливать промежуточные результаты, чтобы резко снизить потребление HBM/VRAM. В типичном случае это уменьшает память с квадратичной по длине последовательности до почти линейной по рабочему буферу.

Определение:

FlashAttention — это алгоритм для точного вычисления scaled dot-product attention, оптимизированный под GPU. Он не материализует полную матрицу QK^T и не держит все intermediate activations в виде больших тензоров, а разбивает вычисление на блоки и использует online softmax, чтобы считать результат потоково. За счёт этого снижается пиковое использование памяти и часто растёт скорость из-за лучшей локальности доступа к данным.

Пример использования:

В трансформере с длинным контекстом обычный attention быстро упирается в память, потому что нужно хранить матрицу размера L x L. FlashAttention позволяет обучать или дообучать модель на более длинных последовательностях при том же объёме GPU-памяти.

# Концептуальный пример: включение memory-efficient attention в Transformer
# Конкретный API зависит от фреймворка и версии библиотеки.

import torch
import torch.nn.functional as F

q = torch.randn(2, 8, 1024, 64, device="cuda")
k = torch.randn(2, 8, 1024, 64, device="cuda")
v = torch.randn(2, 8, 1024, 64, device="cuda")

# Если backend/фреймворк поддерживает FlashAttention,
# attention будет считаться через оптимизированную реализацию.
out = F.scaled_dot_product_attention(q, k, v, is_causal=False)

Пояснение кода:

Код показывает типичный сценарий: есть запросы q, ключи k и значения v для батча, голов и длины контекста. При обычной реализации attention сначала строится большая матрица сходств между q и k, а затем применяется softmax и умножение на v. В оптимизированной реализации это делается блоками: GPU обрабатывает кусок последовательности, сразу нормализует результаты и не хранит всю матрицу целиком. На практике это даёт заметное снижение пикового расхода памяти и часто позволяет увеличить длину контекста или batch size.

Ключевые моменты:

  • FlashAttention считает attention точно, а не приближённо.
  • Главный выигрыш — меньшее пиковое использование памяти за счёт отказа от хранения полной L x L матрицы attention.
  • Вычисление идёт блоками с online softmax, что улучшает локальность данных на GPU.
  • По памяти выигрыш особенно заметен на длинных последовательностях и больших batch size.
  • На практике это часто даёт и ускорение, потому что уменьшается число обращений к медленной памяти.
  • Это не магия «без O(L²) вычислений»: по сложности вычислений attention остаётся квадратичным, но память и эффективность доступа к данным улучшаются существенно.