deftrain(self): """Train a BPE Tokenizer""" # count the word frequency for document inself.corpus: # split each document in corpus by whitespace words = document.split() self.word_freq += Counter(words)
# initialize the self.splits for word inself.word_freq: self.splits[word] = list(word) + ["</w>"]
defupdate_splits(self, lhs: str, rhs: str): """If we see lhs and rhs appear consecutively, we merge them""" for word, word_split inself.splits.items(): new_split = [] cursor = 0 while cursor < len(word_split): if ( word_split[cursor] == lhs and cursor + 1 < len(word_split) and word_split[cursor + 1] == rhs ): new_split.append(lhs + rhs) cursor += 2 else: new_split.append(word_split[cursor]) cursor += 1 self.splits[word] = new_split
defget_pairs_freq(self) -> dict: """Compute the pair frequency""" pairs_freq = defaultdict(int) for word, freq inself.word_freq.items(): split = self.splits[word] for i inrange(len(split)): if i + 1 < len(split): pairs_freq[(split[i], split[i + 1])] += freq return pairs_freq
deftokenize(self, s: str) -> list[str]: splits = [list(t) + ["</w>"] for t in s.split()] for lhs, rhs inself.merges: for idx, split inenumerate(splits): new_split = [] cursor = 0 while cursor < len(split): if ( cursor + 1 < len(split) and split[cursor] == lhs and split[cursor + 1] == rhs ): new_split.append(lhs + rhs) cursor += 2 else: new_split.append(split[cursor]) cursor += 1 assert"".join(new_split) == "".join(split) splits[idx] = new_split returnsum(splits, [])