ML/AI/SW Developer

Triplet Loss with Pytorch

1. Triplet Loss?

  • meta learning의 대표적인 로스중 하나, 임베딩 벡터
    • 위 그림과 같이 triplet loss를 사용하려면 3가지의 데이터가 필요하다. 중심이 되는 Anchor와 Anchor와 다른 클래스에 속하는 Negative, 동일 클래스에 속하는 Positive가 그 3가지 이다.
    • Anchor와 Positive는 벡터 공간상에서 가까워 지도록, Anchor와 negative는 멀어지도록 gradient를 부여합니다.
  • 참고할 점!
    • negative sample을 잘 구성할 수록, 학습의 결과가 좋아진다고 합니다. 기본적으로 임베딩을 이용해 분류를 한다고 생각해보면, 딥러닝 모델이 헷갈리는 부분은 분명 유사한 샘플이지만 다른 클래스에 속하는 경우일 것입니다. 따라서, negative sample들을 anchor와 유사하게 구성할 수 록 더 좋은 방향으로 학습이 됩니다.
    • 또한 어떤 거리 함수로 판단할 것인지, margin을 얼마나 줄 것인지 또한 굉장히 중요한 것으로 계속 연구되고 있는 것으로 알고있습니다.
  • 참고

2. Pytorch 에서 활용하기

  • pytorch 1.10 버전에는 TripletMarginLoss 가 구현되어 있어 쉽게 활용이 가능합니다. 아래 코드는 pytorch가 제공하는 예제입니다.
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
output = triplet_loss(anchor, positive, negative)
output.backward()