从一份TinyStories的故事集到Transformer模型, 就像是从原始人一夜当中走进了现代
CS336 Assignment 1#
这是我做过的第一个(也许也会是最后一个)几乎没什么Skeleton Code的Lab, 这门课的目标是让学习者彻底搞懂大模型的原理,并且从”Scratch”来从头构建大模型。
任务拆分开,大概有这么几点:
首先是工程架构部分:
bpe.pytokneizer.py: Byte-pair encoding(BPE) tokenizer, 在字节级别上实现一个分词器。transformer.py: Transformer language model (LM), 实现Transformer的各模块并且组合成可实例化的Transformer类transformer.py: The cross-entropy loss function and the AdamW optimizer, 实现AdamW优化器和损失函数train_transformer.py: The training loop, with support for serializing and loading model and optimizer state, 实现训练循环
其次是跑训练和测试:
train_bpe_tinystories.py: Train a BPE tokenizer on the TinyStories dataset. 在Tinystories数据集上训练这个BPE分词器tokenizer_experiments.py: Encode and Decode, 基于训练得到的词汇表和字节合并来对语料进行编码和解码tokenizer_experiments.py: Run your trained tokenizer on the dataset to convert it into a sequence of integer IDs. 在数据集上应用分词器, 把文字数据集转化成整数序列train_transformer.py: Train a Transformer LM on the TinyStories dataset. 利用分词器的结果训练Transformer模型
作业要求#
torch.nn.parametertorch.nntorch.optim.Optimizer(作为基类)
这个作业主要还是让学习者实现算法和工程架构,至于底层那些并行计算之类的,交给PyTorch的开发者去做吧
评测框架#
没有给出直接一键运行的测试,虽然测试的输入和输出ASSERT是实现好的,但是要自己去接这个测试接口,测试逻辑在./assignment1-basics/tests里面,测试接口在./assignment1-basics/tests/adapters.py里面
比如在transformer.py我们实现了Linear Layer,对应的测试接口在:
# Test For Implement of Linear Layer
from cs336_basics import transformer
def run_linear(
d_in: int,
d_out: int,
weights: Float[Tensor, " d_out d_in"],
in_features: Float[Tensor, " ... d_in"],
) -> Float[Tensor, " ... d_out"]:
"""
Given the weights of a Linear layer, compute the transformation of a batched input.
Args:
in_dim (int): The size of the input dimension
out_dim (int): The size of the output dimension
weights (Float[Tensor, "d_out d_in"]): The linear weights to use
in_features (Float[Tensor, "... d_in"]): The output tensor to apply the function to
Returns:
Float[Tensor, "... d_out"]: The transformed output of your linear module.
"""
# raise NotImplementedError
linear_module = transformer.Linear(in_features=d_in,out_features=d_out)
linear_module.weight.data = weights
return linear_module(in_features)
pythontransformer.Linear是我们实现的Linear类,实现这个接口之后,运行
uv run pytest -k test_linearbash就可以运行预设的测试了
下载数据#
有四个数据集,两个Train两个Valid,分别对应TinyStories和OpenWebText_Result数据集
mkdir -p data
cd data
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-train.txt
wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStoriesV2-GPT4-valid.txt
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_train.txt.gz
gunzip owt_train.txt.gz
wget https://huggingface.co/datasets/stanford-cs336/owt-sample/resolve/main/owt_valid.txt.gz
gunzip owt_valid.txt.gz
cd ..bash如果有网络问题,可以从https://hf-mirror.com/下载,替换下载命令里面的地址就行了
环境配置#
本lab用uv做包管理器,我之前没用过完全不熟练,我是直接照着他文档里面的命令安装的,先安装uv
pip install uvbash当运行某个py文件的时候,用uv命令运行,如果有依赖缺失会自动下载
uv run <python_file_path>bash注意给uv也像pip一样配置一个国内源,不然安装一些大包可能要安装到明天早上去
# 推荐使用清华源
echo 'export UV_DEFAULT_INDEX="https://pypi.tuna.tsinghua.edu.cn/simple"'>> ~/.bashrc
# 或者用阿里源
# echo 'export UV_DEFAULT_INDEX="https://mirrors.aliyun.com/pypi/simple/"' >> ~/.bashrc
# 让配置立即生效
source ~/.bashrc
# 转载自https://zhuanlan.zhihu.com/p/1930714592423703026bashBPE算法#
文档中给出一个例子(stylized example),考虑以下的语料
low low low low low
lower lower widest widest widest
newest newest newest newest newest newestplaintext我们有一个词语集合Vocabulary(代码里一般称为Vocab),可以理解为是一个词汇表,这个词汇表一开始内容很少,然后在训练的时候慢慢增长
假如我们通过whitespace来对语料进行分割,我们就可以得到frequency table:
{low: 5, lower: 2, widest: 3, newest: 6}plaintext但我们要换一种更方便的数据结构来表示,比如用dict[tuple[bytes], int],那么这个表被表示成:
{(l,o,w):5,(l,o,w,e,r):2, ...}plaintext接下来我们统计这个frequency table里面byte(char)对的两两组合,这就是个计数的工程,得到的结果是:
{lo: 7, ow: 7, we: 8, er: 2, wi: 3, id: 3, de: 3, es: 9, st: 9, ne: 6, ew: 6}plaintext找到那些出现次数最多的组,这个例子里面是
{es:9, st:9}plaintext挑选字典序更大的那个,这里是st,把所有的’s’ ‘t’ 组合成st,所以这个frequency table就变成了
{(l,o,w): 5, (l,o,w,e,r): 2, (w,i,d,e,st): 3, (n,e,w,e,st): 6}plaintext然后再把’st’添加到vocab里面去
再重复上面这个那个计数的步骤,这一次’e’ ‘st’是出现最频繁的,所以对他进行合并
如果我们重复合并到没有可以继续的了,我们的merges应该是:
['s t', 'e st', 'o w', 'l ow', 'w est', 'n e',
'ne west', 'w i', 'wi d', 'wid est', 'low e', 'lowe r']plaintext其中的每一项代表我们在那一次合并当中把两个什么东西(byte)合并了
当然实际上并不一定要合并到最后,我们可以指定合并的次数,比如例子中指定为6次,那么merges应该是:
['s t', 'e st', 'o w', 'l ow', 'w est', 'n e']plaintext这种情况下我们的vocab将会变成
[<|endoftext|>, [...256 BYTE CHARS], st, est, ow, low, west, ne]plaintext其中前面两个是init vocab的时候就已经有的,然后每一次merge的时候会往vocab里面加一项
如此一来的话,词语newest就会被分词成为’ne’ ‘west’, 换言之这个merge的过程就是把分词这件事情从细变粗的过程,最开始每一个词语都是一个一个字母分的,现在有一部分字母被聚合了
并行预分词 Parallelizing pre-tokenization#
chunk和对每个chunk的处理#
首先要对原有的文本进行预分词,以TinyStories数据集为例:
u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.
<|endoftext|>
Once upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.
<|endoftext|>
They went back to the living room and cleaned up their toys. They decided to build something together. They made a big house with a garden and a fence. They put their cars and dolls inside. They were happy and proud of their work.
Mommy and Daddy came to see their house. They praised them and gave them a treat. It was a lemon cake. It was sour, but they liked it. They learned that sharing is caring, and that family is sweet.
<|endoftext|>
Lucy and the little girl played together happily. In the end, they both learnt an important lesson: be peaceful, kind, and understanding when faced with a conflict. And that is why Lucy and the little girl became great friends.
<|endoftext|>
At the end of the day, Tom and Max were tired. They had played all day and had lots of fun. They said goodbye to each other and went to their homes. Before going to sleep, they both did another easy stretch. Tom knew that tomorrow would be another happy morning.
<|endoftext|>plaintext忽略数据的内容,注意到故事和故事之间是用<|endoftext|>进行分割的,所以我们把两个<|endoftext|>之间的内容成为一个chunk,首先要把chunk提取出来并且得到每一个chunk的frequency_dict
课程组给了一段现成的代码,在./assignment1-basics/cs336_basics/pretokenization_example.py里面,这个函数的作用是读取原始的Corpus,然后把chunk的边界,大概约等于上面说的<|endoftext|>的位置返回
def find_chunk_boundaries(
file: BinaryIO,
desired_num_chunks: int,
split_special_token: bytes,
) -> list[int]:python注意这个分割字符(在这个例子里面是<|endoftext|>),是不唯一的,当然有可能某个corpus的是用<|Hello|>来区分段落的,只要数据类型是bytes就行了
既然需要并行,我们首先要写处理单个chunk的函数:构造如下
def pretokenize_chunk(args): # Deal with a chunk
# bpe.py
start,end,input_path,split_special_token = args
# start: chunk的开始索引
# end: chunk的结束索引
# input_path: 文件目录
# split_special_token: chunk内的分隔符
with open(input_path,"rb") as f: # 打开文件
f.seek(start) # 文件指针移动start偏移量(将文件读取指针移动到 start 指定的字节位置)
chunk_bytes = f.read(end - start).decode("utf-8", errors="ignore")
# 读取start到end之间的数据
split_pattern = "|".join(re.escape(token) for token in split_special_token)
# 现在这个split_pattern能够正则匹配任意的split_special_token内的分隔符
text_segments = re.split(f"({split_pattern})",chunk_bytes)
# 对文本进行分割, 返回交替的结果, 形如[文本段1, 分隔符1, 文本段2, 分隔符2, ...]
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
# GPT-2 风格的正则分词
frequency_dict = defaultdict(int) # 初始化频率字典,注意用defaultdict来确保默认值为0,避免后面访问下标不存在的问题
for segment in text_segments: # 遍历文本段
if segment not in split_special_token: # 跳过分隔符,只处理实际文本
for match in re.finditer(PAT,segment): # 在这个实际文本当中找所有的匹配项,返回迭代器
pretoken = match.group()
pretoken_bytes = pretoken.encode("utf-8") # 编码成UTF-8
pretoken_bytes_tuple = tuple(bytes([b]) for b in pretoken_bytes)
# 转化成字节tuple,例如("h", "e", "l", "l", "o") 而不是 b"hello"
frequency_dict[pretoken_bytes_tuple] += 1
# 更新频率
return dict(frequency_dict)python要注意文本处理的架构(层次)是: 整个语料->单个chunk->chunk中的一个segment->segment当中的每个文本
如果split_special_token = [b"<|endoftext|>"],文本为:
Story 1 content here.<|endoftext|>Story 2 content here.<|endoftext|>plaintext那么经过分割后的text_segments为
[
"Story 1 content here.", # 文本段
"<|endoftext|>", # 分隔符(被保留)
"Story 2 content here.", # 文本段
"<|endoftext|>" # 分隔符(被保留)
]plaintext然后再遍历这个text_segments,得到的结果格式类似于(不一定准确,只是形式上类似,具体还要看这个PAT对每个segment的分割规则):
{
(b'S', b't', b'o', b'r', b'y'): 2, # "Story" 出现2次
(b' ',): 6, # 单个空格出现6次
(b'1',): 1, # "1" 出现1次
(b'2',): 1, # "2" 出现1次
(b'c', b'o', b'n', b't', b'e', b'n', b't'): 2, # "content" 出现2次
(b'h', b'e', b'r', b'e'): 2, # "here" 出现2次
(b'.',): 2, # "." 出现2次
}plaintext这就完成了对每个chunk的内部切割,接下来需要一个并行代码把每个chunk映射到这个函数上
同时并行处理多个chunk#
我们应当用多进程同时处理多个chunk,每个chunk会独立的返回这个chunk内容的frequency_dict,然后把他们concat到一起去就得到了整个语料的频率字典
def load_and_chunk_file(
input_path: str,
desired_num_chunks: int,
split_special_token: list[str],
Debug=False):
# bpe.py
with open(input_path,"rb") as f: # 打开文件
num_processes = 16 # 设定并行的进程数量
split_token_bytes = split_special_token[0].encode("utf-8") # 拿到分隔符并且把他encode成bytes
boundaries = find_chunk_boundaries(file=f,desired_num_chunks=desired_num_chunks,split_special_token=split_token_bytes)
# 调用给出的find_chunk_boundaries函数, 这里返回的就是原始语料里面每个分隔符的位置,在这个项目里面就是每个<|endoftext|>的位置
with multiprocessing.Pool(processes=num_processes) as pool: # 并行处理
chunk_args=[(start,end,input_path,split_special_token) for start,end in zip(boundaries[:-1],boundaries[1:])]
# 四个参数, 注意boundaries的数据类型是list[int], 从boundaries[0,1]开始滑动取得每一组start和end
# input_path和split_special_token都是不变的
results=pool.map(pretokenize_chunk,chunk_args)
# 传入参数并且获得并行传回的结果
total_frequencies = merge_frequencies(results)
# 把结果"concat"起来
return total_frequenciespython注意这个concat其实不是真的concat, 因为不同的chunk里面有可能出现相同的词语,比如在chunk1和chunk2的frequency_dict里面也许都有
{(b'S', b't', b'o', b'r', b'y'): 2}plaintext那显然concat之后得到的应该是:
{(b'S', b't', b'o', b'r', b'y'): 4}plaintext所以需要对并行传回来的结果进行遍历,然后合并同类项
# bpe.py
def merge_frequencies(frequency_dict): # Calculate the frequencies from each chunks and sum them together
total_frequencies = defaultdict(int) # 初始化结果
for every_frequency_dict in frequency_dict:
# 遍历每个chunk的frequency_dict
for pretoken_bytes, count in every_frequency_dict.items():
total_frequencies[pretoken_bytes] += count
# 保证不同chunk之间的同样的pretoken_bytes的计数不重不漏
return total_frequenciespython初始化Vocab和Merges#
从前面那个例子可以看到, 训练过程本质上就是更新Vocab和Merges这两个结果的过程, 所以先对他们进行初始化
# bpe.py
def initialize_vocab_and_merges(special_tokens):
vocab = {}
for i in range(256):
vocab[i] = bytes([i]) # 初始的一些默认bytes
for special_token in special_tokens:
vocab[len(vocab)] = special_token.encode("utf-8") # 我们自己附加的special_tokens
return vocab,[]python随着训练进行, Vocab和Merges会越来越大
初始化pair计数#
我们现在只得到了每个词语的计数, 但是如例子里面, 我们要对每个词语遍历拆出pair, 去对pair进行计数, 比如说输入是
{(b'S', b't', b'o', b'r', b'y'): 4}plaintext那输出大概是
{'St':4, 'to':4, 'or':4, 'ry':4}plaintext注意这里已经不分chunk了, 所以每次得到的结果就是全局的更新量了
#bpe.py
def get_initial_pair_frequencies(frequency_dict,Debug=False):
pair_freq = defaultdict(int)
pair_to_tokens = defaultdict(set)
# 初始化
for pretoken_bytes, count in frequency_dict.items():
# 遍历词频字典, 获取每个"单词"和他的计数
for i in range(len(pretoken_bytes)-1): # 遍历这个单词的所有相邻字符
pair = ((pretoken_bytes[i],), (pretoken_bytes[i+1],)) # 取得相邻的pair
pair_freq[pair] = pair_freq.get(pair,0) + count # 更新这个pair的计数
pair_to_tokens[pair].add(pretoken_bytes) # 记住某一个pair出现在哪个单词里面
return pair_freq, pair_to_tokenspython找到最频繁出现的pair#
假如有类似于
{'St':4, 'to':4, 'or':4, 'ry':4}plaintext的pair的频率字典, 那么需要一个函数获取里面出现次数最多的那个pair
# bpe.py
def find_best_pair(pair_frequencies):
if not pair_frequencies:
return None
best_pair = tuple()
max_freq = -1
for pair,freq in pair_frequencies.items(): # 遍历一遍即可, 追踪最大的freq的pair
if freq > max_freq:
max_freq = freq
best_pair = pair
elif freq == max_freq:
best_pair = max(best_pair,pair)
return best_pairpythonMerge操作#
接下来这个是我觉得BPE里面最难的实现, 首先回顾一下我们现在有了什么:
frequency_dict: dict(词语:频率)
pair_frequencies: dict(字符对:频率)
pair_to_tokens: dict(字符,set(词语)), 这个来自get_initial_pair_frequencies函数, 标注哪些词语含有这个字符
best_pair: tuple(字符1, 字符2)
还有一个比较麻烦的地方在于这一步有性能要求, 如果太暴力的话可能会导致测试过不去, 一个简单的想法是遍历frequency_dict, 然后找到所有含有best_pair的词语对他们进行修改, 但是这样是无法通过测试的
想一个取巧的办法, 既然我们有pair_to_tokens, 我们可以先找出含有best_pair的词语然后对他们进行遍历, 这就不需要遍历所有的词语了
# bpe.py
def merge_pair(frequency_dict, pair_frequencies, pair_to_tokens, best_pair,Debug=False):
byte1_tuple, byte2_tuple = best_pair # 拆开两个字符, 例如(b'S',)和(b't',)
merged_byte = byte1_tuple[0] + byte2_tuple[0] # 转换数据结构
affected_tokens = pair_to_tokens.get(best_pair,set()).copy()
# 找到那些含有这个pair的词语
if best_pair in pair_to_tokens:
# 在被merge之后, 所有的词语当中应该不再含有这个pair, 此等价于这个pair不再属于任何词语, 所以从pair_to_tokens当中删除这个pair
del pair_to_tokens[best_pair]
for pretoken in affected_tokens: # 遍历含有这个pair的词语
count = frequency_dict[pretoken] # 旧的词语计数, 注意合并后词语计数是会变化的, 比如旧的词语是(b'S', b't', b'o', b'r', b'y')
# 如果St被合并之后, 这个旧的词语应该是不存在了, 产生了一个新的词语:(b'St', b'o', b'r', b'y'), 这个新的词语(其实不见得是新的), 可能以前有
# 的计数要加上老的词语的计数才行
frequency_dict[pretoken] -= count # 旧的词语计数减少
if frequency_dict[pretoken] <= 0:
del frequency_dict[pretoken] # 删除旧的词语
new_pretoken_list = [] # 准备生成新词语
i = 0
while i < len(pretoken): # 用类似滑动窗口的形式
if (i < len(pretoken) - 1 and
(pretoken[i],) == byte1_tuple and
(pretoken[i+1],) == byte2_tuple):
new_pretoken_list.append(merged_byte) # 如果检测到best_pair那一位就直接把best_pair一起append
i += 2
else:
new_pretoken_list.append(pretoken[i]) # 否则只append一个byte
i += 1
new_pretoken_tuple = tuple(new_pretoken_list) # 生成新词语
frequency_dict[new_pretoken_tuple] += count # 加上老词语的计数python还要修改pair_frequencies和pair_to_tokens, 因为老的pair已经没有了并且由于产生了新的词语,会产生新的pair, 接着上面的函数代码:
# bpe.py
if new_pretoken_tuple != pretoken: # 如果新词语不等于老词语(真的有合并发生)
for i in range(len(pretoken) - 1):
old_pair = ((pretoken[i],), (pretoken[i+1],))
pair_frequencies[old_pair] -= count
if pair_frequencies[old_pair] <= 0: # 修改pair_frequencies
del pair_frequencies[old_pair]
if old_pair in pair_to_tokens:
pair_to_tokens[old_pair].discard(pretoken)
if not pair_to_tokens[old_pair]: # 修改pair_to_tokens
del pair_to_tokens[old_pair]
for i in range(len(new_pretoken_tuple) - 1): # 新词语本身会引入新的pair, 所以更新pair_frequencies和pair_to_tokens
new_pair = ((new_pretoken_tuple[i],), (new_pretoken_tuple[i+1],))
pair_frequencies[new_pair] += count
pair_to_tokens[new_pair].add(new_pretoken_tuple)
return frequency_dict, pair_frequencies, pair_to_tokenspython训练主函数#
现在我们来写训练循环, 架构非常简单, 分为以下几个步骤
数据准备:
- 加载训练语料库文件
- 按分隔符(如
<|endoftext|>)将文件分割成多个数据块
预处理:
- 对每个数据块进行预分词,统计所有预分词的频率
- 初始化词频表
frequency_dict
字节对统计:
- 计算所有相邻字节对的出现频率
- 建立字节对频率表和映射关系
while 继续训练条件: 步骤1: 找到最优合并对 best_pair = 找到频率最高的字节对(pair_frequencies)
步骤2: 执行合并 merge_pair(frequency_dict, pair_frequencies, pair_to_tokens, best_pair)
步骤3: 更新模型状态 更新合并记录(merges列表) 更新词汇表(vocab集合)
BPE训练流程
├── 初始化阶段
│ ├── 加载并分块文件
│ ├── 初始化词频表
│ └── 初始化字节对频率
├── 训练迭代 (循环开始)
│ ├── 选择最佳字节对
│ ├── 合并操作
│ │ ├── 更新词频统计
│ │ ├── 重建受影响的预分词
│ │ └── 更新字节对映射
│ ├── 记录合并操作
│ └── 扩展词汇表
└── 训练结束plaintext# bpe.py
def train_bpe(input_path:str, vocab_size:int, special_tokens:list[str], Debug=False):
frequency_dict = load_and_chunk_file(input_path, desired_num_chunks=4, split_special_token=special_tokens)
# 生成原始词频表
vocab, merges = initialize_vocab_and_merges(special_tokens)
# 生成原始词汇表和空的Merges
pair_frequencies, pair_to_tokens = get_initial_pair_frequencies(frequency_dict)
# 生成初始的字符对的频率表
while len(vocab) < vocab_size:
# 循环条件: 未达到预设的vocab_size
best_pair = find_best_pair(pair_frequencies) # 获取best_pair
if not best_pair:
break
frequency_dict, pair_frequencies, pair_to_tokens = merge_pair(frequency_dict, pair_frequencies, pair_to_tokens, best_pair)
# merge这个best_pair
merges.append((best_pair[0][0],best_pair[1][0])) # 更新merges
new_token = best_pair[0][0] + best_pair[1][0] # 更新vocab
vocab[len(vocab)] = new_token
return vocab,mergespython评测接口#
非常容易, 只要修改adapters.py的框架代码就行了:
# adapters.py
from cs336_basics import bpe
def run_train_bpe(
input_path: str | os.PathLike,
vocab_size: int,
special_tokens: list[str],
**kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
input_path_str = str(input_path)
vocab,merges = bpe.train_bpe(input_path=input_path_str,vocab_size=vocab_size,special_tokens=special_tokens,Debug=False)
return vocab, mergespython这里不需要手动指定路径, 测试自己存了一些demo文件去进行训练并和ref进行比对来判断正误, 启动测试的代码是:
root@autodl-container-8d994fbd73-e5baa69e:~/autodl-tmp/Stanford_CS336/assignment1-basics# uv run pytest tests/test_train_bpe.pybash稍微注意一下启动时候的路径, 测试通过如下:
(base) root@autodl-container-8d994fbd73-e5baa69e:~/autodl-tmp/Stanford_CS336/assignment1-basics# uv run pytest tests/test_train_bpe.py
============================================================================ test session starts =============================================================================
platform linux -- Python 3.12.3, pytest-8.4.1, pluggy-1.6.0
rootdir: /root/autodl-tmp/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 3 items
tests/test_train_bpe.py::test_train_bpe_speed PASSED
tests/test_train_bpe.py::test_train_bpe PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens PASSED
================================================================================================================= 3 passed in 3.04s =================================================================================================================bash在TinyStories和OpenWebText数据集上训练BPE分词器#
这里只需要自己写一个脚本调用之前实现的训练过程就行了, 唯一要注意的就是Merges和Vocab持久化时候的格式问题
Vocab的输出格式是:
vocab = {0: b'hello', 1: b'world', ...}plaintext需要转化为
vocab_unicode = {'hello': 0, 'world': 1, ...}plaintext# train_bpe_tinystories.py
from cs336_basics.bpe import train_bpe
from loguru import logger
from tests.common import gpt2_bytes_to_unicode
import json
from pathlib import Path
def save_vocab_merge(vocab, merges, output_path='./../../TinyStories_Result'):
output_dir = Path(output_path)
output_dir.mkdir(exist_ok=True)
byte_to_unicode = gpt2_bytes_to_unicode()
# Vocab:{id:bytes} -> {unicode_string:id}
vocab_unicode = {}
for token_id,token_bytes in vocab.items():
unicode_chars = [byte_to_unicode[b] for b in token_bytes]
unicode_string = ''.join(unicode_chars)
vocab_unicode[unicode_string] = token_id
vocab_path = output_dir/"vocab.json"
with open(vocab_path, 'w', encoding='utf-8') as f:
json.dump(vocab_unicode, f, ensure_ascii=False, indent=2)
merge_path = output_dir/"merges.txt"
# Merge:[tuple(bytes,bytes)]
with open(merge_path,'w',encoding='utf-8') as f:
for merge_pair in merges:
token1_bytes, token2_bytes = merge_pair
token1_unicode = ''.join(byte_to_unicode[b] for b in token1_bytes)
token2_unicode = ''.join(byte_to_unicode[b] for b in token2_bytes)
f.write(f"{token1_unicode} {token2_unicode}\n")
return vocab_path,merge_pathpython实际上就是把原始格式转化成人类可读的Unicode格式就可以了
训练接口函数
# train_bpe_tinystories.py
def train_bpe_tinystories():
input_path = "./../../data/TinyStoriesV2-GPT4-train.txt"
special_tokens = ["<|endoftext|>"]
vocab_size = 10000
vocab, merges = train_bpe(input_path=input_path,vocab_size=vocab_size,special_tokens=special_tokens,Debug=False)
vocab_path, merges_path = save_vocab_merge(vocab, merges)
if __name__ == "__main__":
train_bpe_tinystories()python另一个数据集只需要修改一下路径就好了, 不再赘述
# train_bpe_expts_owt.py
from cs336_basics.bpe import train_bpe
from loguru import logger
from tests.common import gpt2_bytes_to_unicode
import json
from pathlib import Path
# uv run python train_bpe_expts_owt.py
def save_vocab_merge(vocab, merges, output_path='./../../OpenWebText_Result'):
output_dir = Path(output_path)
output_dir.mkdir(exist_ok=True)
byte_to_unicode = gpt2_bytes_to_unicode()
# Vocab:{id:bytes} -> {unicode_string:id}
vocab_unicode = {}
for token_id,token_bytes in vocab.items():
unicode_chars = [byte_to_unicode[b] for b in token_bytes]
unicode_string = ''.join(unicode_chars)
vocab_unicode[unicode_string] = token_id
vocab_path = output_dir/"vocab.json"
with open(vocab_path, 'w', encoding='utf-8') as f:
json.dump(vocab_unicode, f, ensure_ascii=False, indent=2)
merge_path = output_dir/"merges.txt"
# Merge:[tuple(bytes,bytes)]
with open(merge_path,'w',encoding='utf-8') as f:
for merge_pair in merges:
token1_bytes, token2_bytes = merge_pair
token1_unicode = ''.join(byte_to_unicode[b] for b in token1_bytes)
token2_unicode = ''.join(byte_to_unicode[b] for b in token2_bytes)
f.write(f"{token1_unicode} {token2_unicode}\n")
return vocab_path,merge_path
def train_bpe_expts_owt():
input_path = "./../../data/owt_train.txt"
special_tokens = ["<|endoftext|>"]
vocab_size = 32000
vocab, merges = train_bpe(input_path=input_path,vocab_size=vocab_size,special_tokens=special_tokens,Debug=False)
vocab_path, merges_path = save_vocab_merge(vocab, merges)
if __name__ == "__main__":
train_bpe_expts_owt()python最后得到的merges.txt和vocab.json类似于
Ġ t
h e
Ġ a
Ġ s
Ġ w
n d
Ġt he
e d
Ġ b
Ġt oplaintext{
"Ā": 0,
"ā": 1,
"Ă": 2,
"ă": 3,
"Ą": 4,
"ą": 5,
"Ć": 6,
"ć": 7,
"Ĉ": 8,
"ĉ": 9,
"Ċ": 10,
}plaintext分词器: 编码与解码#
当Vocab和Merges被训练完成后, 我们可以用他们然后对给定的语料进行解码/编码
编码#
本质上就是把语料先预分词, 然后应用已有的Merges, 最后再编码成整数序列, 举例如下:
Vocab = {0: b' ', 1: b'a', 2:b'c', 3: b'e', 4: b'h', 5: b't', 6: b'th', 7: b' c', 8: b' a', 9: b'the', 10: b'at'}
Merges = [(b't', b'h'), (b' ', b'c'), (b' ', 'a'), (b'th', b'e'), (b' a', b't')]plaintext对’the cat ate’进行编码的步骤如下:
首先预分词用空格来进行token划分,得到
['the', ' cat', ' ate']plaintext对于第一个token:‘the’, 他的表示方式是:
[b't', b'h', b'e']plaintext应用两次Merges中的合并就变成:
[b'the']plaintext然后转化成整数序列就是[9], 其余两个token也一样, 最后得到的整数序列是:
[9, 7, 1, 5, 10, 3]plaintext解码#
解码过程相对简单, 对于输入的整数序列, 只需要一个个的查找在Vocab中对应的词并且concat起来就好了
Tokenizer类的实现(遵循实验文档的接口架构)#
# tokenizer.py
class tokenizer():
def __init__(self,vocab:dict[int,bytes],merges:list[tuple[bytes,bytes]],special_tokens:list[str]=None):
self.vocab = vocab
self.merges = merges
self.special_tokens = special_tokens
self.vocab_reverse = {token_bytes:token_id for token_id,token_bytes in self.vocab.items()} # 反向映射方便查找
self.PAT = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")python先实现类方法apply_bpe_merges, 传入一个token, 查找Merges看是否有可以应用在这个token上的合并操作, 若有就应用后返回
# tokenizer.py
def apply_bpe_merges(self,token_parts)->list[bytes]:
current_parts = token_parts.copy()
for merge_pair in self.merges: # 遍历合并表
byte1,byte2 = merge_pair # 待合并的字节对
i=0
while i < len(current_parts) - 1: # 遍历token看是否有和合并的字节对一样的字节对
if current_parts[i]==byte1 and current_parts[i+1]==byte2:
merged = byte1 + byte2
current_parts[i] = merged
del current_parts[i+1] # 执行合并操作并且删除原来的第二个字节
else:
i += 1
return current_partspython实现一个一般的编码函数, 单纯的通过PAT去匹配, 然后对各PAT分割出的部分应用Merge, 最后再通过self.vocab_reverse转化成整数序列完成编码
# tokenizer.py
def encode_normal_part(self,text:str)->list[int]:
pre_tokens = [] # 分割成的很多token
for match in re.finditer(self.PAT,text):
pre_tokens.append(match.group()) # 根据PAT来分割
result = []
for pre_token in pre_tokens: # 处理每个预分词
pre_token_bytes = pre_token.encode("utf-8") # 生成字节序列
token_parts = [bytes([b]) for b in pre_token_bytes] # 生成单个字节的列表, 例如"Hello" → b'Hello' → [b'H', b'e', b'l', b'l', b'o']
merged_parts = self.apply_bpe_merges(token_parts) # 应用Merge
for part in merged_parts: # 遍历Merge后的字节序列, 查表把每个字节序列转化为token id, 同时忽略那些不再词汇表里的字节序列
token_id = self.vocab_reverse.get(part,None)
if token_id is not None:
result.append(token_id)
return result
# 1. 预分词:["Hello"]
# 2. 字节化:[b'H', b'e', b'l', b'l', b'o']
# 3. BPE合并(假设有 'l','l' 合并):[b'H', b'e', b'll', b'o']
# 4. ID查找:
# - b'H' → 假设ID为 72
# - b'e' → 假设ID为 101
# - b'll' → 假设ID为 200
# - b'o' → 假设ID为 111
# 5. 输出:[72, 101, 200, 111]python考虑到有可能有self.special_tokens, 所以要在上面这个函数上实现一个更一般的encode函数
def encode(self,text:str)->list[int]:
if self.special_tokens: # 特殊token模式
sorted_specials = sorted(self.special_tokens, key=len, reverse=True)
special_pattern = "|".join(re.escape(token) for token in sorted_specials)
# 在特殊token的地方切开
# [普通文本1, 特殊token1, 普通文本2, 特殊token2, ...]
parts = re.split(f"({special_pattern})", text)
result = []
for part in parts:
if not part:
continue
elif part in self.special_tokens: # 特殊token部分不需要走Merges逻辑, 直接查表就行
special_bytes = part.encode("utf-8")
token_id = self.vocab_reverse.get(special_bytes,None)
if token_id is not None:
result.append(token_id)
else: # 走之前的逻辑, 先应用Merges然后再查表
part_result = self.encode_normal_part(part)
result.extend(part_result)
return result
else: # 无特殊token, 直接走原有逻辑
return self.encode_normal_part(text)python文档要求我们以流式处理和惰性求值的方式来encode, 非常简单, 实现如下:
# tokenizer.py
def encode_iterable(self,iterable:Iterable[str])->Iterable[int]:
for text in iterable:
token_ids = self.encode(text)
for token_id in token_ids:
yield token_id
python解码部分更容易, 就是一个查表然后join的函数而已
# tokenizer.py
def decode(self,ids:list[int])->str:
byte_sequences = []
for token_id in ids:
if token_id in self.vocab.keys():
byte_sequences.append(self.vocab[token_id])
combined_bytes = b''.join(byte_sequences)
text = combined_bytes.decode('utf-8',errors='replace')
return textpython实验文档还要求我们实现一个类的构造器, 其实很容易, 只要读取那些被持久化了的Merges和Vocab,然后把他转化成持久化前的数据类型就行了, 相当于是前面训练中那个save_vocab_merge的逆向
# tokenizer.py
@classmethod
def from_files(cls,vocab_filepath:str,merges_filepath:str,special_tokens:list[str]=None):
with open(file=vocab_filepath,mode='r',encoding='utf-8') as f:
vocab_unicode = json.load(f)
vocab = {}
for unicode_str,token_id in vocab_unicode.items():
vocab[token_id] = unicode_str.encode('utf-8')
merges = []
with open(file=merges_filepath,mode='r',encoding='utf-8') as f:
for line in f:
if line.strip():
token1_str, token2_str = line.strip().split()
merges.append((token1_str.encode('utf-8'),token2_str.encode('utf-8')))
return cls(vocab,merges,special_tokens)
python评测接口#
和之前一样, 修改adapters.py即可
# adapters.py
from cs336_basics import tokenizer
def get_tokenizer(
vocab: dict[int, bytes],
merges: list[tuple[bytes, bytes]],
special_tokens: list[str] | None = None,
) -> Any:
return tokenizer.tokenizer(vocab=vocab,merges=merges,special_tokens=special_tokens)python启动测试:
uv run pytest tests/test_tokenizer.pybash测试结果:
(base) zyli@lab:~/Stanford_CS336/assignment1-basics$ uv run pytest tests/test_tokenizer.py
======================================================================== test session starts =========================================================================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0
rootdir: /home/zyli/Stanford_CS336/assignment1-basics
configfile: pyproject.toml
plugins: jaxtyping-0.3.2
collected 25 items
tests/test_tokenizer.py::test_roundtrip_empty PASSED
tests/test_tokenizer.py::test_empty_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_character PASSED
tests/test_tokenizer.py::test_single_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_unicode_character PASSED
tests/test_tokenizer.py::test_single_unicode_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_ascii_string PASSED
tests/test_tokenizer.py::test_ascii_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string PASSED
tests/test_tokenizer.py::test_unicode_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string_with_special_tokens PASSED
tests/test_tokenizer.py::test_unicode_string_with_special_tokens_matches_tiktoken PASSED
tests/test_tokenizer.py::test_overlapping_special_tokens PASSED
tests/test_tokenizer.py::test_address_roundtrip PASSED
tests/test_tokenizer.py::test_address_matches_tiktoken PASSED
tests/test_tokenizer.py::test_german_roundtrip PASSED
tests/test_tokenizer.py::test_german_matches_tiktoken PASSED
tests/test_tokenizer.py::test_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_special_token_trailing_newlines PASSED
tests/test_tokenizer.py::test_encode_special_token_double_newline_non_whitespace PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_iterable_memory_usage PASSED
tests/test_tokenizer.py::test_encode_memory_usage XFAIL (Tokenizer.encode is expected to take more memory than allotted (1MB).)
========================================================================== warnings summary ==========================================================================
tests/adapters.py:352
/home/zyli/Stanford_CS336/assignment1-basics/tests/adapters.py:352: SyntaxWarning: invalid escape sequence '\T'
rope_theta (float): The RoPE $\Theta$ parameter.
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================= 24 passed, 1 xfailed, 1 warning in 1696.82s (0:28:16) ========================================================bash这个测试耗时在现有的复杂度下非常长, 而且也有一个内存限制测试是预计不通过的(测试文件里面写了这个测试本来就预计不通过), 在实验文档里有写可以通过Cpp或Rust实现来显著提升速度
To test your BPE training function against our provided tests, you will first need to implement the
test adapter at [adapters.run_train_bpe]. Then, run uv run pytest tests/test_train_bpe.py.
Your implementation should be able to pass all tests. Optionally (this could be a large time-investment),
you can implement the key parts of your training method using some systems language, for instance
C++ (consider cppyy for this) or Rust (using PyO3). If you do this, be aware of which operations
require copying vs reading directly from Python memory, and make sure to leave build instructions, or
make sure it builds using only pyproject.toml. Also note that the GPT-2 regex is not well-supported
in most regex engines and will be too slow in most that do. We have verified that Oniguruma is
reasonably fast and supports negative lookahead, but the regex package in Python is, if anything,
even faster.plaintext编码数据集#
只需要利用tokenizer类以及已有的Merges Vocab即可
# tokenizer_experiment.py
import enum
from cs336_basics import tokenizer
from typing import IO, Any, BinaryIO
from tests import test_tokenizer
import random
import numpy as np
import multiprocessing as mp
from functools import partial
TinyStories_Vocab_Path = './../TinyStories_Result/vocab.json'
TinyStories_Merges_Path = './../TinyStories_Result/merges.txt'
OpenWebText_Vocab_Path = './../OpenWebText_Result/vocab.json'
OpenWebText_Merges_Path = './../OpenWebText_Result/merges.txt'
TinyStories_Datapath = './../data/TinyStoriesV2-GPT4-train.txt'
OpenWebText_Datapath = './../data/owt_train.txt'
TinyStories_Valid_Datapath = './../data/TinyStoriesV2-GPT4-valid.txt'
OpenWebText_Valid_Datapath = './../data/owt_valid.txt'
def sample_documents_from_file(filepath,num_samples=10):
documents = []
with open(filepath,'r',encoding='utf-8') as f:
content = f.read()
parts = content.split('<|endoftext|>')
for part in parts:
if part.strip():
documents.append(part+'<|endoftext|>')
if len(documents) <= num_samples:
return documents
return random.sample(documents,num_samples)
def all_documents_from_file(filepath):
documents = []
with open(filepath,'r',encoding='utf-8') as f:
content = f.read()
parts = content.split('<|endoftext|>')
for part in parts:
if part.strip():
documents.append(part+'<|endoftext|>')
return documents
def calculate_compression_ratio(text,tokenizer):
original_bytes = len(text.encode('utf-8'))
tokens = tokenizer.encode(text)
num_tokens = len(tokens)
compression_ratio = original_bytes / num_tokens if num_tokens > 0 else 0
return compression_ratio
def encode_text(text,tokenizer):
tokens = tokenizer.encode(text)
def encode_entire_file(filepath,tokenizer):
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
# 一次性编码整个文件内容
tokens = tokenizer.encode(content)
return tokens
def encode_documents_batch(docs_batch,tokenizer):
tokens = []
for doc in docs_batch:
doc_tokens = tokenizer.encode(doc)
tokens.extend(doc_tokens)
return tokens
def encode_documents_parallel(documents, tokenizer, num_processes=None):
if num_processes is None:
num_processes = min(mp.cpu_count(), 96) # 使用最多32个进程
# 将文档分成批次
batch_size = max(1, len(documents) // num_processes)
doc_batches = [documents[i:i + batch_size] for i in range(0, len(documents), batch_size)]
# 创建编码函数(固定tokenizer参数)
encode_func = partial(encode_documents_batch, tokenizer=tokenizer)
print(f"Using {len(doc_batches)} processes to encode {len(documents)} documents...")
# 使用进程池并行处理
with mp.Pool(processes=len(doc_batches)) as pool:
results = pool.map(encode_func, doc_batches)
# 合并所有批次的结果
all_tokens = []
for batch_tokens in results:
all_tokens.extend(batch_tokens)
return all_tokens
if __name__ == "__main__":
print("\nLoading TinyStories tokenizer...")
tinystories_tokenizer = test_tokenizer.get_tokenizer_from_vocab_merges_path(
vocab_path=TinyStories_Vocab_Path,
merges_path=TinyStories_Merges_Path,
special_tokens=["<|endoftext|>"]
)
print("Loading OpenWebText tokenizer...")
openwebtext_tokenizer = test_tokenizer.get_tokenizer_from_vocab_merges_path(
vocab_path=OpenWebText_Vocab_Path,
merges_path=OpenWebText_Merges_Path,
special_tokens=["<|endoftext|>"]
)
print("\nEncoding all TinyStories Dataset")
print("Encoding TinyStories valid dataset...")
TinyStories_Valid_docs = all_documents_from_file(TinyStories_Valid_Datapath)
TinyStories_Valid_Encode = encode_documents_parallel(TinyStories_Valid_docs, tinystories_tokenizer)
print("Encoding TinyStories train dataset...")
TinyStories_Train_docs = all_documents_from_file(TinyStories_Datapath)
TinyStories_Train_Encode = encode_documents_parallel(TinyStories_Train_docs, tinystories_tokenizer)
print("\nEncoding all OpenWebText Dataset")
print("Encoding OpenWebText valid dataset...")
OpenWebText_Valid_docs = all_documents_from_file(OpenWebText_Valid_Datapath)
OpenWebText_Valid_Encode = encode_documents_parallel(OpenWebText_Valid_docs, openwebtext_tokenizer)
print("Encoding OpenWebText train dataset...")
OpenWebText_Train_docs = all_documents_from_file(OpenWebText_Datapath)
OpenWebText_Train_Encode = encode_documents_parallel(OpenWebText_Train_docs, openwebtext_tokenizer)
print("\nSaving encoded datasets as uint16 NumPy arrays...")
np.save('./../TinyStories_Result/train_tokens.npy', np.array(TinyStories_Train_Encode, dtype=np.uint16))
np.save('./../TinyStories_Result/valid_tokens.npy', np.array(TinyStories_Valid_Encode, dtype=np.uint16))
np.save('./../OpenWebText_Result/train_tokens.npy', np.array(OpenWebText_Train_Encode, dtype=np.uint16))
np.save('./../OpenWebText_Result/valid_tokens.npy', np.array(OpenWebText_Valid_Encode, dtype=np.uint16))
print("All datasets encoded and saved successfully!")
print(f"TinyStories train tokens: {len(TinyStories_Train_Encode)}")
print(f"TinyStories valid tokens: {len(TinyStories_Valid_Encode)}")
print(f"OpenWebText train tokens: {len(OpenWebText_Train_Encode)}")
print(f"OpenWebText valid tokens: {len(OpenWebText_Valid_Encode)}")
python只需要注意最后以np.uint16格式保存即可