BPE (Byte Pair Encoding) 分词
在 Transformer 模型中,文本的表示方式通常是通过将文本转换为整数序列来实现的。这个过程被称为 tokenization (分词)。常见的分词方法包括 word-level tokenization (基于单词的分词)、subword-level tokenization (基于子词的分词) 以及 byte-level tokenization (基于字节的分词)。在本篇文章中,我们将重点介绍 BPE (Byte Pair Encoding) 分词方法,这是一种常用的 subword-level tokenization 技术。
一个简单的 encoding 方式是类似 Unicode 的字符级别编码 (character-level encoding)。在 Python 中,可以直接使用内置的 ord 函数将字符转换为对应的整数编码:
ord('a') # 输出 97
ord('你') # 输出 20320
使用 chr 函数可以将整数编码转换回对应的字符:
chr(97) # 输出 'a'
使用 Python 的内置函数 encode 和 decode 可以实现字符串和字节序列之间的转换:
# 将字符串编码为字节序列
s = "你好!"
b = s.encode('utf-8') # 输出 b'\xe4\xbd\xa0\xe5\xa5\xbd\xef\xbc\x81'
list(b) # 输出 [228, 189, 160, 229, 165, 189, 239, 188, 129]
可见,通过 Unicode 编码后我们可以将任意字符串转换为整数序列。其中,不一定是每个字节(byte)对应一个字符,而有可能是多个字节对应一个字符 (如中文字符通常由 3 个字节表示),具体的编码方式取决于编码方式。然而我们依然可以通过这种方式将任意字符串转换为确定性的、范围有限的整数序列(0-255 之间)。在现实使用中 utf-8 编码是最常用的编码方式。utf-16 以及 utf-32 虽然能够表示更多的字符,但会导致编码后的字节序列更长,占用更多的存储空间。11 这里建立在一个以英文为主体的假设上,即大部分字符都在 ASCII 范围内 (0-127)。如果文本中包含大量非 ASCII 字符(如中文、日文等),也许 utf-16 甚至更优一些。
如此的 byte-level tokenization 解决了先前 word-level tokenization 中未登录词 (out-of-vocabulary, OOV) 的问题,因为任何字符串都可以被编码为字节序列。然而这种方式也有其缺点:它忽略了词语的语义信息,导致模型难以捕捉到词语之间的关系。此外,字节级别的编码会导致序列长度增加,从而增加模型的计算复杂度22 一般的 Transformer 模型的计算复杂度与序列长度的平方成正比 ()。我们希望找到能兼顾词语语义和处理未登录词的编码方式。由此引入的就是 subword-level tokenization 方法,其中最著名的就是 BPE (Byte Pair Encoding) 分词方法。我们选取一个适中的序数范围(介于 byte-level 的255个和 word-level 的数十万之间),并通过统计文本中出现频率较高的子词 (subword) 来构建词汇表,从而在一定程度上保留了词语的语义信息,同时也能处理未登录词的问题。BPE 的核心思想是通过迭代地合并文本中出现频率最高的字节对 (byte pair) 来构建子词单元。
BPE 的训练
在初始阶段,由于 BPE 基于 byte-level tokenization,因此我们首先需要将文本转换为字节序列。据此我们将文本表示为了一个 范围内的整数序列。如果我们按照最粗暴的方式来进行 BPE 分词,我们需要不断遍历整个文本来统计所有可能的字节对的频率,这在大规模文本上是非常低效的。为了解决这个问题,我们可以先进行一次初始的分词 (pre-tokenization),将文本划分为较大的单元(如单词或常见子词),然后在这些单元内进行字节对的统计和合并。这样可以显著减少需要处理的字节对数量,提高 BPE 训练的效率。例如,假设我们已经认定一个 pre-token 是 text,那么今后我们在统计 t 和 e 这对字节时,只需加上 text 中 t 和 e 出现的次数,而不需要每次都遍历整个文本。
在 GPT 系列的 pre-tokenization 中,使用了以下正则表达式来划分文本:
'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+
该正则表达式的含义如下:
(?:[sdmt]|ll|ve|re):匹配常见的英语缩写,如's(is/has)、'd(would/had)、'm(am)、'll(will)、've(have)、're(are)。?\p{L}+:匹配一个或多个字母(包括 Unicode 字母),前面可有一个空格。?\p{N}+:匹配一个或多个数字,前面可有一个空格。?[^\s\p{L}\p{N}]+:匹配一个或多个非空白、非字母、非数字的字符,前面可有一个空格。\s+(?!\S):匹配一个或多个空白字符,后面不跟非空白字符(即匹配行尾的空白)。\s+:匹配一个或多个空白字符。
通过这种方式,我们可以有效地将文本划分为较大的单元。相继地,我们训练 BPE 模型时,只需在这些单元内进行字节对的统计和合并,从而提高了训练效率。以下是一个初始暴力的 BPE 训练代码:
import regex as re
from collections import Counter
class naiveTokenizer():
def __init__(self, input_path, vocab_size):
self.input_path = input_path
self.vocab_size = vocab_size
with open(input_path, 'r') as f:
self.text = f.read()
self.merges = {}
self.vocab = {i: bytes([i]) for i in range(256)}
self.token_counts = {}
def pretokenize(self):
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
it = re.finditer(PAT, self.text)
self.token_counts = dict(Counter(tuple(m.group().encode('utf-8')) for m in it))
print(self.token_counts)
def count_pair(self):
counts = Counter()
for word, count in self.token_counts.items():
for i in range(len(word) - 1):
pair = (word[i], word[i+1])
counts[pair] += count
return counts
def _merge_tokens(self, pair, new_id):
new_token_counts = {}
for word, count in self.token_counts.items():
if len(word) < 2 or pair not in zip(word, word[1:]):
new_token_counts[word] = count
continue
new_word = []
i = 0
while i < len(word):
if i < len(word) - 1 and (word[i], word[i+1]) == pair:
new_word.append(new_id)
i += 2
else:
new_word.append(word[i])
i += 1
new_token_counts[tuple(new_word)] = count
self.token_counts = new_token_counts
def train(self):
self.pretokenize()
num_merges = self.vocab_size - 256
for i in range(num_merges):
counts = self.count_pair()
if not counts:
break
frequent_pair = max(counts.keys(), key=lambda p: (counts[p], p))
new_id = 256 + i
print(f"Merge {i+1}/{num_merges}: {self.vocab[frequent_pair[0]], self.vocab[frequent_pair[1]]} -> {new_id}")
self._merge_tokens(frequent_pair, new_id)
self.merges[frequent_pair] = new_id
self.vocab[new_id] = self.vocab[frequent_pair[0]] + self.vocab[frequent_pair[1]]
这份初始代码的思路很简单:首先对文本进行 pre-tokenization,然后统计所有可能的字节对的频率,选择出现频率最高的字节对进行合并,并更新词汇表和 token 计数。这个过程会重复进行,直到达到预定的词汇表大小为止。如果我们想要在较大的数据集上训练 BPE 模型,需要进行一些优化。
Pre-tokenization 并行化
在大规模文本上进行 BPE 训练时,pre-tokenization 可能成为瓶颈。为了提高效率,我们可以将文本划分为多个片段,并使用多线程或多进程来并行处理这些片段。处理得到的 token 计数可以在最后进行合并。需要注意的一点是,在划分文本时,可能会出现一些 token 被拆分的情况(例如一个单词被分割成两部分)。为了避免这种情况,我们在划分文本时需要确保每个片段的边界处不会拆分 token。具体实现时,一种简单的实现方式是,在划分文本时找到最近的特殊 token33 通常文本中会包含类似 [SEP]、<|endoftext|> 等类似的特殊标记来分割不同的段落或文本 作为边界,从而确保 token 不会被拆分。
下面是一个简单的多进程 pre-tokenization 实现:
def find_chunk_boundaries(
self,
file: BinaryIO,
desired_num_chunks: int,
split_special_token: bytes,
) -> list[int]:
"""
Divides a file into approximately equal-sized chunks, with boundaries
placed at special token positions to avoid splitting tokens.
"""
assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"
# Get total file size
file.seek(0, os.SEEK_END)
file_size: int = file.tell()
file.seek(0)
chunk_size: int = file_size // desired_num_chunks
# Initial uniform chunk boundary positions
chunk_boundaries: list[int] = [i * chunk_size for i in range(desired_num_chunks + 1)]
chunk_boundaries[-1] = file_size
# Adjust boundaries to split on special tokens
for boundary_idx in range(1, len(chunk_boundaries) - 1):
initial_position: int = chunk_boundaries[boundary_idx]
file.seek(initial_position)
while True:
mini_chunk: bytes = file.read(self.MINI_CHUNK_SIZE)
if mini_chunk == b"":
chunk_boundaries[boundary_idx] = file_size
break
found_at: int = mini_chunk.find(split_special_token)
if found_at != -1:
chunk_boundaries[boundary_idx] = initial_position + found_at
break
initial_position += self.MINI_CHUNK_SIZE
return sorted(set(chunk_boundaries))
def count_tokens(self, start: int, end: int) -> CounterType[Tuple[int, ...]]:
"""
Count token frequencies in a file range.
"""
special_token_pattern: str = "|".join(map(re.escape, self.special_tokens))
counts: CounterType[Tuple[int, ...]] = Counter()
with open(self.input_path, "rb") as f:
f.seek(start)
text: str = f.read(end - start).decode("utf-8", errors="ignore")
chunks: list[str] = re.split(special_token_pattern, text)
for chunk in chunks:
matches = re.finditer(self.TOKENIZATION_PATTERN, chunk)
counts.update(
Counter(tuple(match.group().encode("utf-8")) for match in matches)
)
return counts
def pretokenize(self) -> None:
"""
Tokenize the input file using parallel processing.
Splits the file into chunks, counts token frequencies in each chunk
using multiprocessing, and aggregates results.
"""
total_counts: CounterType[Tuple[int, ...]] = Counter()
with open(self.input_path, "rb") as f:
boundaries: list[int] = self.find_chunk_boundaries(
f, self.NUM_PROCESSES, self.SPLIT_SPECIAL_TOKEN
)
ranges: zip[Tuple[int, int]] = zip(boundaries[:-1], boundaries[1:])
with Pool(processes=self.NUM_PROCESSES) as pool:
results: list[CounterType[Tuple[int, ...]]] = pool.starmap(
self.count_tokens, ranges
)
for result in results:
total_counts.update(result)
self.token_counts = dict(total_counts)
在上述代码中,我们首先定义了 find_chunk_boundaries 方法,用于根据特殊 token 来划分文本片段。然后,count_tokens 方法用于在指定的文件范围内统计 token 频率。最后,pretokenize 方法使用多进程来并行处理文本片段,并将结果进行合并。
寻找最优字节对的加速
在训练过程中,对于每次合并操作,我们需要统计所有可能的字节对的频率,这在大规模文本上是非常耗时的。为了提高效率,我们可以使用更高效的数据结构来存储和更新字节对的频率。我们需要支持的操作有:修改一个字节对的频率、删除一个字节对以及获取当前频率最高的字节对。不难发现,这些操作可以通过使用堆 (heap) 数据结构来实现。具体来说,我们可以使用一个最大堆 (max-heap)44 也就是一个优先队列 来存储字节对及其频率,从而能够高效地获取频率最高的字节对。
在执行合并操作时,我们需要更新受影响的字节对的频率。具体来说,当我们合并一个字节对 (a, b) 时,所有包含该字节对的 token 都会发生变化,因此我们需要相应地更新这些 token 中其他字节对的频率。为了实现这一点,我们可以维护一个映射关系,记录每个字节对在文本中出现的位置(也就是 pre-token 中的索引)。这样,在合并操作后,我们可以快速定位受影响的字节对,并更新它们的频率。
一个简单的实现思路如下:
- 使用一个最大堆来存储字节对及其频率。
- 使用一个字典来记录每个字节对在文本中出现的位置。
- 在每次合并操作后,遍历受影响的 token,更新相关字节对的频率,并调整堆中的位置。
实现的代码如下所示:
def count_pairs(self) -> Dict[Tuple[int, int], int]:
"""
Count adjacent token pair frequencies.
Builds pair counts from token_counts and maintains the pair_to_word
mapping for efficient merge operations.
"""
counts: CounterType[Tuple[int, int]] = Counter()
self.pair_to_word = defaultdict(set)
for word, count in self.token_counts.items():
for idx in range(len(word) - 1):
pair: Tuple[int, int] = (word[idx], word[idx + 1])
counts[pair] += count
self.pair_to_word[pair].add(word)
return { pair: [-freq] + list(pair) for pair, freq in counts.items() }
def merge_tokens(
self,
pair: Tuple[int, int],
new_id: int,
counts: pqdict[Tuple[int, int], List[int]],
) -> None:
"""
Merge a token pair throughout the vocabulary.
For each word containing the pair, creates a new word with the pair
replaced by a new token ID, and updates all pair counts accordingly.
"""
affected_words: list[Tuple[int, ...]] = list(self.pair_to_word[pair])
pair_deltas: Dict[Tuple[int, int], int] = defaultdict(int)
for word in affected_words:
count: int = self.token_counts[word]
merged_word: list[int] = []
idx: int = 0
while idx < len(word):
if idx < len(word) - 1 and (word[idx], word[idx + 1]) == pair:
merged_word.append(new_id)
idx += 2
else:
merged_word.append(word[idx])
idx += 1
old_pairs = list(zip(word, word[1:]))
for old_pair in old_pairs:
pair_deltas[old_pair] += count # Frequency decreases, so negative freq increases
self.pair_to_word[old_pair].discard(word)
new_pairs = list(zip(merged_word, merged_word[1:]))
for new_pair in new_pairs:
pair_deltas[new_pair] -= count # Frequency increases, so negative freq decreases
self.pair_to_word[new_pair].add(tuple(merged_word))
self.token_counts[tuple(merged_word)] = count
del self.token_counts[word]
# Batch apply all pair count changes
for affected_pair, delta in pair_deltas.items():
if affected_pair not in counts:
if delta < 0:
counts[affected_pair] = [delta, affected_pair[0], affected_pair[1]]
else:
counts[affected_pair][0] += delta
if counts[affected_pair][0] >= 0:
del counts[affected_pair]
else:
counts.updateitem(affected_pair, counts[affected_pair])
Tokenizer 的实现
完成 BPE 训练后,我们可以使用训练得到的词汇表和合并规则来实现一个 BPE tokenizer。该 tokenizer 能够将输入文本转换为对应的 token 序列,并支持解码操作将 token 序列还原为文本。
编码 (Encoding)
编码过程包括以下步骤:
- 对输入文本进行 pre-tokenization,得到初始的 token 列表。
- 对每个 token 进行字节级别的编码,得到对应的字节序列。
- 迭代地应用 BPE 合并规则,直到无法再进行合并为止。
这其中需要注意特殊 token 的处理,例如 [UNK](未登录词)、[PAD](填充符)等,这些特殊 token 通常在 pre-tokenization 阶段就已经被识别出来,并直接映射到对应的 token ID。
解码 (Decoding)
解码过程相对简单,主要包括以下步骤:
- 将输入的 token ID 列表转换为对应的字节序列。
- 将字节序列解码为字符串,得到最终的文本输出。