Transformer 语言模型实现
如图是一个 GPT-Transformer 语言模型的简化实现。可以看到这是一个 decoder-only 的 Transformer 模型,输入是一个序列的 token id,经过嵌入层(Embedding Layer)转换为向量表示,然后通过多个 Transformer 解码器层(Transformer Decoder Layer)进行处理,最后通过线性层和 softmax 层输出下一个 token 的概率分布。我们将依次实现这些组件。
嵌入层(Embedding Layer)
嵌入层的作用是将离散的 token id 转换为连续的向量表示。假设词汇表大小为 V,嵌入维度为 D,那么嵌入层可以看作是一个 V x D 的矩阵,其中每一行对应一个 token 的向量表示。给定一个 token id 序列,我们可以通过查找嵌入矩阵的对应行来获得其向量表示。
class Embedding(nn.Module):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None
):
super().__init__()
self.weight: Float[Tensor, "num_embeddings embedding_dim"] = nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device=device, dtype=dtype)
)
std = 1
nn.init.trunc_normal_(
self.weight,
mean = 0,
std = std,
a = - 3,
b = 3
)
def forward(self, token_ids: Int[Tensor, "..."]) -> Float[Tensor, "... embedding_dim"]:
return self.weight[token_ids]
PyTorch 中支持 tensor indexing11 使用 data[indices] 时,会索引 data 中最后一个维度,产生根据 indices 形状扩展的新张量,因此我们可以直接使用 token id 作为索引来获取嵌入向量。
Linear 层
线性层是神经网络的基石。我们在 transformer 中只用使用到无偏置的线性层,也就是
因此我们可以实现一个简单的 Linear 层:
class Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
device: torch.device | None = None,
dtype: torch.dtype | None = None
) -> None:
super().__init__()
self.weight: Float[Tensor, "out_features in_features"] = nn.Parameter(
torch.empty((out_features, in_features), device=device, dtype=dtype)
)
std = 2 / (in_features + out_features)
nn.init.trunc_normal_(
self.weight,
mean = 0,
std = std,
a = - 3 * std,
b = 3 * std
)
def forward(self, x: Float[Tensor, "... in_features"]) -> Float[Tensor, "... out_features"]:
x = einsum(self.weight, x, "out_features in_features, ... in_features -> ... out_features")
return x
这里我们使用了 einsum 来实现矩阵乘法操作。它的好处是不用考虑输入张量的具体维度,我们只需要思考围绕哪些维度进行卷积,再据此写出高可读性的爱因斯坦求和表达式即可。
Norm 层:从 LayerNorm 到 RMSNorm
在 transformer 中,归一化层(Normalization Layer)起着至关重要的作用。在原始的 Transformer 论文中,使用的是 LayerNorm。我们回顾一下 LayerNorm 的计算公式:
其中, 是特征维度的大小, 和 是可学习的参数。在这个过程中,归一化的指标主要涉及均值归一化(带来了平移不变性)和方差归一化(带来了尺度不变性)。后续研究发现,均值归一化对模型性能的提升并不显著,而方差归一化则更为重要。现代 Transformer 结构中,权重通常被初始化为对称分布。在大模型训练过程中,激活向量的均值往往接近于零,因此均值归一化的作用有限。从数学角度看,在高维空间中,减去均值相当于将向量投影到一个低维子空间,而这种投影在高维空间中对距离的影响较小。因此,许多现代 Transformer 变体选择省略均值归一化,仅保留方差归一化,这就是 RMSNorm 的基本思想。
RMSNorm 的计算公式如下:
我们发现,RMSNorm 省略了均值归一化部分,只保留了方差归一化。这使得 RMSNorm 在计算上更为高效,同时在实践中也表现出良好的性能。下面是 RMSNorm 的实现代码:
class RMSNorm(nn.Module):
def __init__(
self,
d_model: int,
eps: float = 1e-5,
device: torch.device | None = None,
dtype: torch.dtype | None = None
):
super().__init__()
self.d_model = d_model
self.eps = eps
self.weight: Float[Tensor, "d_model"] = nn.Parameter(
torch.ones(d_model, device=device, dtype=dtype)
)
def forward(self, x: Float[Tensor, "... d_model"]) -> Float[Tensor, "... d_model"]:
in_dtype = x.dtype
x = x.to(torch.float32)
square_sum = einsum(x, x, "... d_model, ... d_model -> ...")
rms = torch.rsqrt(square_sum / self.d_model + self.eps)
result = einsum(x, rms, self.weight, "... d_model, ..., d_model -> ... d_model")
return result.to(in_dtype)
这里的细节是,我们在计算过程中将输入张量转换为 float32,以确保数值稳定性,最后再将结果转换回原始的数据类型。在现代大模型训练中,模型的参数往往使用混合精度(mixed precision)进行训练,因此这种类型转换是必要的。
Pre-Norm vs Post-Norm
在 Transformer 结构中,归一化层的位置也经历了演变。原始的 Transformer 论文中采用的是 Post-Norm 结构,即在每个子层(self-attention 和 feed-forward)之后进行归一化。假设输入的张量为 ,则 Post-Norm 的计算流程类似于
在现代的 Transformer 变体中,几乎已经成为所有人的共识的是采用 Pre-Norm 结构,即在每个子层之前进行归一化。
通过训练流程图可以发现,Pre-Norm 结构在整体的网络结构中保持了 Identity Mapping 贯穿始终,这使得梯度能够更有效地传播,避免了 Post-Norm 结构中可能出现的梯度消失 / 爆炸问题。
FFN 层:从 ReLU 到 SwiGLU
前馈神经网络层用于在 attention block 之中引入非线性变换。在原始的 Transformer 论文中,使用的是 ReLU 激活函数,这是一种简单且高效的非线性函数,在计算其梯度时也非常方便。然而,随着研究的深入,发现更复杂的激活函数可以提升模型的表达能力和性能。例如,GELU(Gaussian Error Linear Unit)激活函数在 BERT 模型中被广泛采用。GELU 的计算公式如下:
GELU 通过引入平滑的非线性变换,使得模型在处理输入时能够更好地捕捉复杂的模式。它的核心思想是统计性地实现 dropout 效果,从而提升模型的泛化能力。我们知道 dropout 在训练中的重要性,它通过随机地丢弃部分神经元来防止过拟合。一旦选定了 dropout 的神经元,它就变成了一个确定性模型的训练:每个神经元要么被保留,要么被丢弃。GELU 通过平滑地调整神经元的激活值,实现了类似的效果。它的核心思想是:将随机丢弃神经元的过程转化为一个连续的概率分布,直接集成在激活函数中。如果 越小,则该神经元被丢弃的概率越大,反之亦然。我们假设神经元被保留的概率为 ,输出遵循一个标准的正态分布 ,则有:
其中, 是标准正态分布的累积分布函数(CDF)。则期望上的输出为:
对于 的幂函数展开得到了上面的近似公式。
在后续的研究中,发现了更为高效的激活函数变体,其中 SwiGLU 是一种结合了 SiLU(Sigmoid Linear Unit)和 GLU(Gated Linear Unit)思想的激活函数。SwiGLU 的计算公式如下:
其中, 表示逐元素相乘操作,、 和 是线性变换矩阵。SiLU 激活函数的计算公式为:
SiLU 通过引入 sigmoid 函数,实现了平滑的非线性变换,消除了 ReLU 在负值区域的零梯度问题。GLU 结构引入了线性层之外的逐元素乘法操作,本质上是将加性的变换替换为乘性的变换,从而提升了模型的表达能力。对于一个简单的 GLU 层,其计算公式为:
可以得到其梯度为:
我们发现在反向传播过程中,GLU 层的梯度计算涉及到两个部分:一部分是通过 传递的梯度,另一部分是通过 传递的梯度。这种结构有效地解决了激活函数在反向传播中可能出现的梯度消失问题,从而提升了模型的训练效果。下面是 SwiGLU 的实现代码:
class SwiGLU(nn.Module):
def __init__(self, d_model: int, d_ff: int | None = None):
super().__init__()
if not d_ff:
d_ff = int(8 /3 * d_model)
d_ff = math.ceil(d_ff / 64) * 64
self.w1 = Linear(d_model, d_ff)
self.w2 = Linear(d_ff, d_model)
self.w3 = Linear(d_model, d_ff)
def forward(self, x: Float[Tensor, "... d_model"]) -> Float[Tensor, "... d_model"]:
layer1: Float[Tensor, "... d_ff"] = self.w1(x)
silu: Float[Tensor, "... d_ff"] = layer1 * torch.sigmoid(layer1)
gate: Float[Tensor, "... d_ff"] = self.w3(x)
x: Float[Tensor, "... d_ff"] = silu * gate
x: Float[Tensor, "... d_model"] = self.w2(x)
return x
RoPE 位置编码
原始的 Transformer 论文使用了借助 和 函数的绝对位置编码(Absolute Positional Encoding)。这种方法通过为每个位置生成一个固定的向量来引入位置信息。然而,绝对位置编码在处理长序列时可能会遇到一些问题,面对足够长的序列时,模型可能无法有效地捕捉到位置信息。并且,这类位置编码的具体构造思路也相对单薄:只是简单地赋予相近的位置较为接近的编码值,较远的位置差异较大的编码值。一方面这过于宽泛而难以为这样的构造提供一个理论上的解释,另一方面这种利用三角函数的构造对于不同位置上的编码值关系不具有平移不变性:一个位置的编码值与另一个位置的编码值之间的关系不仅取决于它们之间的距离,还取决于它们在序列中的绝对位置。这种过于固定化的假设限制了模型在处理不同长度的序列时的泛化能力。后续的模型引入了可训练的绝对位置编码,但这种方法在处理长序列时仍然存在类似的问题。
RoPE(Rotary Position Embedding)提出了一种相对位置编码的方法。所谓相对位置编码,指的是编码方式关注于不同位置之间的相对关系,而不是每个位置的绝对编码值。具体而言也就是对于不同位置的相同编码、相同距离,我们在计算它们之间的关系时得到的结果是相同的。
其中, 是位置编码函数, 是一个函数,它的输入是两个位置之间的距离 。
RoPE 的构造方式是旋转:注意到旋转矩阵满足 ,这恰好满足了相对位置编码的要求。为了简单起见,我们在考虑模型的高维空间 时考虑其 个二维子空间,我们假设旋转矩阵恰好对应了这 个二维的不变子空间。那么建立在这 个子空间对应的正交基的组合上的旋转矩阵应当满足以下形式:
其中每个 都是一个二维的旋转矩阵
为了满足我们先前提到的关于位置的相对性,我们只需保证 关于位置 的增量是一个常数即可。一般而言我们采用
其中 是一个超参数,通常设置为 10,000 (为了和原始的绝对位置编码中的频率范围保持一致)。实际在代码实现中,RoPE 作为超参数而不是一个可训练的参数被引入到模型中,这允许了我们提前计算出所有位置的 RoPE 编码,并在训练过程中直接使用这些预计算的编码值。一种可读性较高的实现方式是将输入的嵌入向量分成 个二维子空间,然后对每个子空间应用对应的旋转矩阵。下面是 RoPE 的实现代码
class RotaryPositionalEmbedding(nn.Module):
def __init__(
self,
theta: float,
d_k: int,
max_seq_len: int,
device: torch.device | None = None,
dtype: torch.dtype | None = torch.float32
):
super().__init__()
self.theta = theta
self.max_seq_len = max_seq_len
self.d_k = d_k
log_theta = math.log(self.theta)
dim_indices: Float[Tensor, "d_k/2"] = torch.arange(0, self.d_k, step = 2, device = device, dtype = dtype)
indices: Float[Tensor, "max_seq_len"] = torch.arange(self.max_seq_len, device = device, dtype = dtype)
inv: Float[Tensor, "d_k/2"] = torch.exp(- (log_theta / self.d_k) * dim_indices)
thetas: Float[Tensor, "max_seq_len d_k/2"] = einsum(indices, inv, "seq_len, d -> seq_len d")
cos_thetas = torch.cos(thetas)
sin_thetas = torch.sin(thetas)
rotary_list = [cos_thetas, -sin_thetas, sin_thetas, cos_thetas]
rotary_matrix: Float[Tensor, "max_seq_len d_k/2 2 2"] = rearrange(rotary_list, "(dim1 dim2) ... -> ... dim1 dim2", dim1 = 2, dim2 = 2)
self.register_buffer(
"rotary",
rotary_matrix,
persistent = False
)
def forward(
self,
x: Float[Tensor, "... seq_len d_k"],
token_positions: Float[Tensor, "... seq_len"]
) -> Float[Tensor, "... seq_len d_k"]:
R: Float[Tensor, "... seq_len d_k/2 2 2"] = self.rotary[token_positions]
x_reshape = rearrange(x, "... seq_len (block dim) -> ... seq_len block dim", dim = 2)
x_out = einsum(x_reshape, R, "... seq_len block dim2, ... seq_len block dim1 dim2 -> ... seq_len block dim1")
x_out = rearrange(x_out, "... seq_len block dim -> ... seq_len (block dim)", dim = 2)
return x_out
Attention
我们知道标准的 Attention 公式
在处理多头注意力时,可以通过重排张量使得头的维度 与特征维度 分开而与 batch 维度 合并,这样就可以在计算 Attention 时同时处理多个头(而不引入太多额外的计算开销)。
def forward(
self,
x: Float[Tensor, "... seq_len d_model"],
token_positions: Int[Tensor, " ... seq_len"] | None = None
) -> Float[Tensor, "... seq_len d_model"]:
seq_len = x.shape[-2]
if self.max_len:
mask = self.causual_mask[:seq_len, :seq_len]
else:
mask = torch.tril(torch.ones(seq_len, seq_len, dtype = torch.bool))
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
Q = rearrange(Q, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads = self.num_heads)
K = rearrange(K, "... seq_len (num_heads d_k) -> ... num_heads seq_len d_k", num_heads = self.num_heads)
V = rearrange(V, "... seq_len (num_heads d_v) -> ... num_heads seq_len d_v", num_heads = self.num_heads)
if self.RoPE:
token_positions = rearrange(token_positions, "... seq_len -> ... 1 seq_len")
Q = self.RoPE(Q, token_positions)
K = self.RoPE(K, token_positions)
attention: Float[Tensor, "... num_heads seq_len d_v"] = scaled_dot_product_attention(Q, K, V, mask)
attention: Float[Tensor, "... seq_len d_model"] = rearrange(attention, "... num_heads seq_len d_v -> ... seq_len (num_heads d_v)")
x = self.output_proj(attention)
return x
以上就是一个简化的 Transformer 语言模型的实现。通过逐步构建嵌入层、线性层、归一化层、前馈神经网络层以及注意力机制,我们可以实现一个功能完整的 Transformer 模型。这些组件的设计和实现细节对于理解 Transformer 的工作原理以及优化模型性能至关重要。