Ziyu Li's Homepage

Back

从一份TinyStories的故事集到Transformer模型, 就像是从原始人一夜当中走进了现代

CS336 Assignment 1#

这是我做过的第一个(也许也会是最后一个)几乎没什么Skeleton Code的Lab, 这门课的目标是让学习者彻底搞懂大模型的原理,并且从”Scratch”来从头构建大模型。

任务拆分开,大概有这么几点:

首先是工程架构部分:

  • bpe.py tokneizer.py: Byte-pair encoding(BPE) tokenizer, 在字节级别上实现一个分词器。
  • transformer.py: Transformer language model (LM), 实现Transformer的各模块并且组合成可实例化的Transformer类
  • transformer.py: The cross-entropy loss function and the AdamW optimizer, 实现AdamW优化器和损失函数
  • train_transformer.py: The training loop, with support for serializing and loading model and optimizer state, 实现训练循环

其次是跑训练和测试:

  • train_bpe_tinystories.py: Train a BPE tokenizer on the TinyStories dataset. 在Tinystories数据集上训练这个BPE分词器
  • tokenizer_experiments.py: Encode and Decode, 基于训练得到的词汇表和字节合并来对语料进行编码和解码
  • tokenizer_experiments.py: Run your trained tokenizer on the dataset to convert it into a sequence of integer IDs. 在数据集上应用分词器, 把文字数据集转化成整数序列
  • train_transformer.py: Train a Transformer LM on the TinyStories dataset. 利用分词器的结果训练Transformer模型

作业要求#

  • torch.nn.parameter
  • torch.nn
  • torch.optim.Optimizer(作为基类)

这个作业主要还是让学习者实现算法和工程架构,至于底层那些并行计算之类的,交给PyTorch的开发者去做吧

评测框架#

没有给出直接一键运行的测试,虽然测试的输入和输出ASSERT是实现好的,但是要自己去接这个测试接口,测试逻辑在./assignment1-basics/tests里面,测试接口在./assignment1-basics/tests/adapters.py里面

比如在transformer.py我们实现了Linear Layer,对应的测试接口在:

transformer.Linear是我们实现的Linear类,实现这个接口之后,运行

uv run pytest -k test_linear
bash

就可以运行预设的测试了

下载数据#

有四个数据集,两个Train两个Valid,分别对应TinyStories和OpenWebText_Result数据集

mkdir -p data
cd data

wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt

wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
gunzip owt_train.txt.gz
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
gunzip owt_valid.txt.gz

cd ..
bash

如果有网络问题,可以从https://hf-mirror.com/下载,替换下载命令里面的地址就行了

环境配置#

本lab用uv做包管理器,我之前没用过完全不熟练,我是直接照着他文档里面的命令安装的,先安装uv

pip install uv
bash

当运行某个py文件的时候,用uv命令运行,如果有依赖缺失会自动下载

uv run <python_file_path>
bash

注意给uv也像pip一样配置一个国内源,不然安装一些大包可能要安装到明天早上去

# 推荐使用清华源
echo 'export UV_DEFAULT_INDEX="https://pypi.tuna.tsinghua.edu.cn/simple"'>> ~/.bashrc

# 或者用阿里源
# echo 'export UV_DEFAULT_INDEX="https://mirrors.aliyun.com/pypi/simple/"' >> ~/.bashrc

# 让配置立即生效
source ~/.bashrc
# 转载自https://zhuanlan.zhihu.com/p/1930714592423703026
bash

BPE算法#

文档中给出一个例子(stylized example),考虑以下的语料

low low low low low
lower lower widest widest widest
newest newest newest newest newest newest
plaintext

我们有一个词语集合Vocabulary(代码里一般称为Vocab),可以理解为是一个词汇表,这个词汇表一开始内容很少,然后在训练的时候慢慢增长

假如我们通过whitespace来对语料进行分割,我们就可以得到frequency table:

{low: 5, lower: 2, widest: 3, newest: 6}
plaintext

但我们要换一种更方便的数据结构来表示,比如用dict[tuple[bytes], int],那么这个表被表示成:

{(l,o,w):5,(l,o,w,e,r):2, ...}
plaintext

接下来我们统计这个frequency table里面byte(char)对的两两组合,这就是个计数的工程,得到的结果是:

{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 9, st: 9, ne: 6, ew: 6}
plaintext

找到那些出现次数最多的组,这个例子里面是

{es:9, st:9}
plaintext

挑选字典序更大的那个,这里是st,把所有的’s’ ‘t’ 组合成st,所以这个frequency table就变成了

 {(l,o,w): 5, (l,o,w,e,r): 2, (w,i,d,e,st): 3, (n,e,w,e,st): 6}
plaintext

然后再把’st’添加到vocab里面去

再重复上面这个那个计数的步骤,这一次’e’ ‘st’是出现最频繁的,所以对他进行合并

如果我们重复合并到没有可以继续的了,我们的merges应该是:

['s t', 'e st', 'o w', 'l ow', 'w est', 'n e',
'ne west', 'w i', 'wi d', 'wid est', 'low e', 'lowe r']
plaintext

其中的每一项代表我们在那一次合并当中把两个什么东西(byte)合并了

当然实际上并不一定要合并到最后,我们可以指定合并的次数,比如例子中指定为6次,那么merges应该是:

['s t', 'e st', 'o w', 'l ow', 'w est', 'n e']
plaintext

这种情况下我们的vocab将会变成

[<|endoftext|>, [...256 BYTE CHARS], st, est, ow, low, west, ne]
plaintext

其中前面两个是init vocab的时候就已经有的,然后每一次merge的时候会往vocab里面加一项

如此一来的话,词语newest就会被分词成为’ne’ ‘west’, 换言之这个merge的过程就是把分词这件事情从细变粗的过程,最开始每一个词语都是一个一个字母分的,现在有一部分字母被聚合了

并行预分词 Parallelizing pre-tokenization#

chunk和对每个chunk的处理#

首先要对原有的文本进行预分词,以TinyStories数据集为例:

u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.
<|endoftext|>
Once upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.
<|endoftext|>

They went back to the living room and cleaned up their toys. They decided to build something together. They made a big house with a garden and a fence. They put their cars and dolls inside. They were happy and proud of their work.
Mommy and Daddy came to see their house. They praised them and gave them a treat. It was a lemon cake. It was sour, but they liked it. They learned that sharing is caring, and that family is sweet.
<|endoftext|>


Lucy and the little girl played together happily. In the end, they both learnt an important lesson: be peaceful, kind, and understanding when faced with a conflict. And that is why Lucy and the little girl became great friends.
<|endoftext|>

At the end of the day, Tom and Max were tired. They had played all day and had lots of fun. They said goodbye to each other and went to their homes. Before going to sleep, they both did another easy stretch. Tom knew that tomorrow would be another happy morning.
<|endoftext|>
plaintext

忽略数据的内容,注意到故事和故事之间是用<|endoftext|>进行分割的,所以我们把两个<|endoftext|>之间的内容成为一个chunk,首先要把chunk提取出来并且得到每一个chunk的frequency_dict

课程组给了一段现成的代码,在./assignment1-basics/cs336_basics/pretokenization_example.py里面,这个函数的作用是读取原始的Corpus,然后把chunk的边界,大概约等于上面说的<|endoftext|>的位置返回

def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
python

注意这个分割字符(在这个例子里面是<|endoftext|>),是不唯一的,当然有可能某个corpus的是用<|Hello|>来区分段落的,只要数据类型是bytes就行了

既然需要并行,我们首先要写处理单个chunk的函数:构造如下

要注意文本处理的架构(层次)是: 整个语料->单个chunk->chunk中的一个segment->segment当中的每个文本

如果split_special_token = [b"<|endoftext|>"],文本为:

Story 1 content here.<|endoftext|>Story 2 content here.<|endoftext|>
plaintext

那么经过分割后的text_segments

[
    "Story 1 content here.",  # 文本段
    "<|endoftext|>",          # 分隔符(被保留)
    "Story 2 content here.",  # 文本段
    "<|endoftext|>"           # 分隔符(被保留)
]
plaintext

然后再遍历这个text_segments,得到的结果格式类似于(不一定准确,只是形式上类似,具体还要看这个PAT对每个segment的分割规则):

{
    (b'S', b't', b'o', b'r', b'y'): 2,  # "Story" 出现2次
    (b' ',): 6,                          # 单个空格出现6次
    (b'1',): 1,                          # "1" 出现1次  
    (b'2',): 1,                          # "2" 出现1次
    (b'c', b'o', b'n', b't', b'e', b'n', b't'): 2,  # "content" 出现2次
    (b'h', b'e', b'r', b'e'): 2,         # "here" 出现2次
    (b'.',): 2,                          # "." 出现2次
}
plaintext

这就完成了对每个chunk的内部切割,接下来需要一个并行代码把每个chunk映射到这个函数上

同时并行处理多个chunk#

我们应当用多进程同时处理多个chunk,每个chunk会独立的返回这个chunk内容的frequency_dict,然后把他们concat到一起去就得到了整个语料的频率字典

注意这个concat其实不是真的concat, 因为不同的chunk里面有可能出现相同的词语,比如在chunk1和chunk2的frequency_dict里面也许都有

{(b'S', b't', b'o', b'r', b'y'): 2}
plaintext

那显然concat之后得到的应该是:

{(b'S', b't', b'o', b'r', b'y'): 4}
plaintext

所以需要对并行传回来的结果进行遍历,然后合并同类项

# bpe.py
def merge_frequencies(frequency_dict): # Calculate the frequencies from each chunks and sum them together

    total_frequencies = defaultdict(int) # 初始化结果

    for every_frequency_dict in frequency_dict: 
    # 遍历每个chunk的frequency_dict
        for pretoken_bytes, count in every_frequency_dict.items():

            total_frequencies[pretoken_bytes] += count
            # 保证不同chunk之间的同样的pretoken_bytes的计数不重不漏

    return total_frequencies
python

初始化Vocab和Merges#

从前面那个例子可以看到, 训练过程本质上就是更新Vocab和Merges这两个结果的过程, 所以先对他们进行初始化

# bpe.py
def initialize_vocab_and_merges(special_tokens):
    vocab = {}

    for i in range(256):
        vocab[i] = bytes([i]) # 初始的一些默认bytes

    for special_token in special_tokens:
        vocab[len(vocab)] = special_token.encode("utf-8") # 我们自己附加的special_tokens

    return vocab,[]
python

随着训练进行, Vocab和Merges会越来越大

初始化pair计数#

我们现在只得到了每个词语的计数, 但是如例子里面, 我们要对每个词语遍历拆出pair, 去对pair进行计数, 比如说输入是

{(b'S', b't', b'o', b'r', b'y'): 4}
plaintext

那输出大概是

{'St':4, 'to':4, 'or':4, 'ry':4}
plaintext

注意这里已经不分chunk了, 所以每次得到的结果就是全局的更新量了

找到最频繁出现的pair#

假如有类似于

{'St':4, 'to':4, 'or':4, 'ry':4}
plaintext

pair的频率字典, 那么需要一个函数获取里面出现次数最多的那个pair

Merge操作#

接下来这个是我觉得BPE里面最难的实现, 首先回顾一下我们现在有了什么:

frequency_dict: dict(词语:频率)

pair_frequencies: dict(字符对:频率)

pair_to_tokens: dict(字符,set(词语)), 这个来自get_initial_pair_frequencies函数, 标注哪些词语含有这个字符

best_pair: tuple(字符1, 字符2)

还有一个比较麻烦的地方在于这一步有性能要求, 如果太暴力的话可能会导致测试过不去, 一个简单的想法是遍历frequency_dict, 然后找到所有含有best_pair的词语对他们进行修改, 但是这样是无法通过测试的

想一个取巧的办法, 既然我们有pair_to_tokens, 我们可以先找出含有best_pair的词语然后对他们进行遍历, 这就不需要遍历所有的词语了

还要修改pair_frequenciespair_to_tokens, 因为老的pair已经没有了并且由于产生了新的词语,会产生新的pair, 接着上面的函数代码:

训练主函数#

现在我们来写训练循环, 架构非常简单, 分为以下几个步骤

数据准备:

  • 加载训练语料库文件
  • 按分隔符(如 <|endoftext|>)将文件分割成多个数据块

预处理:

  • 对每个数据块进行预分词,统计所有预分词的频率
  • 初始化词频表 frequency_dict

字节对统计:

  • 计算所有相邻字节对的出现频率
  • 建立字节对频率表和映射关系

while 继续训练条件: 步骤1: 找到最优合并对 best_pair = 找到频率最高的字节对(pair_frequencies)

步骤2: 执行合并 merge_pair(frequency_dict, pair_frequencies, pair_to_tokens, best_pair)

步骤3: 更新模型状态 更新合并记录(merges列表) 更新词汇表(vocab集合)

BPE训练流程
├── 初始化阶段
│   ├── 加载并分块文件
│   ├── 初始化词频表  
│   └── 初始化字节对频率
├── 训练迭代 (循环开始)
│   ├── 选择最佳字节对
│   ├── 合并操作
│   │   ├── 更新词频统计
│   │   ├── 重建受影响的预分词
│   │   └── 更新字节对映射
│   ├── 记录合并操作
│   └── 扩展词汇表
└── 训练结束
plaintext

评测接口#

非常容易, 只要修改adapters.py的框架代码就行了:

# adapters.py

from cs336_basics import bpe

def run_train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:

    input_path_str = str(input_path)
    vocab,merges = bpe.train_bpe(input_path=input_path_str,vocab_size=vocab_size,special_tokens=special_tokens,Debug=False)
    
    return vocab, merges
python

这里不需要手动指定路径, 测试自己存了一些demo文件去进行训练并和ref进行比对来判断正误, 启动测试的代码是:

root@autodl-container-8d994fbd73-e5baa69e:~/autodl-tmp/Stanford_CS336/assignment1-basics# uv run pytest tests/test_train_bpe.py
bash

稍微注意一下启动时候的路径, 测试通过如下:

(base) root@autodl-container-8d994fbd73-e5baa69e:~/autodl-tmp/Stanford_CS336/assignment1-basics# uv run pytest tests/test_train_bpe.py 
============================================================================ test session starts =============================================================================
platform linux -- Python 3.12.3, pytest-8.4.1, pluggy-1.6.0
rootdir: /root/autodl-tmp/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 items                                                                                                                                                                                                                                   

tests/test_train_bpe.py::test_train_bpe_speed PASSED
tests/test_train_bpe.py::test_train_bpe PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens PASSED

================================================================================================================= 3 passed in 3.04s =================================================================================================================
bash

在TinyStories和OpenWebText数据集上训练BPE分词器#

这里只需要自己写一个脚本调用之前实现的训练过程就行了, 唯一要注意的就是MergesVocab持久化时候的格式问题

Vocab的输出格式是:

vocab = {0: b'hello', 1: b'world', ...}
plaintext

需要转化为

vocab_unicode = {'hello': 0, 'world': 1, ...}
plaintext

实际上就是把原始格式转化成人类可读的Unicode格式就可以了

训练接口函数

另一个数据集只需要修改一下路径就好了, 不再赘述

最后得到的merges.txtvocab.json类似于

Ġ t
h e
Ġ a
Ġ s
Ġ w
n d
Ġt he
e d
Ġ b
Ġt o
plaintext
{
  "Ā": 0,
  "ā": 1,
  "Ă": 2,
  "ă": 3,
  "Ą": 4,
  "ą": 5,
  "Ć": 6,
  "ć": 7,
  "Ĉ": 8,
  "ĉ": 9,
  "Ċ": 10,
}
plaintext

分词器: 编码与解码#

VocabMerges被训练完成后, 我们可以用他们然后对给定的语料进行解码/编码

编码#

本质上就是把语料先预分词, 然后应用已有的Merges, 最后再编码成整数序列, 举例如下:

Vocab = {0: b' ', 1: b'a', 2:b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b'at'}
Merges = [(b't', b'h'), (b' ', b'c'), (b' ', 'a'), (b'th', b'e'), (b' a', b't')]
plaintext

对’the cat ate’进行编码的步骤如下:

首先预分词用空格来进行token划分,得到

['the', ' cat', ' ate']
plaintext

对于第一个token:‘the’, 他的表示方式是:

[b't', b'h', b'e']
plaintext

应用两次Merges中的合并就变成:

[b'the']
plaintext

然后转化成整数序列就是[9], 其余两个token也一样, 最后得到的整数序列是:

[9, 7, 1, 5, 10, 3]
plaintext

解码#

解码过程相对简单, 对于输入的整数序列, 只需要一个个的查找在Vocab中对应的词并且concat起来就好了

Tokenizer类的实现(遵循实验文档的接口架构)#

# tokenizer.py
class tokenizer():

    def __init__(self,vocab:dict[int,bytes],merges:list[tuple[bytes,bytes]],special_tokens:list[str]=None):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens

        self.vocab_reverse = {token_bytes:token_id for token_id,token_bytes in self.vocab.items()} # 反向映射方便查找
        self.PAT = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
python

先实现类方法apply_bpe_merges, 传入一个token, 查找Merges看是否有可以应用在这个token上的合并操作, 若有就应用后返回

实现一个一般的编码函数, 单纯的通过PAT去匹配, 然后对各PAT分割出的部分应用Merge, 最后再通过self.vocab_reverse转化成整数序列完成编码

考虑到有可能有self.special_tokens, 所以要在上面这个函数上实现一个更一般的encode函数

文档要求我们以流式处理和惰性求值的方式来encode, 非常简单, 实现如下:

# tokenizer.py
    def encode_iterable(self,iterable:Iterable[str])->Iterable[int]:
        for text in iterable:
            token_ids = self.encode(text)

            for token_id in token_ids:
                yield token_id
python

解码部分更容易, 就是一个查表然后join的函数而已

# tokenizer.py
    def decode(self,ids:list[int])->str:
        byte_sequences = []
        for token_id in ids:
            if token_id in self.vocab.keys():
                byte_sequences.append(self.vocab[token_id])


        combined_bytes = b''.join(byte_sequences)

        text = combined_bytes.decode('utf-8',errors='replace')

        return text
python

实验文档还要求我们实现一个类的构造器, 其实很容易, 只要读取那些被持久化了的MergesVocab,然后把他转化成持久化前的数据类型就行了, 相当于是前面训练中那个save_vocab_merge的逆向

评测接口#

和之前一样, 修改adapters.py即可

# adapters.py
from cs336_basics import tokenizer

def get_tokenizer(
    vocab: dict[int, bytes],
    merges: list[tuple[bytes, bytes]],
    special_tokens: list[str] | None = None,
) -> Any:

    return tokenizer.tokenizer(vocab=vocab,merges=merges,special_tokens=special_tokens)
python

启动测试:

uv run pytest tests/test_tokenizer.py
bash

测试结果:

这个测试耗时在现有的复杂度下非常长, 而且也有一个内存限制测试是预计不通过的(测试文件里面写了这个测试本来就预计不通过), 在实验文档里有写可以通过Cpp或Rust实现来显著提升速度

To test your BPE training function against our provided tests, you will first need to implement the
test adapter at [adapters.run_train_bpe]. Then, run uv run pytest tests/test_train_bpe.py.
Your implementation should be able to pass all tests. Optionally (this could be a large time-investment),
you can implement the key parts of your training method using some systems language, for instance
C++ (consider cppyy for this) or Rust (using PyO3). If you do this, be aware of which operations
require copying vs reading directly from Python memory, and make sure to leave build instructions, or
make sure it builds using only pyproject.toml. Also note that the GPT-2 regex is not well-supported
in most regex engines and will be too slow in most that do. We have verified that Oniguruma is
reasonably fast and supports negative lookahead, but the regex package in Python is, if anything,
even faster.
plaintext

编码数据集#

只需要利用tokenizer类以及已有的Merges Vocab即可

只需要注意最后以np.uint16格式保存即可

Stanford CS336 Assignment 1(Part 1) - 手搓BPE
https://astro-pure.js.org/blog/cs336_assignment1_part1
Author Ziyu(Albert) Li 李子煜
Published at January 23, 2026
Comment seems to stuck. Try to refresh?✨