>
← 返回投肯智能知识库首页

大模型 Context Window 机制深度解析:从理论到代码实现

作者:重庆投肯小刚更新日期:2026年5月阅读时长:30分钟

一、Context Window 是什么?

大语言模型(LLM)的 Context Window(上下文窗口)指的是模型在一次推理过程中能够"看到"的最大 token 数量。这个数字直接决定了模型能处理多长的输入、维持多长的对话记忆、以及在 RAG 中能注入多少上下文。

以主流模型为例:

模型Context Window约等于中文字数
GPT-3.5-turbo16K tokens约 1.2 万字
GPT-4o128K tokens约 10 万字
Claude 3.5 Sonnet200K tokens约 15 万字
Claude 3 Opus200K tokens约 15 万字
Gemini 1.5 Pro1M tokens约 75 万字
Llama 3.1 8B128K tokens约 10 万字
Qwen2 72B128K tokens约 10 万字
核心理解:Context Window 是模型在一次"思考"中能容纳的最大信息量。不是模型"记住"多少,而是模型"同时看见"多少。超出这个范围的文档,模型完全看不见。

二、Context Window 的技术原理

2.1 Transformer 的二次复杂度问题

Transformer 的自注意力机制(Self-Attention)是 Context Window 的根本约束。标准自注意力的计算复杂度是 O(n²),其中 n 是序列长度:

# 标准 Self-Attention 的计算复杂度分析

# 对于长度为 n 的序列,注意力矩阵是 n×n
# 每次前向传播需要计算:
#   Q @ K^T  → n×d @ d×n → n² 乘法
#   Softmax  → n² 归一化
#   @ V      → n×n @ n×d → n²·d 乘法

# 实际数字示例:n=4096 tokens
# 注意力矩阵大小: 4096 × 4096 = 16,777,216 元素
# float16 下占用: 16,777,216 × 2 bytes ≈ 32 MB(仅一组注意力头)

# 128K tokens 的注意力矩阵:
# 131,072 × 131,072 = 17,179,869,184 元素
# float16 下占用: 约 32 GB(仅一组头)

# 如果有 96 个注意力头,内存需求直接爆表

2.2 KV Cache:解决推理重复计算

自回归生成(Autoregressive Generation)有一个关键特性:每次生成新 token 时,只需要关注"所有历史 token + 新 token",而不是从头计算。这就是 KV Cache 的核心思想。

# KV Cache 工作原理的 Python 伪代码

class KVCache:
    """
    KV Cache 的核心逻辑:
    每生成一个新 token,历史 K 和 V 矩阵保留,
    只计算新 token 的 K 和 V,然后拼接。
    
    这样避免了对历史 token 重复计算注意力。
    """
    
    def __init__(self, max_seq_len: int, n_heads: int, head_dim: int, dtype=np.float16):
        self.max_seq_len = max_seq_len
        self.n_heads = n_heads
        self.head_dim = head_dim
        
        # 预分配缓存空间(PyTorch 实现原理)
        # 第一次调用时分配最大长度,之后只写入不重新分配
        self.k_cache = torch.zeros(
            (n_heads, max_seq_len, head_dim),
            dtype=dtype
        )  # shape: [n_heads, max_seq_len, head_dim]
        self.v_cache = torch.zeros(
            (n_heads, max_seq_len, head_dim),
            dtype=dtype
        )
        
        self.current_len = 0  # 当前填充到的位置
    
    def update(self, new_k: torch.Tensor, new_v: torch.Tensor, pos: int):
        """
        将新计算出的 K、V 写入缓存的指定位置
        
        Args:
            new_k: 新 token 的 K 向量,shape [n_heads, 1, head_dim]
            new_v: 新 token 的 V 向量,shape [n_heads, 1, head_dim]
            pos: 写入位置(当前序列长度)
        """
        self.k_cache[:, pos:pos+1, :] = new_k
        self.v_cache[:, pos:pos+1, :] = new_v
        self.current_len += 1
    
    def get_cache(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """获取当前所有历史 K、V 用于注意力计算"""
        return (
            self.k_cache[:, :self.current_len, :],
            self.v_cache[:, :self.current_len, :]
        )


def attention_with_cache(query: torch.Tensor, kv_cache: KVCache) -> torch.Tensor:
    """
    带 KV Cache 的注意力计算
    
    与标准注意力的区别:
    - 标准:Q @ K^T 需要 O(n²) 与全部历史计算
    - Cache:Q 只与缓存中的 K 计算,新 token 的 K 不需要与历史 K 重复计算
    
    但实际上attention计算本身仍然是 O(n²)——
    KV Cache 节省的是"重复计算 Q 相对于历史 K 的部分",
    而不是减少注意力矩阵的规模。
    
    真正节省的是:
    1. 不需要重新计算历史 token 的 K(已缓存)
    2. 不需要重新计算历史 token 的激活值(已缓存)
    
    但 K 和 V 之间的注意力矩阵仍然需要 O(n²) 计算。
    """
    k_cached, v_cached = kv_cache.get_cache()
    # query: [batch, n_heads, 1, head_dim](当前 token)
    # k_cached: [n_heads, seq_len, head_dim]
    
    # 调整维度以便计算
    # query: [n_heads, 1, head_dim](单token查询)
    q = query.squeeze(0)  # [n_heads, head_dim]
    
    # 计算注意力分数
    # q [n_heads, head_dim] @ k_cached.T [n_heads, head_dim, seq_len]
    # 结果: [n_heads, seq_len]
    scores = torch.einsum('hd,hdn->hn', q, k_cached) / (q.shape[-1] ** 0.5)
    attn_weights = F.softmax(scores, dim=-1)
    
    # 加权求和
    # attn_weights [n_heads, seq_len] @ v_cached [n_heads, seq_len, head_dim]
    # 结果: [n_heads, head_dim]
    output = torch.einsum('hn,hnd->hd', attn_weights, v_cached)
    
    return output.unsqueeze(0)

2.3 位置编码与注意力范围

Context Window 的另一个约束来自位置编码(Positional Encoding)。大多数模型使用 Rotary Position Embedding(RoPE),它通过旋转机制将位置信息注入到 Q 和 K 向量中。

# RoPE(旋转位置编码)的核心数学原理

# 标准 Transformer 用绝对位置编码:PE(pos) 直接加到 token embedding 上
# RoPE 用旋转矩阵对 Q 和 K 进行编码,让注意力天然包含相对位置信息

import torch
import math

def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    RoPE 旋转函数:将向量后半部分旋转到负方向
    
    数学原理:
    对于一个 d 维向量,前半部分与后半部分配对:
    [x_0, x_1, ..., x_{d/2-1}, x_{d/2}, ..., x_{d-1}]
    
    旋转后:
    [-x_{d/2}, -x_{d/2+1}, ..., -x_{d-1}, x_0, x_1, ..., x_{d/2-1}]
    
    等价于乘以下面的旋转矩阵(d维):
    [  cosθ  -sinθ |  0    0  ]
    [  sinθ   cosθ |  0    0  ]
    [   0      0   | -cosθ  sinθ]
    [   0      0   | -sinθ -cosθ]
    """
    x1 = x[..., :x.shape[-1] // 2]      # 前半部分
    x2 = x[..., x.shape[-1] // 2:]      # 后半部分
    return torch.cat([-x2, x1], dim=-1)


def apply_rotary_pos_emb(
    q: torch.Tensor,        # query 向量
    k: torch.Tensor,        # key 向量
    freqs_cos: torch.Tensor, # 预计算的 cos 值
    freqs_sin: torch.Tensor  # 预计算的 sin 值
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    对 Q 和 K 向量应用 RoPE 旋转
    
    核心公式:
    RoPE(q, pos) = q * cos(θ_pos) + rotate_half(q) * sin(θ_pos)
    RoPE(k, pos) = k * cos(θ_pos) + rotate_half(k) * sin(θ_pos)
    
    旋转后,Q 和 K 的点积自然包含相对位置信息:
    RoPE(q, m) · RoPE(k, n) 的点积只与 (m-n) 有关
    """
    # 将频率分为余弦和正弦部分
    # 旋转公式:x_rot = x * cos + rotate_half(x) * sin
    q_embed = (q * freqs_cos) + (rotate_half(q) * freqs_sin)
    k_embed = (k * freqs_cos) + (rotate_half(k) * freqs_sin)
    return q_embed, k_embed


def build_rotary_frequencies(
    seq_len: int,
    dim: int,        # head_dim,通常是 128
    base: float = 10000.0,
    device: str = 'cuda'
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    构建 RoPE 频率张量(预计算,避免重复计算)
    
    θ_i = base^{-2i/dim},i = 0, 1, ..., dim/2-1
    
    这个指数衰减的频率设计保证了:
    - 低频维度(i 小,θ 大)→ 长距离依赖
    - 高频维度(i 大,θ 小)→ 短距离依赖
    """
    # 计算每个维度的频率倒数的平方根
    # θ_i = base^{-2i/dim}
    # freqs = [base^{0}, base^{-2/dim}, base^{-4/dim}, ...]
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2, device=device).float() / dim))
    
    # 生成位置索引 [0, 1, 2, ..., seq_len-1]
    t = torch.arange(seq_len, device=device)
    
    # 计算每个位置 × 每个频率的外积
    # t[:, None] * freqs[None, :] → [seq_len, dim/2]
    # 形状:(seq_len, dim/2)
    freqs = torch.outer(t, freqs)
    
    # 转换为极坐标形式:cos(θ) 和 sin(θ)
    # shape: [seq_len, dim/2]
    freqs_cos = freqs.cos()
    freqs_sin = freqs.sin()
    
    return freqs_cos, freqs_sin


# 完整的多头注意力前向传播(带 RoPE + KV Cache)
def multihead_attention_with_rope_and_cache(
    x: torch.Tensor,           # 输入 token embedding [batch, seq_len, d_model]
    cached_k: torch.Tensor,    # 历史 K 缓存 [batch, n_heads, seq_len, head_dim]
    cached_v: torch.Tensor,    # 历史 V 缓存 [batch, n_heads, seq_len, head_dim]
    current_pos: int,           # 当前处理的起始位置
    w_q, w_k, w_v, w_o: torch.Tensor,  # 权重矩阵
    freqs_cos: torch.Tensor,    # RoPE cos 频率
    freqs_sin: torch.Tensor     # RoPE sin 频率
):
    """
    带 RoPE 和 KV Cache 的多头注意力实现
    """
    batch_size, seq_len, d_model = x.shape
    n_heads = w_q.shape[0]  # 从权重形状推断
    head_dim = w_q.shape[-1]
    
    # 投影到 Q、K、V
    q = x @ w_q  # [batch, seq_len, n_heads*head_dim]
    k = x @ w_k
    v = x @ w_v
    
    # 重塑为多头格式
    # [batch, seq_len, n_heads, head_dim] → [batch, n_heads, seq_len, head_dim]
    q = q.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)
    k = k.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)
    v = v.view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)
    
    # 提取当前位置的频率(用于 RoPE)
    # freqs_cos/sin shape: [seq_len, head_dim/2]
    pos_cos = freqs_cos[current_pos:current_pos+seq_len]
    pos_sin = freqs_sin[current_pos:current_pos+seq_len]
    
    # 应用 RoPE(只对 Q 和 K,不对 V)
    # 需要将频率 reshape 为匹配 head 的维度
    # pos_cos: [seq_len, head_dim/2] → [1, 1, seq_len, head_dim/2]
    q, k = apply_rotary_pos_emb(
        q, k,
        pos_cos.unsqueeze(0).unsqueeze(0),
        pos_sin.unsqueeze(0).unsqueeze(0)
    )
    
    # 更新 KV Cache
    new_k = k
    new_v = v
    # 将新的 K、V 追加到缓存
    # cached_k[:, :, current_pos:current_pos+seq_len, :] = k
    
    # 计算注意力
    # 如果有历史缓存,需要将当前 K 与历史 K 拼接
    if current_pos > 0:
        # k_all: [batch, n_heads, current_pos+seq_len, head_dim]
        # 需要从 cached_k 中取 [0:current_pos] 部分
        k_all = torch.cat([cached_k[:, :, :current_pos, :], k], dim=2)
        v_all = torch.cat([cached_v[:, :, :current_pos, :], v], dim=2)
    else:
        k_all = k
        v_all = v
    
    # 注意力计算:Q @ K^T / sqrt(d_k)
    scores = torch.matmul(q, k_all.transpose(-2, -1)) / math.sqrt(head_dim)
    attn_weights = F.softmax(scores, dim=-1)
    
    # @ V 并输出
    output = torch.matmul(attn_weights, v_all)
    output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
    output = output @ w_o
    
    return output, new_k, new_v

三、Window Attention:突破 O(n²) 的工程方案

3.1 Sliding Window Attention 的原理

为解决 O(n²) 问题,研究者提出了 Sliding Window Attention(滑动窗口注意力)。每个 token 只关注其周围 W 个 token,而不关注全序列:

# Sliding Window Attention 的核心逻辑

"""
标准 Attention:每个位置可以看到所有历史位置
O(n²) 复杂度

Sliding Window Attention(W=4):
position 0: attend [0]
position 1: attend [0,1]
position 2: attend [0,1,2]
position 3: attend [0,1,2,3]
position 4: attend [1,2,3,4]  ← 窗口滑动
position 5: attend [2,3,4,5]  ← 窗口滑动
...

复杂度降为 O(n·W),W 是窗口大小(常数)
"""

def sliding_window_attention(
    q: torch.Tensor,    # [batch, n_heads, seq_len, head_dim]
    k: torch.Tensor,    # [batch, n_heads, seq_len, head_dim]
    v: torch.Tensor,    # [batch, n_heads, seq_len, head_dim]
    window_size: int = 4
) -> torch.Tensor:
    """
    滑动窗口注意力实现
    
    Args:
        q, k, v: 多头注意力的 Q、K、V 向量
        window_size: 每个位置前后各看多少个 token
    
    Returns:
        加权输出向量
    """
    seq_len = q.shape[2]
    head_dim = q.shape[-1]
    
    # 为每个位置创建掩码,使得窗口外的位置得分为 -inf
    # 掩码形状:[seq_len, seq_len]
    mask = torch.triu(torch.ones(seq_len, seq_len, device=q.device) * float('-inf'), diagonal=-window_size)
    # 设置下三角(window_size 以内)为 0(允许注意)
    mask = torch.tril(mask, diagonal=0) + torch.triu(torch.zeros(seq_len, seq_len, device=q.device), diagonal=window_size)
    
    # 由于 tril 起点不对,重新构造正确掩码
    # 更简洁的做法:
    mask = torch.zeros(seq_len, seq_len, device=q.device)
    for i in range(seq_len):
        start = max(0, i - window_size)
        mask[i, start:i+1] = 0
        mask[i, i+1:] = float('-inf')
    
    # 计算注意力分数
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(head_dim)
    
    # 应用窗口掩码
    scores = scores + mask.unsqueeze(0).unsqueeze(0)  # broadcast to [batch, n_heads, seq_len, seq_len]
    
    attn_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attn_weights, v)
    
    return output


# 分层稀疏注意力(Longformer 风格)
# 远距离用大窗口,近距离用小窗口
def hierarchical_sparse_attention(
    q, k, v,
    local_window: int = 256,
    global_window: int = 512
):
    """
    分层稀疏注意力:
    1. 每个 token 看局部 local_window 范围(精细)
    2. 每个 token 额外看全局 global_window 范围的汇总信息(粗粒度)
    
    复杂度:O(n·local_window + n·global_window) ≈ O(n)
    """
    seq_len = q.shape[2]
    
    # 局部注意力:O(n·local_window)
    local_out = sliding_window_attention(q, k, v, window_size=local_window)
    
    # 全局注意力:每隔 stride 个 token 汇聚信息
    # 这部分用更小的粒度处理
    stride = global_window // 4
    global_k = k[:, :, ::stride, :]  # 稀疏采样 key
    global_v = v[:, :, ::stride, :]
    
    # 计算与全局 key 的注意力
    global_scores = torch.matmul(q, global_k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
    global_attn = F.softmax(global_scores, dim=-1)
    global_out = torch.matmul(global_attn, global_v)
    
    # 融合局部和全局
    return local_out + 0.5 * global_out

3.2 FlashAttention:GPU 友好的实现

FlashAttention 是目前最主流的高效注意力实现,它通过分块计算算子融合将显存占用从 O(n²) 降到 O(n),同时保持数值精度。

# FlashAttention 核心思想:分块 + 在线 softmax

"""
标准 attention 的问题:
1. 需要计算完整的 n×n 注意力矩阵,显存 O(n²)
2. 需要将矩阵分成多个矩阵相乘,HBM 读写开销巨大

FlashAttention 解决方案:
将注意力矩阵按块(block)分块计算,
在累加过程中实时更新 softmax 统计量(max 和 sum),
避免一次性加载完整矩阵。

数学原理:
Softmax(x)_i = exp(x_i - max(x)) / sum_j(exp(x_j - max(x)))

当分块计算时:
已知前 i 块的 max 和 sum,
处理第 i+1 块时:
new_max = max(old_max, block_max)
new_sum = old_sum * exp(old_max - new_max) + block_sum * exp(block_max - new_max)

最终每个元素的 softmax = exp(x_i - global_max) * exp(x_i_local - block_max) / new_sum
"""

# 伪代码展示核心逻辑(实际 GPU 实现请用 Triton 或 CUDA)
def flash_attention_fused(
    q: torch.Tensor,   # [batch, n_heads, seq_len, head_dim]
    k: torch.Tensor,   # [batch, n_heads, seq_len, head_dim]
    v: torch.Tensor,   # [batch, n_heads, seq_len, head_dim]
    block_size: int = 128
):
    batch, n_heads, seq_len, head_dim = q.shape
    n_blocks = (seq_len + block_size - 1) // block_size
    
    # 初始化输出和 softmax 统计量
    output = torch.zeros_like(q)
    lse = torch.zeros((batch, n_heads, seq_len), device=q.device)  # log-sum-exp
    
    for i in range(n_blocks):
        # 加载当前块
        q_block = q[:, :, i*block_size:(i+1)*block_size, :]  # [batch, n_heads, block, head_dim]
        
        # 计算与所有 K 块的注意力(分块累加)
        for j in range(n_blocks):
            k_block = k[:, :, j*block_size:(j+1)*block_size, :]
            v_block = v[:, :, j*block_size:(j+1)*block_size, :]
            
            # 局部 softmax 计算
            # scores: [batch, n_heads, block, block]
            scores = torch.matmul(q_block, k_block.transpose(-2, -1)) / math.sqrt(head_dim)
            
            # 在线 softmax 融合
            if i == 0 and j == 0:
                # 第一个块:直接计算
                block_max = scores.amax(dim=-1, keepdim=True)
                block_sum = torch.exp(scores - block_max).sum(dim=-1, keepdim=True)
                # 更新全局统计量
                cur_max = block_max
                cur_sum = block_sum
            else:
                # 后续块:融合到全局
                block_max = scores.amax(dim=-1, keepdim=True)
                new_max = torch.maximum(cur_max, block_max)
                
                # 重新对齐并累加
                block_sum = torch.exp(scores - new_max).sum(dim=-1, keepdim=True)
                cur_sum = cur_sum * torch.exp(cur_max - new_max) + block_sum
                cur_max = new_max
        
        # 计算最终输出
        # output_block = exp(q_block @ k^T - global_max) / sum / v^T
        # 简化处理(实际实现更复杂)
        attn_block = torch.matmul(
            torch.exp(scores - cur_max.unsqueeze(-1)),  # 已对齐的权重
            v_block
        )
        output[:, :, i*block_size:(i+1)*block_size, :] = attn_block
    
    return output

四、中止策略与 Context Window 结束处理

4.1 常见的停止信号

当生成长文本时,模型需要知道何时停止。主流模型使用以下停止机制:

停止机制说明示例
EOS Token遇到 End-of-Sequence 标记停止Llama 用 <|eot_id|>
Max Tokens到达预设最大生成长度停止API 调用时设置 max_tokens
Stop Sequences遇到指定字符串停止停止词如 "\n\n" 或 "## "
Logit Mask在解码阶段屏蔽某些 token屏蔽掉终止符后的 token

4.2 中止策略的代码实现

class StoppingCriteria:
    """
    自定义停止条件,用于生成循环中判断是否应该停止
    
    支持多种停止条件组合:
    - 遇到特定 token
    - 遇到特定字符串
    - 达到最大 token 数
    """
    
    def __init__(
        self,
        eos_token_id: int,
        max_new_tokens: int,
        stop_sequences: List[str] = None,
        tokenizer=None
    ):
        self.eos_token_id = eos_token_id
        self.max_new_tokens = max_new_tokens
        self.stop_sequences = stop_sequences or []
        self.tokenizer = tokenizer
        self.generated_tokens = []
        self.generated_text = ""
    
    def __call__(self, last_token_id: int, full_sequence: List[int]) -> bool:
        """
        每次生成一个 token 后调用此方法判断是否停止
        
        Args:
            last_token_id: 最新生成的 token ID
            full_sequence: 目前为止的完整 token 序列
        
        Returns:
            True = 停止生成,False = 继续生成
        """
        self.generated_tokens.append(last_token_id)
        
        # 条件1:达到最大 token 数
        if len(self.generated_tokens) >= self.max_new_tokens:
            return True
        
        # 条件2:遇到 EOS token
        if last_token_id == self.eos_token_id:
            return True
        
        # 条件3:检查停止字符串(需要解码后判断)
        if self.tokenizer is not None:
            # 将最新 token 解码(增量解码)
            new_token_text = self.tokenizer.decode([last_token_id])
            self.generated_text += new_token_text
            
            # 检查是否匹配任何停止序列
            for stop_seq in self.stop_sequences:
                if stop_seq in self.generated_text:
                    # 发现停止序列,返回 True
                    # 可以选择是否在序列处截断
                    return True
        
        return False


def streaming_generation_with_stop(
    model,
    tokenizer,
    prompt: str,
    max_new_tokens: int = 512,
    stop_sequences: List[str] = ["\n\n## ", "。", "---"]
):
    """
    带停止条件的大模型流式生成示例
    
    支持 Ctrl+C 中断、流式输出、任意停止序列检测
    """
    # 编码输入
    input_ids = tokenizer.encode(prompt, return_tensors="pt")
    
    # 初始化停止条件
    stopping = StoppingCriteria(
        eos_token_id=tokenizer.eos_token_id,
        max_new_tokens=max_new_tokens,
        stop_sequences=stop_sequences,
        tokenizer=tokenizer
    )
    
    # 生成循环
    with torch.no_grad():
        generated = input_ids[0].tolist()
        output_text = ""
        
        while True:
            # 前向传播(使用缓存加速)
            outputs = model(torch.tensor([generated]))
            next_token_logits = outputs.logits[0, -1, :]
            
            # 采样(可替换为 greedy/top_k/top_p)
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            
            generated.append(next_token)
            
            # 检查停止条件
            if stopping(last_token_id=next_token, full_sequence=generated):
                break
            
            # 流式输出(yield)
            new_text = tokenizer.decode([next_token])
            output_text += new_text
            yield new_text
    
    return output_text

五、Context Window 在实际应用中的坑

5.1 位置衰减问题

警告:大多数模型对 Context Window 开头和结尾的内容理解更好(称为"lost in the middle"问题),中间部分容易被忽略。RAG 检索时尽量将最相关的内容放在开头或结尾。

5.2 Context 溢出时的降级策略

class ContextWindowManager:
    """
    当输入超出 Context Window 时的降级策略管理器
    """
    
    def __init__(self, model_name: str, max_tokens: int):
        self.max_tokens = max_tokens
        # 保留 20% 空间给输出
        self.input_max = int(max_tokens * 0.8)
    
    def truncate_or_summarize(self, texts: List[str]) -> str:
        """
        处理超出 Context Window 的文本
        
        策略:
        1. 如果总长度在限制内:直接拼接
        2. 如果超出:从后往前截断(因为最近的内容更重要)
        3. 如果单段就超出:使用分段 + 摘要
        """
        # 估算总 token 数
        total_tokens = sum(self._estimate_tokens(t) for t in texts)
        
        if total_tokens <= self.input_max:
            return "\n".join(texts)
        
        # 从后往前截断(近期内容更重要)
        result = []
        remaining = self.input_max
        
        for text in reversed(texts):
            tokens = self._estimate_tokens(text)
            if tokens <= remaining:
                result.insert(0, text)
                remaining -= tokens
            else:
                # 截断此文本
                truncated = self._truncate_text(text, remaining)
                result.insert(0, truncated)
                break
        
        return "\n".join(result)
    
    def _estimate_tokens(self, text: str) -> int:
        """估算中文字符数(粗略:1中文≈1.5 token)"""
        chinese_chars = sum(1 for c in text if '\u4e00' <= c <= '\u9fff')
        other_chars = len(text) - chinese_chars
        return int(chinese_chars * 1.5 + other_chars * 0.25)
    
    def _truncate_text(self, text: str, max_tokens: int) -> str:
        """按 token 数截断文本"""
        max_chars = int(max_tokens / 1.5)  # 估算
        if len(text) <= max_chars:
            return text
        return text[:max_chars] + "...[内容已截断]"

六、总结

Context Window 是 LLM 最关键的资源之一,理解其背后的技术原理(O(n²) 注意力复杂度、KV Cache、RoPE 位置编码、FlashAttention 分块计算)可以帮助你: