ML/AI/SW Developer

Byte Pair Encoding

1. Byte Pair Encoding

  • 일반적으로 하나의 단어에 대해 하나의 embedding 벡터를 생성할 경우 학습하지 못한 단어를 맞이했을 경우 out-of-vocabulary(OOV)라는 치명적인 문제를 갖게 됩니다. 이럴 경우 보통, Unknown token으로 처리해주어 모델의 입력으로 넣게 되면서, 이러한 경우가 많아지면 전체적으로 모델의 성능이 저하될 수 있습니다. 하지만 모든 단어의 embedding 벡터를 만들기에는 필요한 embedding parameter의 수가 지나치게 많아집니다. 이러한 문제를 해결하기 위해 컴퓨터가 이해하는 단어를 표현하는 데에 데이터 압축 알고리즘 중 하나인 byte pair encoding 기법을 적용한 sub-word tokenizaiton이라는 개념이 나타났습니다.

2. Code

2.1 Vocab 준비

  • 먼저 단어들을 character 사이에 공백을 넣어 vocab 준비 및 WORD_END 붙이기
    • E.g. [‘abc’, ‘def’] $\rightarrow$ [‘a b c _’, ‘d e f _’]
  • idx2word는 subword를 모두 담을 리스트로, 먼저 알파벳들(여기서는)을 중복없이 담아 준다!
  • python Collections 라이브러이의 Counter를 써서 쉽게 빈도수 계산 가능!
# Subword 단어들을 저장할 list
idx2word = []
# 모든 단어들을 character 단위로 split하고 마지막에 WORD_END 붙이기
# character 단위로 split했을때 character들 subword에 넣어주기 (WORD_END 제외)
for idx, s in enumerate(corpus):
    corpus[idx] = ' '.join(list(s + WORD_END))
    idx2word += [_ for _ in corpus[idx].split() if _ != '_']
# 중복제거
idx2word = list(set(idx2word))

# counter 이용해서 vocab dict 만들기
vocab = Counter(corpus)

2.2 get_stats

  • 띄어진 단어들에서 pair들의 빈도수를 체크하고, dict로 넘겨주기
    • E.g [’s k y’, ‘s k i e s’] $\rightarrow$ { (s, k): 2, (k, y): 1, …}
def get_stats(vocab):
    # pair들의 빈도수를 저장하기 위한 dict
    pairs = defaultdict(int)
    # vocab에 있는 단어 및 빈도수를 기반으로 반복문 실행
    for word, freq in vocab.items():
        # word를 split / [sky_] -> [s, k, y, _]
        symbols = word.split()
        # 마지막에 word_end는 제외?
        for i in range(len(symbols)-1):
            # pair를 합쳐 보며 빈도수 저장
            pairs[symbols[i], symbols[i+1]] += freq
    return pairs

2.3 merge_vocab

  • get_stats로 만든 pair를 기반으로 vocab 업데이트
  • 가장 빈도수가 높은 pair를 하나로 합침
    • E.g. [’s k y’, ‘s k i e s’] $\rightarrow$ [‘sk y’, ‘sk i e s’]
def merge_vocab(pairs, vocab):
    # 새로운 vocab 저장할 dict 생성
    result = defaultdict(int)
    # pair들중 가장 빈도수가 높은 pair 찾기
    best_pair = max(pairs, key=pairs.get)
    # vocab에 있는 단어들 중 best_pair merge 해, 새로운 vocab 생성
    for word, freq in vocab.items():
        paired = word.replace(" ".join(best_pair), "".join(best_pair))
        result[paired] = vocab[word]
    return dict(result)

2.4 get_stats, merge_vocab 반복 적용

  • 현재는 10이라고 설정되어있지만 변경 가능
  • 5-1과 같이 최대 단어집 크기를 정해놓고 subword화
    • 6은 Special token의 개수
    • SPECIAL = [PAD, UNK, CLS, SEP, MSK, WORD_END]
      for i in range(10):
        # 1. character pairs 생성
        pairs = get_stats(vocab)  
        # 1-1. 더이상 페어 생성이 불가능 한 경우
        if not pairs: break
        # 2. pairs들의 빈도수를 기반으로 Vocab 업데이트
        vocab = merge_vocab(pairs, vocab)
        # 3. Vocab에서 unique word 얻기 (split)
        idx2word += get_vocab(vocab)
        # 4. 중복 제거 및 리스트화
        idx2word = list(set(idx2word))
        # 5. voca 길이 계산
        new_length = len(idx2word)
        # 5-1. 스페셜 토큰들을 제외한 개수 검사
        if new_length >= max_vocab_size - 6:
        break
      

2.5 get_vocab

  • vocab에 단어들 split해서 unique 토큰들 얻기
def get_vocab(vocab):
    # unique한 word를 저장할 list
    result = []
    # vocab에 있는 단어들을 꺼내오기
    for word, freq in vocab.items():
        # split해서 subword들 토큰화
        tokens = word.split()
        # 토큰들 result에 넣어 주기 WORD_END 제외
        result += [token for token in tokens if token != '_']
    return result

3. Reference