>
大语言模型(LLM)的 Context Window(上下文窗口)指的是模型在一次推理过程中能够"看到"的最大 token 数量。这个数字直接决定了模型能处理多长的输入、维持多长的对话记忆、以及在 RAG 中能注入多少上下文。
以主流模型为例:
| 模型 | Context Window | 约等于中文字数 |
|---|---|---|
| GPT-3.5-turbo | 16K tokens | 约 1.2 万字 |
| GPT-4o | 128K tokens | 约 10 万字 |
| Claude 3.5 Sonnet | 200K tokens | 约 15 万字 |
| Claude 3 Opus | 200K tokens | 约 15 万字 |
| Gemini 1.5 Pro | 1M tokens | 约 75 万字 |
| Llama 3.1 8B | 128K tokens | 约 10 万字 |
| Qwen2 72B | 128K tokens | 约 10 万字 |
Transformer 的自注意力机制(Self-Attention)是 Context Window 的根本约束。标准自注意力的计算复杂度是 O(n²),其中 n 是序列长度:
# 标准 Self-Attention 的计算复杂度分析
# 对于长度为 n 的序列,注意力矩阵是 n×n
# 每次前向传播需要计算:
# Q @ K^T → n×d @ d×n → n² 乘法
# Softmax → n² 归一化
# @ V → n×n @ n×d → n²·d 乘法
# 实际数字示例:n=4096 tokens
# 注意力矩阵大小: 4096 × 4096 = 16,777,216 元素
# float16 下占用: 16,777,216 × 2 bytes ≈ 32 MB(仅一组注意力头)
# 128K tokens 的注意力矩阵:
# 131,072 × 131,072 = 17,179,869,184 元素
# float16 下占用: 约 32 GB(仅一组头)
# 如果有 96 个注意力头,内存需求直接爆表
自回归生成(Autoregressive Generation)有一个关键特性:每次生成新 token 时,只需要关注"所有历史 token + 新 token",而不是从头计算。这就是 KV Cache 的核心思想。
# KV Cache 工作原理的 Python 伪代码
class KVCache:
"""
KV Cache 的核心逻辑:
每生成一个新 token,历史 K 和 V 矩阵保留,
只计算新 token 的 K 和 V,然后拼接。
这样避免了对历史 token 重复计算注意力。
"""
def __init__(self, max_seq_len: int, n_heads: int, head_dim: int, dtype=np.float16):
self.max_seq_len = max_seq_len
self.n_heads = n_heads
self.head_dim = head_dim
# 预分配缓存空间(PyTorch 实现原理)
# 第一次调用时分配最大长度,之后只写入不重新分配
self.k_cache = torch.zeros(
(n_heads, max_seq_len, head_dim),
dtype=dtype
) # shape: [n_heads, max_seq_len, head_dim]
self.v_cache = torch.zeros(
(n_heads, max_seq_len, head_dim),
dtype=dtype
)
self.current_len = 0 # 当前填充到的位置
def update(self, new_k: torch.Tensor, new_v: torch.Tensor, pos: int):
"""
将新计算出的 K、V 写入缓存的指定位置
Args:
new_k: 新 token 的 K 向量,shape [n_heads, 1, head_dim]
new_v: 新 token 的 V 向量,shape [n_heads, 1, head_dim]
pos: 写入位置(当前序列长度)
"""
self.k_cache[:, pos:pos+1, :] = new_k
self.v_cache[:, pos:pos+1, :] = new_v
self.current_len += 1
def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""获取当前所有历史 K、V 用于注意力计算"""
return (
self.k_cache[:, :self.current_len, :],
self.v_cache[:, :self.current_len, :]
)
def attention_with_cache(query: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
"""
带 KV Cache 的注意力计算
与标准注意力的区别:
- 标准:Q @ K^T 需要 O(n²) 与全部历史计算
- Cache:Q 只与缓存中的 K 计算,新 token 的 K 不需要与历史 K 重复计算
但实际上attention计算本身仍然是 O(n²)——
KV Cache 节省的是"重复计算 Q 相对于历史 K 的部分",
而不是减少注意力矩阵的规模。
真正节省的是:
1. 不需要重新计算历史 token 的 K(已缓存)
2. 不需要重新计算历史 token 的激活值(已缓存)
但 K 和 V 之间的注意力矩阵仍然需要 O(n²) 计算。
"""
k_cached, v_cached = kv_cache.get_cache()
# query: [batch, n_heads, 1, head_dim](当前 token)
# k_cached: [n_heads, seq_len, head_dim]
# 调整维度以便计算
# query: [n_heads, 1, head_dim](单token查询)
q = query.squeeze(0) # [n_heads, head_dim]
# 计算注意力分数
# q [n_heads, head_dim] @ k_cached.T [n_heads, head_dim, seq_len]
# 结果: [n_heads, seq_len]
scores = torch.einsum('hd,hdn->hn', q, k_cached) / (q.shape[-1] ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
# attn_weights [n_heads, seq_len] @ v_cached [n_heads, seq_len, head_dim]
# 结果: [n_heads, head_dim]
output = torch.einsum('hn,hnd->hd', attn_weights, v_cached)
return output.unsqueeze(0)
Context Window 的另一个约束来自位置编码(Positional Encoding)。大多数模型使用 Rotary Position Embedding(RoPE),它通过旋转机制将位置信息注入到 Q 和 K 向量中。
# RoPE(旋转位置编码)的核心数学原理
# 标准 Transformer 用绝对位置编码:PE(pos) 直接加到 token embedding 上
# RoPE 用旋转矩阵对 Q 和 K 进行编码,让注意力天然包含相对位置信息
import torch
import math
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
RoPE 旋转函数:将向量后半部分旋转到负方向
数学原理:
对于一个 d 维向量,前半部分与后半部分配对:
[x_0, x_1, ..., x_{d/2-1}, x_{d/2}, ..., x_{d-1}]
旋转后:
[-x_{d/2}, -x_{d/2+1}, ..., -x_{d-1}, x_0, x_1, ..., x_{d/2-1}]
等价于乘以下面的旋转矩阵(d维):
[ cosθ -sinθ | 0 0 ]
[ sinθ cosθ | 0 0 ]
[ 0 0 | -cosθ sinθ]
[ 0 0 | -sinθ -cosθ]
"""
x1 = x[..., :x.shape[-1] // 2] # 前半部分
x2 = x[..., x.shape[-1] // 2:] # 后半部分
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor, # query 向量
k: torch.Tensor, # key 向量
freqs_cos: torch.Tensor, # 预计算的 cos 值
freqs_sin: torch.Tensor # 预计算的 sin 值
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
对 Q 和 K 向量应用 RoPE 旋转
核心公式:
RoPE(q, pos) = q * cos(θ_pos) + rotate_half(q) * sin(θ_pos)
RoPE(k, pos) = k * cos(θ_pos) + rotate_half(k) * sin(θ_pos)
旋转后,Q 和 K 的点积自然包含相对位置信息:
RoPE(q, m) · RoPE(k, n) 的点积只与 (m-n) 有关
"""
# 将频率分为余弦和正弦部分
# 旋转公式:x_rot = x * cos + rotate_half(x) * sin
q_embed = (q * freqs_cos) + (rotate_half(q) * freqs_sin)
k_embed = (k * freqs_cos) + (rotate_half(k) * freqs_sin)
return q_embed, k_embed
def build_rotary_frequencies(
seq_len: int,
dim: int, # head_dim,通常是 128
base: float = 10000.0,
device: str = 'cuda'
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
构建 RoPE 频率张量(预计算,避免重复计算)
θ_i = base^{-2i/dim},i = 0, 1, ..., dim/2-1
这个指数衰减的频率设计保证了:
- 低频维度(i 小,θ 大)→ 长距离依赖
- 高频维度(i 大,θ 小)→ 短距离依赖
"""
# 计算每个维度的频率倒数的平方根
# θ_i = base^{-2i/dim}
# freqs = [base^{0}, base^{-2/dim}, base^{-4/dim}, ...]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
# 生成位置索引 [0, 1, 2, ..., seq_len-1]
t = torch.arange(seq_len, device=device)
# 计算每个位置 × 每个频率的外积
# t[:, None] * freqs[None, :] → [seq_len, dim/2]
# 形状:(seq_len, dim/2)
freqs = torch.outer(t, freqs)
# 转换为极坐标形式:cos(θ) 和 sin(θ)
# shape: [seq_len, dim/2]
freqs_cos = freqs.cos()
freqs_sin = freqs.sin()
return freqs_cos, freqs_sin
# 完整的多头注意力前向传播(带 RoPE + KV Cache)
def multihead_attention_with_rope_and_cache(
x: torch.Tensor, # 输入 token embedding [batch, seq_len, d_model]
cached_k: torch.Tensor, # 历史 K 缓存 [batch, n_heads, seq_len, head_dim]
cached_v: torch.Tensor, # 历史 V 缓存 [batch, n_heads, seq_len, head_dim]
current_pos: int, # 当前处理的起始位置
w_q, w_k, w_v, w_o: torch.Tensor, # 权重矩阵
freqs_cos: torch.Tensor, # RoPE cos 频率
freqs_sin: torch.Tensor # RoPE sin 频率
):
"""
带 RoPE 和 KV Cache 的多头注意力实现
"""
batch_size, seq_len, d_model = x.shape
n_heads = w_q.shape[0] # 从权重形状推断
head_dim = w_q.shape[-1]
# 投影到 Q、K、V
q = x @ w_q # [batch, seq_len, n_heads*head_dim]
k = x @ w_k
v = x @ w_v
# 重塑为多头格式
# [batch, seq_len, n_heads, head_dim] → [batch, n_heads, seq_len, head_dim]
q = q.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)
# 提取当前位置的频率(用于 RoPE)
# freqs_cos/sin shape: [seq_len, head_dim/2]
pos_cos = freqs_cos[current_pos:current_pos+seq_len]
pos_sin = freqs_sin[current_pos:current_pos+seq_len]
# 应用 RoPE(只对 Q 和 K,不对 V)
# 需要将频率 reshape 为匹配 head 的维度
# pos_cos: [seq_len, head_dim/2] → [1, 1, seq_len, head_dim/2]
q, k = apply_rotary_pos_emb(
q, k,
pos_cos.unsqueeze(0).unsqueeze(0),
pos_sin.unsqueeze(0).unsqueeze(0)
)
# 更新 KV Cache
new_k = k
new_v = v
# 将新的 K、V 追加到缓存
# cached_k[:, :, current_pos:current_pos+seq_len, :] = k
# 计算注意力
# 如果有历史缓存,需要将当前 K 与历史 K 拼接
if current_pos > 0:
# k_all: [batch, n_heads, current_pos+seq_len, head_dim]
# 需要从 cached_k 中取 [0:current_pos] 部分
k_all = torch.cat([cached_k[:, :, :current_pos, :], k], dim=2)
v_all = torch.cat([cached_v[:, :, :current_pos, :], v], dim=2)
else:
k_all = k
v_all = v
# 注意力计算:Q @ K^T / sqrt(d_k)
scores = torch.matmul(q, k_all.transpose(-2, -1)) / math.sqrt(head_dim)
attn_weights = F.softmax(scores, dim=-1)
# @ V 并输出
output = torch.matmul(attn_weights, v_all)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
output = output @ w_o
return output, new_k, new_v
为解决 O(n²) 问题,研究者提出了 Sliding Window Attention(滑动窗口注意力)。每个 token 只关注其周围 W 个 token,而不关注全序列:
# Sliding Window Attention 的核心逻辑
"""
标准 Attention:每个位置可以看到所有历史位置
O(n²) 复杂度
Sliding Window Attention(W=4):
position 0: attend [0]
position 1: attend [0,1]
position 2: attend [0,1,2]
position 3: attend [0,1,2,3]
position 4: attend [1,2,3,4] ← 窗口滑动
position 5: attend [2,3,4,5] ← 窗口滑动
...
复杂度降为 O(n·W),W 是窗口大小(常数)
"""
def sliding_window_attention(
q: torch.Tensor, # [batch, n_heads, seq_len, head_dim]
k: torch.Tensor, # [batch, n_heads, seq_len, head_dim]
v: torch.Tensor, # [batch, n_heads, seq_len, head_dim]
window_size: int = 4
) -> torch.Tensor:
"""
滑动窗口注意力实现
Args:
q, k, v: 多头注意力的 Q、K、V 向量
window_size: 每个位置前后各看多少个 token
Returns:
加权输出向量
"""
seq_len = q.shape[2]
head_dim = q.shape[-1]
# 为每个位置创建掩码,使得窗口外的位置得分为 -inf
# 掩码形状:[seq_len, seq_len]
mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device) * float('-inf'), diagonal=-window_size)
# 设置下三角(window_size 以内)为 0(允许注意)
mask = torch.tril(mask, diagonal=0) + torch.triu(torch.zeros(seq_len, seq_len, device=q.device), diagonal=window_size)
# 由于 tril 起点不对,重新构造正确掩码
# 更简洁的做法:
mask = torch.zeros(seq_len, seq_len, device=q.device)
for i in range(seq_len):
start = max(0, i - window_size)
mask[i, start:i+1] = 0
mask[i, i+1:] = float('-inf')
# 计算注意力分数
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
# 应用窗口掩码
scores = scores + mask.unsqueeze(0).unsqueeze(0) # broadcast to [batch, n_heads, seq_len, seq_len]
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, v)
return output
# 分层稀疏注意力(Longformer 风格)
# 远距离用大窗口,近距离用小窗口
def hierarchical_sparse_attention(
q, k, v,
local_window: int = 256,
global_window: int = 512
):
"""
分层稀疏注意力:
1. 每个 token 看局部 local_window 范围(精细)
2. 每个 token 额外看全局 global_window 范围的汇总信息(粗粒度)
复杂度:O(n·local_window + n·global_window) ≈ O(n)
"""
seq_len = q.shape[2]
# 局部注意力:O(n·local_window)
local_out = sliding_window_attention(q, k, v, window_size=local_window)
# 全局注意力:每隔 stride 个 token 汇聚信息
# 这部分用更小的粒度处理
stride = global_window // 4
global_k = k[:, :, ::stride, :] # 稀疏采样 key
global_v = v[:, :, ::stride, :]
# 计算与全局 key 的注意力
global_scores = torch.matmul(q, global_k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
global_attn = F.softmax(global_scores, dim=-1)
global_out = torch.matmul(global_attn, global_v)
# 融合局部和全局
return local_out + 0.5 * global_out
FlashAttention 是目前最主流的高效注意力实现,它通过分块计算和算子融合将显存占用从 O(n²) 降到 O(n),同时保持数值精度。
# FlashAttention 核心思想:分块 + 在线 softmax
"""
标准 attention 的问题:
1. 需要计算完整的 n×n 注意力矩阵,显存 O(n²)
2. 需要将矩阵分成多个矩阵相乘,HBM 读写开销巨大
FlashAttention 解决方案:
将注意力矩阵按块(block)分块计算,
在累加过程中实时更新 softmax 统计量(max 和 sum),
避免一次性加载完整矩阵。
数学原理:
Softmax(x)_i = exp(x_i - max(x)) / sum_j(exp(x_j - max(x)))
当分块计算时:
已知前 i 块的 max 和 sum,
处理第 i+1 块时:
new_max = max(old_max, block_max)
new_sum = old_sum * exp(old_max - new_max) + block_sum * exp(block_max - new_max)
最终每个元素的 softmax = exp(x_i - global_max) * exp(x_i_local - block_max) / new_sum
"""
# 伪代码展示核心逻辑(实际 GPU 实现请用 Triton 或 CUDA)
def flash_attention_fused(
q: torch.Tensor, # [batch, n_heads, seq_len, head_dim]
k: torch.Tensor, # [batch, n_heads, seq_len, head_dim]
v: torch.Tensor, # [batch, n_heads, seq_len, head_dim]
block_size: int = 128
):
batch, n_heads, seq_len, head_dim = q.shape
n_blocks = (seq_len + block_size - 1) // block_size
# 初始化输出和 softmax 统计量
output = torch.zeros_like(q)
lse = torch.zeros((batch, n_heads, seq_len), device=q.device) # log-sum-exp
for i in range(n_blocks):
# 加载当前块
q_block = q[:, :, i*block_size:(i+1)*block_size, :] # [batch, n_heads, block, head_dim]
# 计算与所有 K 块的注意力(分块累加)
for j in range(n_blocks):
k_block = k[:, :, j*block_size:(j+1)*block_size, :]
v_block = v[:, :, j*block_size:(j+1)*block_size, :]
# 局部 softmax 计算
# scores: [batch, n_heads, block, block]
scores = torch.matmul(q_block, k_block.transpose(-2, -1)) / math.sqrt(head_dim)
# 在线 softmax 融合
if i == 0 and j == 0:
# 第一个块:直接计算
block_max = scores.amax(dim=-1, keepdim=True)
block_sum = torch.exp(scores - block_max).sum(dim=-1, keepdim=True)
# 更新全局统计量
cur_max = block_max
cur_sum = block_sum
else:
# 后续块:融合到全局
block_max = scores.amax(dim=-1, keepdim=True)
new_max = torch.maximum(cur_max, block_max)
# 重新对齐并累加
block_sum = torch.exp(scores - new_max).sum(dim=-1, keepdim=True)
cur_sum = cur_sum * torch.exp(cur_max - new_max) + block_sum
cur_max = new_max
# 计算最终输出
# output_block = exp(q_block @ k^T - global_max) / sum / v^T
# 简化处理(实际实现更复杂)
attn_block = torch.matmul(
torch.exp(scores - cur_max.unsqueeze(-1)), # 已对齐的权重
v_block
)
output[:, :, i*block_size:(i+1)*block_size, :] = attn_block
return output
当生成长文本时,模型需要知道何时停止。主流模型使用以下停止机制:
| 停止机制 | 说明 | 示例 |
|---|---|---|
| EOS Token | 遇到 End-of-Sequence 标记停止 | Llama 用 <|eot_id|> |
| Max Tokens | 到达预设最大生成长度停止 | API 调用时设置 max_tokens |
| Stop Sequences | 遇到指定字符串停止 | 停止词如 "\n\n" 或 "## " |
| Logit Mask | 在解码阶段屏蔽某些 token | 屏蔽掉终止符后的 token |
class StoppingCriteria:
"""
自定义停止条件,用于生成循环中判断是否应该停止
支持多种停止条件组合:
- 遇到特定 token
- 遇到特定字符串
- 达到最大 token 数
"""
def __init__(
self,
eos_token_id: int,
max_new_tokens: int,
stop_sequences: List[str] = None,
tokenizer=None
):
self.eos_token_id = eos_token_id
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences or []
self.tokenizer = tokenizer
self.generated_tokens = []
self.generated_text = ""
def __call__(self, last_token_id: int, full_sequence: List[int]) -> bool:
"""
每次生成一个 token 后调用此方法判断是否停止
Args:
last_token_id: 最新生成的 token ID
full_sequence: 目前为止的完整 token 序列
Returns:
True = 停止生成,False = 继续生成
"""
self.generated_tokens.append(last_token_id)
# 条件1:达到最大 token 数
if len(self.generated_tokens) >= self.max_new_tokens:
return True
# 条件2:遇到 EOS token
if last_token_id == self.eos_token_id:
return True
# 条件3:检查停止字符串(需要解码后判断)
if self.tokenizer is not None:
# 将最新 token 解码(增量解码)
new_token_text = self.tokenizer.decode([last_token_id])
self.generated_text += new_token_text
# 检查是否匹配任何停止序列
for stop_seq in self.stop_sequences:
if stop_seq in self.generated_text:
# 发现停止序列,返回 True
# 可以选择是否在序列处截断
return True
return False
def streaming_generation_with_stop(
model,
tokenizer,
prompt: str,
max_new_tokens: int = 512,
stop_sequences: List[str] = ["\n\n## ", "。", "---"]
):
"""
带停止条件的大模型流式生成示例
支持 Ctrl+C 中断、流式输出、任意停止序列检测
"""
# 编码输入
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# 初始化停止条件
stopping = StoppingCriteria(
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=max_new_tokens,
stop_sequences=stop_sequences,
tokenizer=tokenizer
)
# 生成循环
with torch.no_grad():
generated = input_ids[0].tolist()
output_text = ""
while True:
# 前向传播(使用缓存加速)
outputs = model(torch.tensor([generated]))
next_token_logits = outputs.logits[0, -1, :]
# 采样(可替换为 greedy/top_k/top_p)
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
generated.append(next_token)
# 检查停止条件
if stopping(last_token_id=next_token, full_sequence=generated):
break
# 流式输出(yield)
new_text = tokenizer.decode([next_token])
output_text += new_text
yield new_text
return output_text
class ContextWindowManager:
"""
当输入超出 Context Window 时的降级策略管理器
"""
def __init__(self, model_name: str, max_tokens: int):
self.max_tokens = max_tokens
# 保留 20% 空间给输出
self.input_max = int(max_tokens * 0.8)
def truncate_or_summarize(self, texts: List[str]) -> str:
"""
处理超出 Context Window 的文本
策略:
1. 如果总长度在限制内:直接拼接
2. 如果超出:从后往前截断(因为最近的内容更重要)
3. 如果单段就超出:使用分段 + 摘要
"""
# 估算总 token 数
total_tokens = sum(self._estimate_tokens(t) for t in texts)
if total_tokens <= self.input_max:
return "\n".join(texts)
# 从后往前截断(近期内容更重要)
result = []
remaining = self.input_max
for text in reversed(texts):
tokens = self._estimate_tokens(text)
if tokens <= remaining:
result.insert(0, text)
remaining -= tokens
else:
# 截断此文本
truncated = self._truncate_text(text, remaining)
result.insert(0, truncated)
break
return "\n".join(result)
def _estimate_tokens(self, text: str) -> int:
"""估算中文字符数(粗略:1中文≈1.5 token)"""
chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
other_chars = len(text) - chinese_chars
return int(chinese_chars * 1.5 + other_chars * 0.25)
def _truncate_text(self, text: str, max_tokens: int) -> str:
"""按 token 数截断文本"""
max_chars = int(max_tokens / 1.5) # 估算
if len(text) <= max_chars:
return text
return text[:max_chars] + "...[内容已截断]"
Context Window 是 LLM 最关键的资源之一,理解其背后的技术原理(O(n²) 注意力复杂度、KV Cache、RoPE 位置编码、FlashAttention 分块计算)可以帮助你: