# Dataset에서 alignment 정의하는 부분 index pair 형성
alignments = [
alignment + offset
for align_idx, offset, src_len, tgt_len in zip(sort_order, offsets, src_lengths, tgt_lengths)
for alignment in [samples[align_idx]['alignment'].view(-1, 2)]
if check_alignment(alignment, src_len, tgt_len)
]
# label_smoothed_cross_entropy_with_alignment => forward function
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(nll_loss.data) if reduce else nll_loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
}
alignment_loss = None
# Compute alignment loss only for training set and non dummy batches.
if 'alignments' in sample and sample['alignments'] is not None:
alignment_loss = self.compute_alignment_loss(sample, net_output)
if alignment_loss is not None:
logging_output['alignment_loss'] = utils.item(alignment_loss.data)
loss += self.alignment_lambda * alignment_loss
return loss, sample_size, logging_output
def compute_alignment_loss(self, sample, net_output):
attn_prob = net_output[1]['attn']
# batch size, target size, source size
bsz, tgt_sz, src_sz = attn_prob.shape
attn = attn_prob.view(bsz * tgt_sz, src_sz)
align = sample['alignments']
align_weights = sample['align_weights'].float()
if len(align) > 0:
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
loss = -((attn[align[:, 1][:, None], align[:, 0][:, None]]).log() * align_weights[:, None]).sum()
else:
return None
return loss