type
status
date
slug
summary
tags
category
icon
password
Memory and Computation in LLM Inference: Transformers vs. RWKV
In this post, we explore key aspects of memory usage and GPU utilization in large language models (LLMs) based on the Transformer architecture and the RWKV model. We cover data flow during inference, computational complexity, space complexity, and compare how Transformers and RWKV handle historical context.
1. Memory Usage and GPU Utilization
1.1 Memory Usage (VRAM)
GPU memory (VRAM) is used to store:
- Model weights and parameters
- Input data (such as images or other input tensors)
- Intermediate computation results (activations, gradients, etc.)
- Output data (inference results)
Memory allocation is typically determined at the start of a task and remains fairly stable until completion. For example:
- Inference Phase: Only the model weights and necessary intermediate buffers are loaded, so memory usage stays relatively constant.
- Training Phase: Additional storage is needed for gradients and optimizer states, resulting in higher memory consumption.
1.2 GPU Utilization
GPU utilization measures how busy the GPU cores are:
- High Utilization: Indicates that the GPU is engaged in compute-intensive tasks (e.g., matrix multiplications or convolution operations).
- Low Utilization: Suggests that the GPU is either waiting for data transfers or has fewer computations to perform.
2. Data Flow in Transformer-Based LLMs
A detailed look at how Transformer models process data in two distinct stages—prompt processing and decoding—with an emphasis on the KV-cache optimization.
2.1 Two-Stage LLM Inference with KV-Cache
2.1.1 Prompt Stage
During the prompt phase, all tokens in the input sequence are known. The model computes the queries (Q), keys (K), and values (V) for each token independently:
- Independent Computation: Each token x_i is transformed by linear layers with weights , , and as follows:
Since is fixed, these computations do not depend on other tokens.
- Parallel Processing: Operations such as the self-attention matrix multiplication can be computed in parallel without sequential dependencies.
2.1.2 Decoding Stage
KV-cache becomes truly effective in the decoding phase:
- Sequential Generation: The model generates tokens one by one because each new token depends on the previous output.
- Step 1: Start with a beginning-of-sequence token (e.g.,
<BOS>
), using the KV-cache to generate . - Step 2: Use along with the KV-cache to generate .
- Step 3: Continue sequentially with each generated token updating the cache.
2.2 Self-Attention and KV-Cache in Standard Transformers
Even with KV-cache optimization—which reduces space complexity to —the computation complexity remains . Consider a prompt with 100 tokens:
- KV-Cache Storage: At each time step t , the key cache stores vectors of shape and the value cache similarly has shape.
2.2.1 Computational Complexity Analysis
Using the LLama 3.2 3B model as an example:
- Model Configuration:
- Hidden size: 3072
- Number of heads: 54
- Head dimension d : 128
- Hence, = 3072 .
Self-Attention (SA) Computation:
- Projection: For each token, compute Q, K, V via linear transformations. Total cost for 3 projections:
- Attention Operation: For each head, multiply a () with a ( )to get a attention matrix, then multiply by V . For h heads, this costs:
- Final Projection: A projection of shape multiplied by costs:
Overall, the self-attention cost is:
(1)
Feed-Forward Network (FFN) Computation:
For two fully connected layers:
(2)
The self-attention cost dominates when:
For LLama, since the hidden size hd = 3072 , the self-attention computation overtakes the FFN cost when n > 3072 . Furthermore, the total computation in a Transformer layer is:
(3)
The quadratic term dominates when:
For LLama, this means n > = 15360 ; hence, only when the sequence length approaches 15,360 tokens does the quadratic complexity become fully apparent.
2.2.2 Space Complexity
Transformers typically use a KV-cache to reduce space complexity to O(n) . The KV-cache is a collection of (n, d) vectors. For instance, for the Llama 3.2 3B model with 3000 tokens:
- Weight Memory: Depends on the total parameter count and quantization method.
- For FP16:
- For 4-bit quantization (e.g., q4f16):
For P = parameters, the weight memory is approximately 1.5 GB.
- KV-Cache Memory:
For n = 3000 , d = 3072 , and L = 28 , this is about 1.92 GB.
- Total RAM Usage: Ignoring transient activations, the total memory is roughly .
3. The RWKV Architecture
RWKV aims to eliminate the quadratic complexity of self-attention by decomposing the architecture into two parts:
- Time Mixing: Functions similarly to the self-attention layer.
- Channel Mixing: Plays the role of the FFN layer.
3.1 Data Flow in RWKV
- Input Processing: Input tokens are converted into embeddings and processed through time mixing and channel mixing layers, with the state updated at each time step.
- State Propagation: The state (which compresses historical context) is passed forward with each new token.
- Context Storage: The context is stored in a fixed-size state independent of the sequence length.
- Decoding:
Beginning with a
<BOS>
token, new tokens are generated sequentially based on the final state.
Note: Detailed diagrams and further elaboration are available in the referenced documentation.
3.2 Time Mixing and Channel Mixing Details
Time Mixing:
For each token (with shape (1, d) ):
- The initial state is of shape (1, d) .
- Weight matrices W , K , and V are of shape (d, d) .
The time mixing computation can be written recursively as:
Channel Mixing:
Similarly, channel mixing follows:
3.3 RWKV Inference Process
3.3.1 Processing the Prompt
- Sequential Processing: The RWKV model processes each prompt token recursively, updating its state at every time step.
- Final State Storage: Once the prompt is fully processed, the final state encapsulates the entire context.
3.3.2 Generating New Tokens
- Adding an End-of-Sequence Token:
An explicit
<EOS>
token is appended to trigger the generation process.
- Recursive Generation:
The model uses the last state and
<EOS>
as input to generate a new token. The state is updated with each generated token to maintain context continuity.
The RWKV model effectively compresses historical information into a single state matrix of size . Its time complexity remains O(nd) while the space complexity is O(d) .
3.4 Memory Usage in RWKV
Taking the 3B RWKV q4_0 model as an example:
- Hidden Size: 7680
- Layer Count: 24
- Precision: FP16 (2 bytes per value)
- Per-Layer State: 2 state vectors (for time mixing and channel mixing)
Memory calculations:
- Weight Memory:
- State Memory:
- Activation Memory:
3.5 A Pitfall in the MLC Framework’s RWKV5 Implementation
In the current implementation within the MLC framework, the RWKV5 series maintains a huge buffer rather than a single state. This buffer holds N states, where N is determined by the parameter
max_history_size
. Even if one state is small, a large N (e.g., 100) can lead to extremely high memory usage.4. Comparison of RAM Usage: Transformers vs. RWKV
Memory consumption primarily comes from:
- Weight Memory (Model Parameters): Determined by the total parameter count and quantization method.
- For q4 fp16 quantization:
- KV-Cache (for Transformers) vs. State Memory (for RWKV):
- Transformers:
- RWKV:
where n is the number of input tokens, d is the hidden dimension, and L is the number of layers.
- Activation Memory: This is temporary during inference and is generally not counted toward the final memory usage.
- 作者:Sylvia
- 链接:https://vibesylvia.top/article/1a3beda9-55cb-80d1-9505-dd6f190bdd73
- 声明:本文采用 CC BY-NC-SA 4.0 许可协议,转载请注明出处。