ML/AI/SW Developer

Entity를 변형해 데이터 Augmentation 하기

0. Relation Extracion을 위한 data augmentation

  • 동의어로 치환
  • Entity 만 변경 등

1. entity 리스트 추출 하기

# 데이터 안에서 원하는 tpye의 entity 리스트 추출
    # PER, LOC, NOH, ORG, DAT, POH 등
def extract_entity_based_on_type(data, entity, type_, save_name):
    new_dic = {'name':[], 'type':[]}
    for dic in data[entity].values:
        dic = literal_eval(dic)

        if dic['type'] == type_:
            new_dic['name'].append(dic['word'])

    names = list(set(new_dic['name']))
    new_dic['name'] = names
    new_dic['type'] = [type_ for _ in range(len(names))]
    new_dic = pd.DataFrame(new_dic)
    new_dic.to_csv(save_name+'.csv', index=False)
    return new_dic

2. 문장에서 Entity 치환 및 새로운 좌표 계산하기

  • KLUE 데이터 형식 대로 변경하기 위해서 좌표까지 계산
# Sentence 변경
def change_sentence(text, sx, ex, sy, ey, first_word, second_word):
    # entity 위치만 제거 하고 새로운 entity 삽입
    sentence = text[:sx] + first_word + text[ex+1:sy] + second_word + text[ey+1:]
    # 새로운 좌표계산
    origin_length = ex - sx + 1 
    sy = sx + len(first_word) - 1  # 길이를 이용해 끝점 계산
    ex = ex + (len(first_word) - origin_length) # 길면 +, 짧으면 - 해서 시작위치 조절
    ey = ex + len(second_word) - 1 # 길이를 이용해 끝점 계산
    
    return sentence, sx, ex, sy, ey

3. 기존의 문장들에 적용

  • 새로운 데이터프레임을 만들어서 저장
  • 향후 학습에 사용가능!
def augment_data(target_df, sub_entity_list, obj_entity_list, target_class, times):
    new = {'id':[], 'sentence':[], 'subject_entity':[], 'object_entity':[], 'label':[], 'source':[]}
    
    # 원본 데이터수 * times 개 생성
    for time in range(times):
        for idx, (id_, sentence, subject_entity, object_entity, label, source) in enumerate(target_df.values):
            # 정보 추출
            subject_dict = literal_eval(subject_entity)
            object_dict = literal_eval(object_entity)

            # 랜덤 추출
            sub_idx = np.random.randint(0, sub_entity_list.shape[0])
            obj_idx = np.random.randint(0, obj_entity_list.shape[0])
            sub_word = sub_entity_list['name'].iloc[sub_idx]
            obj_word = obj_entity_list['name'].iloc[obj_idx]

            # Sentence 바꾸기
                # 먼저 나오는 entity 속성에 따라 입력 값, 반환 값 순서를 다르게!
            if subject_dict['start_idx'] > object_dict['start_idx']:
                new_sentence, sy, ey, sx, ex = change_sentence(sentence, object_dict['start_idx'], object_dict['end_idx'], subject_dict['start_idx'], subject_dict['end_idx'], obj_word, sub_word)
            else:
                new_sentence, sx, ex, sy, ey = change_sentence(sentence, subject_dict['start_idx'], subject_dict['end_idx'], object_dict['start_idx'], object_dict['end_idx'], sub_word, obj_word)

            # subject_entity//object_entity 새로운 정보 입력
            subject_dict['word'] = sub_word
            subject_dict['start_idx'] = sx
            subject_dict['end_idx'] = ex

            object_dict['word'] = obj_word
            object_dict['start_idx'] = sy
            object_dict['end_idx'] = ey

            new['id'].append(idx)
            new['sentence'].append(new_sentence)
            new['subject_entity'].append(str(subject_dict)) # string으로 변환해 입력
            new['object_entity'].append(str(object_dict)) # string으로 변환해 입력
            new['label'].append(label)
            new['source'].append(source) 

    # 데이터 프레임으로 변경
    new = pd.DataFrame(new)
    # 문장 기준 중복제거
    new = new.drop_duplicates('sentence')
    print("생성된 데이터 수:", new.shape)
    new.to_csv('new_'+target_class+'_members.csv', index=False)
    return new