Ziyu Li's Homepage

Back

CS336 Assignment 1#

Part1中我们实现了BPE分词器的算法以及Tokenizer类, 在Part2中我们将会手搓整个Transformer模型, 然后在Part3中我们会把前面的部分结合来真正的让这个模型开始训练

label_plot

Transformer LM 简介#

Token嵌入层 (Token Embedding)#

输入Transformer的内容是token id数字张量(显然不能是char), 形状为(batch_size, sequence_length), 比如说(2,2)

inputs = np.array([[2,0],[1,2]]) 
python

然后有一个可训练的矩阵E称之为嵌入矩阵, 这个矩阵回答这样一个问题: 对于每个输入的token_id(一维数字), 怎么把他送到高维空间?

该矩阵形如:

E = np.array([
    [0.0, 0.0, 0.0],  # id = 0 (PAD/UNK)
    [0.1, 0.2, 0.3],  # id = 1 Somewhere
    [0.0, 0.5, 0.5],  # id = 2 over
    [0.9, 0.1, 0.0],  # id = 3 the
    [0.4, 0.4, 0.2],  # id = 4 rainbow
    [0.7, 0.3, 0.6],  # id = 5 way
    [0.2, 0.2, 0.9],  # id = 6 up
    [0.6, 0.1, 0.4],  # id = 7 high
])  # shape (V, d_model) -> here V=8, d_model=3
python

含义为: 将输入的token_id当中的0送到R3R^3上, 值为[0.1,0.2,0.3], 至于这些token_id是哪里来的, 就是从之前训练好的Vocab训练而来的, 完整的流程如下

这是我们输入的自然语言:

Somewhere over the rainbow way up high.
plaintext

我们有训练好的词汇表Vocab:

{"Somewhere":1, "over",2, "the":3, "rainbow":4, "way":5, "up":6, "high":7}
plaintext

首先自然语言被分割成

tokens = ["Somewhere","over","the","rainbow","way","up","high"]
plaintext

然后通过词汇表被编码成

ids = [1,2,3,4,5,6,7]
plaintext

接着对于ids里面的每一个id, 直接在E里面找到对应的行就可以了, 这就完成了Token Embeddings这个步骤

Pre-Norm Transformer块#

经过Embedding处理后的尺寸为(batch_size, sequence_length, d_model)的数据被送入Pre-Norm Transformer块当中, 处理完毕后的尺寸仍为(batch_size, sequence_length, d_model), 块内的组件是自注意力机制和前馈层

输出的归一化#

在经过若干个Transformer块之后, 还要进行归一化, 再送入一个线性层做处理, 最后通过Softmax来输出概率logits来决定下一个词输出什么

Einstein算子标注#

显然在做矩阵乘法/张量乘法的时候, 我们是在某一个维度进行求和, 而那个维度在计算完成之后会消失, Einstein算子标注就是让我们显性的写出那个被求和/将消失的维度, 其他的维度就不管了

我觉得这种张亮乘法其实类似于高维定积分, 当通过累次积分来计算重积分的时候, 显然要写好这一次是对什么变量进行积分, 这次积分完毕之后, 这个变量就不再存在

Vf(x,y,z)dxdydz=V1dxdyV2f(x,y,z)dz\iiint_{V} f(x,y,z)\,dx\,dy\,dz = \iint_{V_1}\,dx\,dy \int_{V_2}f(x,y,z)\,dz

后面那次积分的dz已经说明了这次的积分(视作特殊的求和)是对z这个变量的, 所以不会产生混淆

来看一个Einstein标注的张量运算

import torch
from einops import rearrange, einsum
## Basic implementation
Y = D @ A.T
# Hard to tell the input and output shapes and what they mean.
# What shapes can D and A have, and do any of these have unexpected behavior?
## Einsum is self-documenting and robust
# D A -> Y
Y = einsum(D, A, "batch sequence d_in, d_out d_in -> batch sequence d_out")
## Or, a batched version where D can have any leading dimensions but A is constrained.
Y = einsum(D, A, "... d_in, d_out d_in -> ... d_out")
python

这里我们对d_in这个维度求和, 求和后这个维度就消失掉了, 所以只需要在两个运算量里面都亮明这个维度, 算子就会自动求和

再看一个例子

images = torch.randn(64, 128, 128, 3) # (batch, height, width, channel)
dim_by = torch.linspace(start=0.0, end=1.0, steps=10)
## Reshape and multiply
dim_value = rearrange(dim_by, "dim_value -> 1 dim_value 1 1 1") # 拓展维度, 注意1是可以任意添加的维度
images_rearr = rearrange(images, "b height width channel -> b 1 height width channel") # 同上
dimmed_images = images_rearr * dim_value
## Or in one go:
dimmed_images = einsum(
images, dim_by,
"batch height width channel, dim_value -> batch dim_value height width channel"
)
python

注意一下广播的时候首先维度的数量要匹配, 其次每个维度的大小要么相等要么有一个是1, 而在Einstein标注下, 不足的维度会被自动填充并广播

高维张量的乘法是不便想象的, 我觉得首先要明白需要的输出尺寸是多少, 然后再用Einstein标注去写好输入的尺寸, 不变的用...替代

线性变换标注#

本项目中统一使用列向量做线性变换的notation, 即若对向量x进行线性变换W则:

Y=WxY = Wx

x默认为列向量

线性层和嵌入模块#

参数初始化#

对于每一个权重矩阵, 需要把它声明为nn.Parameter参数并且传入shape, 然后对这个参数做初始化, 比如在线性层当中想构造一个out_features * in_features的权重矩阵并且正态初始化

weight = nn.Parameter(torch.empty(out_features, in_features))
nn.init.trunc_normal_(weight,mean = mean, std = std, a = -3*std, b = 3*std)
python

对于不同的层, 文档给了我们不同的初始化要求

LinearLayer:N(μ=0,σ2=2din+dout),clipto[3σ,3σ]Linear \,\, Layer: N(\mu = 0, \sigma^2 = \frac{2}{d_{in} + d_{out}}), \, \, clip \,\, to \,\, [-3\sigma,3\sigma] EmbeddingLayer:N(μ=0,σ2=1),clipto[3,3]Embedding \,\, Layer: N(\mu = 0, \sigma^2 = 1), \, \, clip \,\, to \,\, [-3,3] RMSNormLayer:EnnRMSNorm \,\, Layer: \mathbb{E}_{n*n}

所有的参数初始化都用torch.nn.init.trunc_normal_来实现

线性层#

很容易实现, 就是声明并初始化一个权重矩阵并且作用在输入上面就可以了

很常规, 解释一下最后这个einsum的notation, 先忽略这个..., 就看"ji,i->j"

(w11.....w1inwout1......woutin)(x1xin)\begin{pmatrix} w_{11} ..... w_{1in} \\ \\ \\ \\ w_{out1} ...... w_{out in} \end{pmatrix} \begin{pmatrix} x_1 \\ \\ \\ \\ x_{in} \end{pmatrix}

从逐行求和的角度看, 结果的第j行满足

yj=iW[j,i]x[i]y_j = \sum_{i} W[j,i] * x[i]

所以是对i这个维度求和, 求和的维度会消失, 所以得到输出维度是j

另一种角度看, 消失的总是内在匹配的维度, 所以是这里的i, 即第一个输入的列, 第二个输入的行

完善一下adapters.py里面的测试, 直接调用自己实现的这个Linear类就好了, 注意对于那些权重矩阵, 记得更新他们的内容, 后面所有的测试都是这个格式, 依葫芦画瓢就好了

运行测试uv run pytest -k test_linear, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_linear
==================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 47 deselected / 1 selected                                                                                                                                                                              

tests/test_model.py::test_linear PASSED

============================================================================================== 1 passed, 47 deselected in 0.19s ==============================================================================================
bash

嵌入模块#

和我们之前讲的一样, 这里就是把输入的token_idEmbedding矩阵当中去寻找相应的d_model维度的向量, 也就是一个升维的过程, 比如说输入为5, 那就去找Embedding矩阵的第五行的行向量就好了

注意self.weight[token_ids]这种写法, 实际上nn.Parametertensor的子类, 是可以通过索引访问的, 举个例子

# weight: (num_embeddings=5, embedding_dim=3)
weight = nn.Parameter(torch.tensor([
    [0.1, 0.2, 0.3],  # id 0
    [0.4, 0.5, 0.6],  # id 1
    [0.7, 0.8, 0.9],  # id 2
    [1.0, 1.1, 1.2],  # id 3
    [1.3, 1.4, 1.5],  # id 4
], dtype=torch.float32))


token_ids = torch.tensor([[2,0],
                          [1,2]], dtype=torch.long)  # shape (2,2)

# 索引取 embedding:结果形状 (2, 2, 3)
embs = weight[token_ids]
python

输出类似于

tensor([
  [[0.7, 0.8, 0.9],  # weight[2]
   [0.1, 0.2, 0.3]], # weight[0]
  [[0.4, 0.5, 0.6],  # weight[1]
   [0.7, 0.8, 0.9]]  # weight[2]
])
python

不过也不需要了解, 反正知道支持下标访问就行了, 而且这个下标还可以是一个tensor

运行测试uv run pytest -k test_embedding, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_embedding
==================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 47 deselected / 1 selected                                                                                                                                                                              

tests/test_model.py::test_embedding PASSED

============================================================================================== 1 passed, 47 deselected in 0.21s ==============================================================================================
bash

Pre-Norm Transformer块#

这里主要需要实现三个模块, RMSNorm归一化, RoPE旋转编码, Feed-Forward前馈神经网络

RMSNorm---Root Mean Square Layer Normalization 均方根归一化#

对于向量aRdmodela \in \mathbb{R}^{d_{model}}, RMSNorm会对每个分量a_i进行如下变化:

RMSNorm(ai)=aiRMS(a)giRMSNorm(a_i) = \frac{a_i}{RMS(a)} g_i

其中RMS(a)RMS(a)是一个平方根式求和

RMS(a)=1dmodeli=1dmodelai2+ϵRMS(a) =\sqrt{\frac{1}{d_{model}} \sum_{i=1}^{d_{model}} a_i^2 + \epsilon}

很好理解, 这个分母真的就是均值->平方求和->开根号(忽略那个ϵ\epsilon)

gig_i是一个可学习的向量, 注意每个aia_i有一个gig_i, 总共有dmodeld_{model}aia_i, 所以有dmodeld_{model}gig_i

解释一下这个RMS_a的算法当中的einsum部分, 实际上这里是要算内积分, 所以对于两个一模一样的东西, 直接按照最后一个维度求和就好了

注意这里这个RMS_a.unsqueeze(-1)是不可以省略的, 因为RMS_a的尺寸是..., 而x的尺寸是...d, 所以必须给RMS_a的尺寸做成...1才能进行除法的广播, 实际上这个RMS_a是个标量, 广播的意思就是每个aia_i都要去除以这个标量, 这就必须要求最后一个维度是对齐的, 不然广播会出错

运行测试uv run pytest -k test_rmsnorm, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_rmsnorm
==================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 47 deselected / 1 selected                                                                                                                                                                              

tests/test_model.py::test_rmsnorm PASSED

============================================================================================== 1 passed, 47 deselected in 0.18s ==============================================================================================
bash

前馈神经网络#

虽然说是神经网络, 但是其实就是设计一个激活函数, 大概用的有以下几种:

SiLU/Swish

SiLU(x)=xσ(x)=x1+exSiLU(x) = x*\sigma(x) = \frac{x}{1+e^{-x}}

Gated Linear Units/GLU

GLU(x,W1,W2)=σ(W1x)W2xGLU(x,W_1,W_2) = \sigma(W_1x) \odot W_2x

该算子为Hadamard积, 即逐元素相乘

SwiGLU

SwiGLU(x,W1,W2,W3)=W2(SiLU(W1x)W3x)SwiGLU(x,W_1,W_2,W_3) = W_2(SiLU(W_1x) \odot W_3x)

考虑一下SwiGLU的尺寸问题, 输入的x(...,d_model), 经过SwiGLU的变换之后仍然应该是这个尺寸, 所以W_1W_3的尺寸应当是(d_ff, d_model), W_2应该是(d_model, d_ff)

借用torch.sigmoid实现这个SwiGLU的代码:

还是讲一下这里的维度notation, 首先对于Hadamard算子应该是比较好理解的, 前后尺寸都是...,意味着不对任何维度求和, 那当然就是逐元素相乘

对于W_3x而言, 这里写的看起来像是xW_3, 不过关键还是要对齐那个可以匹配的维度, 由于x(...,d_model), W_3(d_ff, d_model), 显然写的时候要确保最后一维的尺寸一致(即d), 然后对这个维度求和(即不出现在结果里面就好了)

在工程上或许这么写不会有什么问题, 毕竟这块以后都会封装, 也没有谁真的会去看, 不过如果是数学作业/paper上还是保持notation的顺序比较好, 不然读者的脸色可能不会太好看

运行测试uv run pytest -k test_swiglu, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_swiglu
Uninstalled 1 package in 0.92ms
Installed 1 package in 16ms
================================================================================================= test session starts ==================================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 47 deselected / 1 selected                                                                                                                                                                        

tests/test_model.py::test_swiglu PASSED

=========================================================================================== 1 passed, 47 deselected in 0.18s ===========================================================================================
bash

相对位置编码#

相对位置编码和多头注意力机制应该是最难的, 先不考虑代码的问题, 我们先来理解一下这玩意到底是在干什么

首先RoPE接受一个参数θ\theta, 这会决定我们每次旋转的角度

两个index:ik, i代表的是输入xix^i的索引, 也就是这是输入的第几条输出, k代表的是现在是在对这个输出的第几组维度进行变换, 对于每一组(i, k), 变换的角度θi,k=iθ2k2d\theta_{i,k} = \frac{i}{\theta^{\frac{2k-2}{d}}}, 其中d指的是d_model, 为固定的模型参数

维度是被两两分组的, 比如说d_model是8, 那就被切分成4组, k(1,2,3,4)k \in (1,2,3,4)

每一组的旋转矩阵为:

Rki=(cosθi,k,sinθi,ksinθi,k,cosθi,k)R_{k}^{i} = \begin{pmatrix} \cos{\theta_{i,k}}, -\sin{\theta_{i,k}} \\ \sin{\theta_{i,k}}, \cos{\theta_{i,k}} \end{pmatrix}

现在对q1=Wqx1=[1,0,1,0]q^{1} = W_{q}x^{1} = [1,0,1,0]来做变换, q1q^1的i=1, k被分为两组, 1和2, 那么可以计算出以下两个θ\theta

θ1,1=1θ0=1θ1,2=1θ0.5\theta_{1,1} = \frac{1}{\theta^{0}} = 1 \\ \theta_{1,2} = \frac{1}{\theta^{0.5}}

进而得到两个旋转矩阵R11,R21R_{1}^{1}, R_{2}^{1}, 那么组一是[1,0][1,0], 组二是[0,1]组二是[0,1], 分别对其左乘R11,R21R_{1}^{1}, R_{2}^{1}即可

写成比较大的矩阵形式就是:

R(i)=diag(R1(i),  R2(i),  ,  Rd/2(i))=[R1(i)0R2(i)0Rd/2(i)],R^{(i)}=\mathrm{diag}\big(R^{(i)}_1,\;R^{(i)}_2,\;\dots,\;R^{(i)}_{d/2}\big) = \begin{bmatrix} R^{(i)}_1 & & & 0\\[2pt] & R^{(i)}_2 & & \\[2pt] & & \ddots & \\[2pt] 0 & & & R^{(i)}_{d/2} \end{bmatrix}, q(i)=[q1(i)q2(i)qd/2(i)],qk(i)R2,qk(i)=[q2k1(i)q2k(i)](k=1,,d/2).q^{(i)}= \begin{bmatrix} q^{(i)}_1\\[4pt] q^{(i)}_2\\[4pt] \vdots\\[4pt] q^{(i)}_{d/2} \end{bmatrix}, \qquad q^{(i)}_k \in \mathbb{R}^2,\quad q^{(i)}_k= \begin{bmatrix} q^{(i)}_{2k-1}\\[4pt] q^{(i)}_{2k} \end{bmatrix} \quad (k=1,\dots,d/2).

那么整个的RoPE就是RiqiR^{i} q^{i}, 一句话总结就是先把qi/kiq^{i}/k^{i}分组, 然后每组构造一个2*2的旋转矩阵, 对每个组左乘这个旋转矩阵之后凑起来就得到了RoPE的结果

实验文档中提出了用self.register_buffer来注册这些三角函数值, 因为他们只依赖于输入角度theta, 输入的”编号”i以及离散取值于[0,d2][0,\frac{d}{2}]k, 所以可以在初始化的时候就把他们确定下来

注意θi,k\theta_{i,k}, 他的指标应该是ik的笛卡尔积, 这样才能保证每个ik都能被遍历到, 所以应当通过以下代码来建立这个θi,k\theta_{i,k}

positions = torch.arange(sequence_length).float() # i的向量

freqs = theta ** (torch.arange(0, self.d_k , 2).float() / self.d_k) # theta^{2k-2/d}的向量

angles = torch.outer(positions, freqs) # i和K的笛卡尔积
python

然后注册成为类buffer

self.register_buffer("cos_cached", torch.cos(angles), persistent = False)
self.register_buffer("sin_cached", torch.sin(angles), persistent = False)
python

如此一来的话就可以通过token_positions来访问得到cos(θi,k)\cos(\theta_{i,k})sin(θi,k)\sin(\theta_{i,k})了, 例如:

cos_pos = self.cos_cached[token_positions] # 直接拿到这个x所需的所有cos值, 注意token_positions是个tensor
sin_pos = self.sin_cached[token_positions]
python

回想一下在之前的变换矩阵示意图里面, 这个变换最后并不改变输入x的shape, 但是要能把这个分块d/2 * d/2的变换矩阵作用上去, 是需要对x进行一些reshape

x_reshaped = x.view(batch_size,seq_len,d_k//2,2) # 分成d_k//2个组, 每组两个元素
x1,x2 = x_reshaped[...,0],x_reshaped[...,1] # 对于每个组, 把第一个元素和第二个元素分开
python

然后进行变换并且stack回去即可

x1_rotated = x1 * cos_pos - x2 * sin_pos # 注意这里的x_1是: "一个输入token(一个i)"的所有的d_k//2个组的x_1, 不只是一个组的x_1
x2_rotated = x2 * cos_pos + x1 * sin_pos

# 如果说输入的x = [1,2,3,4,5,6], 那么这里的x_1是形如[1,3,5]的tensor, 所以这个变换是向量化的
x_rotated = torch.stack([x1_rotated, x2_rotated], dim=-1)
x_rotated = x_rotated.view(batch_size, seq_len, d_k)
python

整体代码如下:

完善测试:

运行测试uv run pytest -k test_rope, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_rope
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 47 deselected / 1 selected                                                                                                                                                         

tests/test_model.py::test_rope PASSED

=================================================================================== 1 passed, 47 deselected in 0.17s ====================================================================================
bash

Softmax和点积注意力机制#

首先实现Softmax函数, 对向量vv做归一化

Softmax(v)i=exp(vi)j=1nexp(vj)Softmax(v)_i = \frac{exp(v_i)}{\sum_{j=1}^{n} exp(v_j)}

注意求指数有可能会让数据变得很大从而上溢出, 所以一个比较好的办法是让这个向量的每个分量减去这个向量的最大分量, 因为对一个很小的数求指数是不会溢出的

完善测试:

注意这个dim是必须的参数, 因为需要知道对什么维度进行归一化

运行测试uv run pytest -k test_softmax_matches_pytorch, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_softmax_matches_pytorch
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 47 deselected / 1 selected                                                                                                                                                         

tests/test_nn_utils.py::test_softmax_matches_pytorch PASSED

=================================================================================== 1 passed, 47 deselected in 0.17s ====================================================================================
bash

接下来计算点积注意力

Attention(Q,K,V)=Softmax(QTKdk)VAttention(Q,K,V) = Softmax(\frac{Q^TK}{\sqrt{d_k}})V

先考虑一下尺寸问题, 简单来说, 假设QRn×dkQ \in \mathbb{R}^{n \times d_k}, KRm×dkK \in \mathbb{R}^{m \times d_k}, VRm×dvV \in \mathbb{R}^{m \times d_v}, 那么输出尺寸应该是Rn×dv \mathbb{R}^{n \times d_v}, 但是QTKQ^TK如果按照一般的矩阵乘法是不能相乘的, 这里其实对这两个矩阵的每个行向量做内积

QTKij=k=1dkQikKkjQ^TK_{ij} = \sum_{k=1}^{d_k} Q_{ik} K_{kj}

注意求和, 每行有dkd_k个元素, 实验文档要求我们用einsum来实现这个求和, 如果是用pytorch的话, 这里要手动修改成QKTQK^T的乘积形式

但我觉得这种记号并不好, 既然用了矩阵乘法的记号, 那应该要确保按照记号是可相乘的, 不然当发现维度不匹配的时候会很confused

scores = torch.einsum("b...qd,b...kd->b...qk",self.Q,self.K)/torch.sqrt(torch.tensor(self.d_k,dtype=torch.float32))
# kd 指的是Q当中的n * d_k
# qd 指的是K当中的m * d_k
# 匹配的是最后一维, einsum会自动把他们转置成可以相乘的形式
python

然后用mask做掩码变换, 这里mask的维度是n×mn \times m, 暂时不需要管他怎么实现的, 当作类里有的成员变量就好, 既然他是个布尔矩阵, 那只要用torch.where去把mask和计算出的分数做一个类似与运算就好了

if self.mask is not None:
    scores = torch.where(self.mask,scores,float('-inf'))
python

接下来做Softmax变换, 对最后一维做归一化

attention_weights = torch.softmax(scores, dim=-1)
python

最后乘以VV即可, 注意前面的结果是可以”直接”乘以VV的, 因为维度已经匹配了, 而像之前那样维度不匹配的情况, einsum会自动把他们转置成可以相乘的形式

output = torch.einsum("b...qk,b...kv->b...qv",attention_weights,self.V)

return output
python

总体实现如下

完善测试:

这里有两个测试, 分别运行uv run pytest -k test_scaled_dot_product_attentionuv run pytest -k test_4d_scaled_dot_product_attention, 结果如下

多头注意力机制#

之前我们只关心怎么通过Q,K,VQ,K,V来计算出注意力, 但是没有探究这个Q,K,VQ,K,V是怎么通过输入x得到的, 实际上我们有以下的流程图

也就是说, 先通过三个不同的线性层把输入x变换到Q,K,VQ,K,V三个矩阵, 然后通过多头注意力机制计算出注意力, 最后通过一个线性层把多个头的结果拼接起来, 再通过一个线性层得到最终的输出

如果要进行旋转位置编码, 注意每个头上都要应用

if self.use_rope:
    # 创建 token_positions [seq_len]
    token_positions = torch.arange(seq_len, device=x.device)
    # token_positions = self.token_positions

    # 为每个头应用 RoPE,形状 [batch, num_heads, seq, d_k]
    for head in range(self.num_heads):
        Q[:, head, :, :] = self.rope(Q[:, head, :, :], token_positions.unsqueeze(0))
        K[:, head, :, :] = self.rope(K[:, head, :, :], token_positions.unsqueeze(0))
python

还需要计算掩码矩阵, 首先考虑尺寸, 实际上mask的尺寸和QK^T的尺寸是一样的, 这里有:

Q: (batch, num_heads, seq_len, d_k)
K: (batch, num_heads, seq_len, d_k)
plaintext

所以mask的尺寸应该是(seq_len, seq_len)

再看mask的元素, 行表示query, 列表示key, 第i个query只能看到前i个’key’, 所以这是个下三角矩阵

allow_mask[i][j] = True 表示 query_i 可以 attend to key_j

        key_0  key_1  key_2  key_3
        ─────  ─────  ─────  ─────
query_0 │  ✓     ✗     ✗     ✗    ← 只能看自己及之前
query_1 │  ✓     ✓     ✗     ✗    ← 只能看自己及之前
query_2 │  ✓     ✓     ✓     ✗    ← 只能看自己及之前
query_3 │  ✓     ✓     ✓     ✓    ← 只能看自己及之前
plaintext
# 添加因果掩码
causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool()
# scaled_dot_product_attention 期望 mask True = 允许,False = 屏蔽
# 但我们的 causal_mask True = 屏蔽,所以需要取反
allow_mask = ~causal_mask
allow_mask = allow_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, self.num_heads, -1, -1)
python

然后把得到的K,Q,VK,Q,V传入之前实现的点积注意力机制, 注意这里的K,Q,VK,Q,V都是有head这个维度的, 所以返回的结果也是分head的, 要再通过rearrange把他们拼回去

# 使用 scaled_dot_product_attention 类进行计算
# 注意:参考实现使用 (K, Q, V, mask) 的顺序
attn = scaled_dot_product_attention(K, Q, V, allow_mask)
attended_values = attn()  # [batch, num_heads, seq, d_k]

# Rearrange back to [batch, seq, num_heads * d_k]
attended_values = rearrange(attended_values, "b h s d -> b s (h d)", h=self.num_heads)

output = self.output_proj(attended_values)  # [batch, seq, d_model]

return output
python

整个实现如下:

完善测试接口:

运行测试uv run pytest -k test_multihead_self_attention, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_multihead_self_attention
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 46 deselected / 2 selected                                                                                                                                                         

tests/test_model.py::test_multihead_self_attention PASSED
tests/test_model.py::test_multihead_self_attention_with_rope PASSED

=================================================================================== 2 passed, 46 deselected in 0.43s ====================================================================================
bash

组装Transformer#

回忆一下架构图里面, 但我们有了Transformer Block Embedding Layer RMSNorm Linear Softmax之后, 我们就可以来拼装完整的Transformer

单个的Transformer Block#

照着图拼就行了, 不过要注意类似ResNet的结构

完善测试接口:

运行测试uv run pytest -k test_transformer_block, 结果如下:

警告可以忽略, 这应该是课程组在写注释的时候写了LaTeX风格的字符导致解析出问题

完整的Transformer#

和上面一样, 中间的若干个Transformer Block Layer用nn.ModuleList`来实现

完善测试接口:

运行测试uv run pytest -k test_transformer_lm, 结果如下:

(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest -k test_transformer_lm
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 48 items / 46 deselected / 2 selected                                                                                                                                                         

tests/test_model.py::test_transformer_lm PASSED
tests/test_model.py::test_transformer_lm_truncated_input PASSED

=================================================================================== 2 passed, 46 deselected in 0.46s ====================================================================================
bash

计算次数估算#

模型训练的时候, 绝大部分运算都是矩阵乘法里面的数值乘法运算, 注意以下事实: 如果ARm×nA \in \mathbb{R}^{m \times n}, BRn×pB \in \mathbb{R}^{n \times p}, 那么A×BA \times B的计算次数是2×m×n×p2 \times m \times n \times p

参数数量计算和内存用量#

考虑第一个问题, 对于以下配置的GPT2XLGPT-2 XL, 有多少个可训练的参数? 如果每个参数都是单精度浮点数, 那么读取这个模型需要多少内存?

vocab_size = 50257
context_length = 1024
num_layers = 48
d_model = 1600
nun_heads = 25
d_ff = 6400
plaintext

对于Embedding Layer, 有vocabsize×dmodelvocabsize \times d_{model}个参数

50257 * 1600 = 80,411,200
plaintext

对于一个Transformer Block, 有两个RMSNorm, 每个RMSNormdmodeld_{model}个参数, 有一个multihead_self_attention, 其中有四个矩阵Q,K,V,outputQ,K,V,output, 参数量总共为4dmodeldmodel4*d_{model}*d_{model}, 对于positionwise_feedforward, 有两个矩阵, 总参数量为2dffdmodel2*d_{ff}*d_{model}

2 * 1600 + 4 * 1600 * 1600 + 2 * 1600 * 6400 = 3200 + 10240000 + 20480000 = 30,723,200
plaintext

除此之外, 输出前还有一个RMSNorm,一个Linear, 总参数量为dmodel+dmodelvocabsized_{model} + d_{model}*vocabsize

1600 + 1600 * 50257 = 80,412,800
plaintext

故总共有

80411200 + 30723200 * 48 + 80412800 = 1,635,537,600
plaintext

大概1.6B的参数量

1,635,537,600 * 4 /1024 /1024 /1024 = 6.1GB
plaintext

需要内存约6.1GB

FLOPS次数估计#

如果这个模型前向传播一次, 需要多少次FLOPS?

在一个transformer_block中, 一次multihead_self_attention的计算次数由以下部分组成:

Q投影: 2 * seq_len * d_model * d_model
K投影: 2 * seq_len * d_model * d_model
V投影: 2 * seq_len * d_model * d_model
输出投影: 2 * seq_len * d_model * d_model
QK^T: 2 * seq_len * seq_len * d_model
Attn V: 2 * seq_len * seq_len * d_model
共计8 * seq_len * d_model² + 4 * seq_len² * d_model = 20,971,520,000 + 67,108,864,000 = 88,080,384,000
plaintext

一次positionwise_feedforward的计算次数由以下部分组成:

w_1x: 2 * seq_len * d_model * d_ff
W_3x: 2 * seq_len * d_model * d_ff
W2 @ (SiLU(W1x) ⊙ W3x): 2 * seq_len * d_model * d_ff
共计6 * seq_len * d_model * d_ff = 62,914,560,000
plaintext

输出前有一个Output_Embedding的线性层, 计算次数为

2 * seq_len * d_model * vocab_size = 164,682,137,600
plaintext

代入配置, 得到

48 * (88,080,384,000 + 62,914,560,000) + 164,682,137,600 = 7,412,439,449,600
plaintext

大约7.4TeraFLOPs

Stanford CS336 Assignment 1(Part 2) - Transformer的实现
https://astro-pure.js.org/blog/cs336_assignment1_part2
Author Ziyu(Albert) Li 李子煜
Published at January 28, 2026
Comment seems to stuck. Try to refresh?✨