Flash Attention

Wenn ein KI-Modell einen langen Text verarbeitet, muss es jedes Wort mit jedem anderen Wort vergleichen. Bei langen Texten wächst die Zahl dieser Vergleiche so stark, dass der Speicher der Grafikkarte nicht ausreicht. Flash Attention löst dieses Problem, indem es die Vergleiche in kleinen Blöcken durchführt, die in den schnellsten Speicher der GPU passen.

Der Attention-Mechanismus ist das zentrale Rechenverfahren in Transformer-Modellen. Für jede Position in einer Eingabesequenz berechnet er, wie stark jede andere Position beachtet werden soll. Das Ergebnis ist eine Matrix mit so vielen Einträgen wie die Sequenzlänge zum Quadrat. Bei einer Sequenz von 4.096 Tokens enthält diese Matrix bereits über 16 Millionen Werte.

Standard-Attention legt diese gesamte Matrix im Hauptspeicher der GPU ab (High Bandwidth Memory, HBM). Der HBM ist groß, aber der Datentransfer zwischen HBM und Recheneinheiten ist der eigentliche Engpass. Flash Attention vermeidet die vollständige Materialisierung dieser Matrix. Stattdessen werden die Berechnungen blockweise im schnellen On-Chip-Speicher (SRAM) durchgeführt, der um ein Vielfaches schnelleren Zugriff bietet.

Das Kernproblem: Speicherzugriffe als Engpass

Moderne GPUs verfügen über eine hohe Rechenleistung, die häufig nicht voll ausgelastet wird. Der Grund: Die Daten müssen aus dem HBM geladen werden, und die Bandbreite dieses Speichers begrenzt den tatsächlichen Durchsatz. Die Recheneinheiten warten auf Daten, anstatt zu rechnen.

Beispiel: Eine NVIDIA A100-GPU führt bis zu 312 Teraflops an Berechnungen pro Sekunde durch. Der HBM liefert Daten mit maximal 2 TB/s. Für Operationen mit wenig Rechenaufwand pro geladenem Datenelement (wie die elementweise Softmax-Berechnung in Attention) wartet die GPU den Großteil der Zeit auf Speichertransfers.

Beispiel: Standard-Attention bei einer Sequenzlänge von 8.192 Tokens erzeugt eine Attention-Matrix mit etwa 67 Millionen Einträgen. Bei 16-Bit-Genauigkeit belegt allein diese Matrix rund 128 MB im GPU-Speicher. Bei Multi-Head-Attention mit 32 Köpfen ergibt sich ein Speicherbedarf von über 4 GB nur für Attention-Matrizen eines einzigen Layers.

Flash Attention adressiert dieses Problem, indem es die gesamte Attention-Berechnung so umstrukturiert, dass sie im SRAM der GPU stattfindet. Der SRAM ist typischerweise 20 MB groß, bietet aber eine Bandbreite von etwa 19 TB/s. Das ist ein Faktor von knapp 10 gegenüber dem HBM.

Fachliche Einordnung: Die Unterscheidung zwischen compute-bound und memory-bound Operationen stammt aus dem Roofline-Modell der Computerarchitektur. Flash Attention verschiebt Attention von einer memory-bound zu einer compute-bound Operation, was die tatsächliche Hardwareauslastung erhöht.

Tiling: Blockweise Berechnung im schnellen Speicher

Flash Attention unterteilt die Eingabematrizen Q (Queries), K (Keys) und V (Values) in Blöcke fester Größe. Die Blockgröße wird so gewählt, dass die zugehörigen Daten vollständig in den SRAM passen. Für jeden Block wird die Attention lokal berechnet: Skalarprodukt zwischen Q-Block und K-Block, Softmax, gewichtete Summe mit V-Block. Die Teilergebnisse werden akkumuliert, ohne dass die vollständige Attention-Matrix jemals materialisiert wird.

Beispiel: Bei einer Sequenzlänge von 4.096 Tokens und einer Blockgröße von 256 entstehen 16 Blöcke. Statt eine 4.096 × 4.096 Matrix zu speichern, arbeitet Flash Attention mit 256 × 256 Teilmatrizen, die jeweils nur 128 KB belegen.

Der entscheidende Schritt liegt in der korrekten Akkumulation der Softmax-Ergebnisse über mehrere Blöcke hinweg. Softmax normalisiert über die gesamte Zeile der Attention-Matrix. Wenn nur ein Block sichtbar ist, fehlt der Gesamtnenner. Flash Attention löst dieses Problem mit dem Online-Softmax-Algorithmus, der den Normalisierungsfaktor inkrementell aktualisiert, sobald ein neuer Block verarbeitet wird.

Beispiel: Der Online-Softmax speichert pro Zeile zwei zusätzliche Werte: das bisherige Maximum der Eingaben und die laufende Summe der Exponentialwerte. Wenn ein neuer Block ein höheres Maximum enthält, werden die bisherigen Teilergebnisse mit einem Korrekturfaktor neu skaliert. Das Endergebnis ist mathematisch identisch mit Standard-Softmax.

Q, K, V Matrizenim HBM gespeichert
Blockweise AufteilungBlöcke passen in SRAM
Q-Block, K-Block ladenHBM → SRAM Transfer
Lokale AttentionSoftmax + V-Gewichtung
Online-Softmax-AkkumulationKorrekturfaktor pro Block
Endergebnis → HBM

Speicherbedarf und Laufzeitverhalten

Standard-Attention benötigt O(N²) zusätzlichen Speicher für die Attention-Matrix, wobei N die Sequenzlänge bezeichnet. Flash Attention reduziert diesen zusätzlichen Speicherbedarf auf O(N). Die eigentliche Berechnung bleibt bei O(N²) Operationen. Flash Attention berechnet exakt dasselbe Ergebnis. Es handelt sich nicht um eine Approximation.

Beispiel: Bei einer Sequenzlänge von 16.384 Tokens belegt die Attention-Matrix in Standard-Attention etwa 512 MB pro Head bei 16-Bit-Genauigkeit. Flash Attention benötigt für dieselbe Berechnung nur wenige Megabyte zusätzlichen Speicher, unabhängig von der Sequenzlänge.

Beispiel: In Messungen von Dao et al. (2022) erreichte Flash Attention auf einer A100-GPU eine Wandzeit-Beschleunigung um den Faktor 2,4 bei Sequenzlänge 1.024 und um den Faktor 7,6 bei Sequenzlänge 4.096 gegenüber einer PyTorch-Standardimplementierung.

Der Geschwindigkeitsgewinn erklärt sich nicht durch weniger Rechenoperationen, sondern durch weniger Speichertransfers. Die Gesamtzahl der Gleitkommaoperationen ist bei Flash Attention sogar leicht höher (durch die erneute Berechnung bei Rückwärtsdurchläufen). Trotzdem ist die Wandzeit geringer, weil die Daten im schnellen Speicher bleiben.

Recomputation beim Rükwärtsdurchlauf

Beim Training von neuronalen Netzen muss der Rükwärtsdurchlauf (Backpropagation) auf die Attention-Matrix zugreifen, um Gradienten zu berechnen. Standard-Attention speichert diese Matrix während des Vorwärtsdurchlaufs. Flash Attention speichert sie nicht. Stattdessen berechnet es die benötigten Werte im Rükwärtsdurchlauf erneut.

Beispiel: Ein Transformer-Modell mit 24 Layern und 16 Attention-Heads speichert bei Standard-Attention 384 Attention-Matrizen (24 × 16) pro Trainingsschritt. Bei Sequenzlänge 2.048 und Batch-Größe 8 ergibt das mehrere Gigabyte allein für Zwischenergebnisse. Flash Attention speichert stattdessen nur die Statistiken des Online-Softmax (Maximum und Summe pro Zeile), was um Größenordnungen weniger Speicher benötigt.

Diese Strategie wird als Recomputation oder Gradient Checkpointing bezeichnet. Sie erhöht die Zahl der Rechenoperationen, senkt aber den Speicherbedarf erheblich. Der Nettovorteil ist positiv, weil die Recomputation im SRAM stattfindet und damit schneller ist als ein zusätzlicher HBM-Zugriff.

Versionen: Flash Attention 2 und Flash Attention 3

Flash Attention 2 (Dao, 2023) brachte mehrere Verbesserungen gegenüber der ersten Version. Die Parallelisierung wurde von der Batch-Dimension auf die Sequenz-Dimension verlagert, was die GPU-Auslastung bei kleinen Batches verbessert. Die Reihenfolge der Schleifen wurde vertauscht: Die äußere Schleife iteriert nun über Q-Blöcke, die innere über K/V-Blöcke. Das reduziert unnötige Lese- und Schreibvorgänge.

Beispiel: Flash Attention 2 erreichte auf einer A100-GPU einen Durchsatz von etwa 230 Teraflops bei 16-Bit-Genauigkeit. Das entspricht rund 73 % der theoretischen Spitzenleistung. Standard-Attention erreichte auf derselben Hardware etwa 25 bis 40 % Auslastung.

Flash Attention 3 (Shah et al., 2024) nutzt zusätzlich Features der Hopper-Architektur (H100-GPUs): asynchrone Datenflüsse zwischen Warp-Gruppen, hardwarebeschleunigte Quantisierung auf FP8 und überlappende Berechnung mit Datentransfer über den Tensor Memory Accelerator (TMA).

Beispiel: Flash Attention 3 auf einer H100-GPU erreichte bis zu 740 Teraflops bei FP16-Genauigkeit. Das sind 75 % der theoretischen Spitzenleistung und etwa 1,5-mal schneller als Flash Attention 2 auf derselben Hardware.

Integration in Frameworks und Praxis

Flash Attention ist in alle gängigen Frameworks für das Training und die Ausführung von Transformer-Modellen integriert. PyTorch bietet seit Version 2.0 eine native Unterstützung über torch.nn.functional.scaled_dot_product_attention, die Flash Attention automatisch als Backend verwendet, wenn die Hardware es unterstützt.

Beispiel: In Hugging Face Transformers kann Flash Attention 2 durch ein einziges Argument aktiviert werden: model = AutoModelForCausalLM.from_pretrained("model_name", attn_implementation="flash_attention_2"). Das ändert keine Modellgewichte und kein Ergebnis, sondern nur die interne Berechnungsmethode.

CUDA-basierte GPUs ab der Ampere-Generation (A100, RTX 3090 und neuer) unterstützen Flash Attention. Für ältere GPUs existieren alternative Implementierungen wie xFormers oder Memory-Efficient Attention, die ähnliche Prinzipien verfolgen, aber nicht denselben Optimierungsgrad erreichen.

Beispiel: Beim Training eines Modells mit 7 Milliarden Parametern auf einer Sequenzlänge von 8.192 Tokens reduziert Flash Attention 2 den GPU-Speicherbedarf der Attention-Berechnung so stark, dass das Training auf einer einzelnen GPU mit 80 GB HBM möglich wird. Ohne Flash Attention wären bei gleicher Sequenzlänge mindestens zwei GPUs oder eine deutlich kürzere Sequenz nötig.

Varianten und Erweiterungen

Neben der Standard-Version existieren mehrere Varianten von Flash Attention, die für spezielle Anwendungsfälle optimiert sind.

Grouped Query Attention

Grouped Query Attention (GQA) reduziert die Zahl der Key/Value-Heads. Statt dass jeder Attention-Head eigene K- und V-Matrizen hat, teilen sich mehrere Heads eine gemeinsame K/V-Gruppe. Flash Attention unterstützt GQA nativ und profitiert von dem geringeren Speicherbedarf durch weniger K/V-Transfers.

Beispiel: Llama 2 70B verwendet GQA mit 8 K/V-Heads für 64 Q-Heads. In Kombination mit Flash Attention sinkt der KV-Cache-Speicherbedarf bei langen Kontexten um den Faktor 8 gegenüber klassischem Multi-Head-Attention.

Sliding Window Attention

Bei Sliding Window Attention beachtet jede Position nur eine feste Nachbarschaft. Flash Attention kann dieses Muster effizient umsetzen, indem es Blöcke außerhalb des Fensters überspringt und nicht lädt.

Beispiel: Mistral 7B verwendet Sliding Window Attention mit einer Fenstergröße von 4.096 Tokens. Flash Attention überspringt alle Blöcke, deren Abstand zur aktuellen Position das Fenster überschreitet. Bei einer Gesamtsequenz von 32.768 Tokens werden so bis zu 87 % der Blöcke nicht berechnet.

Grenzen und Einschränkungen

Flash Attention ist an bestimmte Hardwarevoraussetzungen gebunden. Die Implementierung nutzt CUDA-spezifische Operationen und setzt GPUs ab der Ampere-Generation voraus. Auf CPUs, TPUs oder älteren GPUs ist Flash Attention nicht direkt einsetzbar. Für AMD-GPUs existiert eine Portierung (Flash Attention für ROCm), die jedoch hinter der CUDA-Variante zurückliegt.

Die Blockgröße muss auf die SRAM-Kapazität der jeweiligen GPU abgestimmt werden. Ein falsch konfigurierter Block führt zu Leistungseinbußen oder Fehlern. In der Praxis wird die Blockgröße automatisch gewählt, aber bei exotischen Konfigurationen (sehr große Head-Dimensionen, ungewöhnliche Sequenzlängen) sind manuelle Anpassungen nötig.

Flash Attention reduziert den Speicherbedarf der Attention-Berechnung, nicht den Speicherbedarf des gesamten Modells. Modellgewichte, Aktivierungen anderer Schichten und Optimierer-Zustände bleiben unverändert. Bei sehr großen Modellen dominieren diese Faktoren den Gesamtspeicherbedarf.

Beispiel: Ein Modell mit 70 Milliarden Parametern belegt bei 16-Bit-Genauigkeit rund 140 GB allein für die Gewichte. Flash Attention spart in diesem Fall zwar Speicher bei der Attention-Berechnung, aber der Gesamtspeicherbedarf wird durch die Modellgröße dominiert. Techniken wie Quantisierung oder Modell-Parallelismus adressieren dieses Problem.

Fachliche Einordnung: Flash Attention ist eine reine Implementierungsoptimierung. Sie verändert weder die Modellarchitektur noch das Attention-Muster noch die mathematischen Ergebnisse. Die Innovation liegt in der Ausnutzung der Speicherhierarchie, nicht in einem neuen Algorithmus. Vergleichbare hardware-bewusste Optimierungen existieren in anderen Bereichen des High-Performance-Computing seit Jahrzehnten. Die Leistung von Flash Attention ist dabei an Nvidias GPU-Architektur gekoppelt. Alternative Beschleuniger wie Googles TPUs oder Intels Gaudi verwenden eigene Optimierungsstrategien für Attention.


Karl Kratz · 21.06.2025 (aktualisiert 03.04.2026)

Technologie Künstliche Intelligenz LLM