A comparison of Hidden Markov Models and Conditional Random Fields, two kinds of probabilistic graphical models.
import torch
from torch import nn
# BMES四位序列标注法
PADDING = 0
B = 2
E = 3
S = 4
M = 5
START = 6
END = 7
LABEL_VOCAB = {0: '<pad>', 1: '<unk>', 2: 'B', 3: 'E', 4: 'S', 5: 'M'}
NUM_TAGS = 8
# batch_size x max_len x num_tags : (2, 9, 6)
logits = torch.randn(3, 9, NUM_TAGS)
print(f"logits:\n {logits}\n")
tags = torch.tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
[6, 4, 2, 3, 4, 4, 7, 0, 0],
[6, 2, 5, 3, 4, 7, 0, 0, 0]])
print(f"tags:\n {tags}\n")
mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0]])
print(f"mask:\n {mask}")
logits:
tensor([[[-1.4452e-01, -1.9019e-01, 9.7182e-01, -1.5851e+00, -1.3361e+00,
-1.4078e+00, 5.9501e-01, 1.1116e+00],
[-1.2016e+00, 5.7204e-01, -1.7459e-01, -1.2101e+00, -1.5633e+00,
-1.5958e+00, 7.4246e-01, -2.2454e-01],
[-1.5298e-01, 7.3708e-01, 3.7166e-01, 2.4409e-01, -7.6002e-01,
-1.1487e+00, -3.5016e-01, 1.7178e-01],
[ 5.3908e-01, 1.7665e-01, 8.4331e-02, 1.2331e+00, -6.3707e-01,
3.2750e-01, -9.5666e-01, -1.0764e+00],
[-2.0275e-01, -4.7478e-01, -2.4096e-01, 3.4847e-01, 1.4107e+00,
-6.7662e-01, 1.1356e+00, -8.8798e-01],
[-5.2993e-01, -8.7384e-01, -7.1909e-01, -9.0088e-01, -1.0477e+00,
5.7400e-01, -8.9259e-02, -9.4986e-01],
[ 5.1900e-01, -1.3026e+00, -1.7043e+00, -2.8520e-01, -1.4247e+00,
5.4460e-02, -6.3961e-01, 1.3025e-01],
[ 1.2510e+00, 2.2883e-01, 3.4238e-01, -7.5308e-01, 2.5237e-03,
-3.4200e-01, -2.2455e-01, -1.0249e+00],
[ 4.8784e-01, -1.5034e+00, -1.6049e-01, 5.3555e-01, 2.2210e-03,
-5.5209e-01, 2.6669e-01, -1.2266e-01]],
[[ 4.8289e-01, 8.0970e-01, -2.6692e-01, -1.0974e+00, 2.5424e-01,
1.4737e-01, 7.8376e-01, -1.4346e+00],
[-2.6347e-01, -1.5178e+00, 1.5867e+00, 7.5319e-01, -2.9615e-01,
6.0084e-01, -5.5508e-01, -1.8708e-01],
[ 2.3153e+00, -1.7178e-01, -6.4674e-01, 1.1492e+00, 7.9638e-01,
5.4677e-01, 1.6118e-01, 2.1595e-01],
[ 1.6032e+00, 1.0249e+00, -1.1967e+00, 7.7639e-01, 1.0185e+00,
-9.3879e-01, -1.4424e+00, 1.3888e+00],
[-1.0615e-01, 4.5697e-01, 1.2877e-01, 1.3390e+00, -1.9707e+00,
-5.8355e-01, -2.2570e+00, -1.1907e+00],
[-6.4982e-01, 1.6911e-01, -3.0681e-01, -5.5137e-01, -6.9925e-01,
1.9890e-01, 4.8145e-01, 5.3715e-01],
[ 5.2184e-01, 5.8037e-01, 8.7149e-01, -2.0938e+00, -4.8395e-01,
-2.0589e+00, -2.3386e+00, 7.5332e-02],
[ 3.7932e-02, 1.3092e+00, -5.4694e-01, 1.6446e-01, -1.7779e+00,
-6.6966e-01, -8.4106e-01, 9.3973e-02],
[ 6.5991e-01, 6.3159e-01, 1.3538e+00, -2.7384e-01, 8.5952e-01,
-3.7105e-01, -1.3350e-01, -2.1770e+00]],
[[-5.1107e-01, -7.3291e-01, 8.8863e-01, 8.5765e-02, 2.0189e-01,
-6.3774e-01, 1.1234e-01, 1.0721e+00],
[-1.7006e-01, -2.5931e-01, 1.1782e+00, -2.6335e-01, 1.4209e+00,
4.8150e-01, 1.1860e+00, 7.5185e-01],
[-1.4374e+00, 1.0826e+00, 2.8144e-02, 6.7766e-01, -1.9165e-01,
1.2558e-01, 1.2113e+00, 5.0190e-02],
[ 3.9188e-01, 1.5474e+00, -1.5797e+00, 6.2201e-01, -3.8201e-01,
-2.2004e+00, -1.2397e+00, 1.3466e+00],
[-1.6803e+00, 1.5669e+00, 5.5389e-01, -8.0821e-01, 1.8370e+00,
-2.8107e-02, 8.2856e-01, 7.0409e-01],
[ 1.3986e+00, -1.1884e+00, 4.2444e-01, 1.3998e+00, -7.3532e-01,
-9.4360e-01, -1.4932e-01, -1.2199e+00],
[ 2.2620e-01, -3.7655e-01, -3.5929e-01, -1.8760e+00, -2.0527e+00,
1.6695e-01, 2.2314e-01, 2.8254e-01],
[ 1.6245e+00, 7.8641e-01, 1.4010e+00, -7.4482e-02, -1.4197e+00,
-1.3425e+00, -8.8786e-01, 9.7691e-01],
[ 3.7015e-01, 1.5858e+00, -4.2269e-01, 2.0528e+00, 3.1429e-01,
-8.5222e-01, -5.5130e-01, -9.6802e-02]]])
tags:
tensor([[6, 4, 2, 5, 5, 3, 4, 7, 0],
[6, 4, 2, 3, 4, 4, 7, 0, 0],
[6, 2, 5, 3, 4, 7, 0, 0, 0]])
mask:
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0]])
logits = logits.transpose(0, 1)
tags = tags.transpose(0, 1).long()
mask = mask.transpose(0, 1).float()
trans_matrix = nn.Parameter(torch.randn(NUM_TAGS, NUM_TAGS))
trans_matrix
Parameter containing:
tensor([[ 0.3130, 0.6805, -1.6291, 1.1787, 0.2590, 0.0393, 0.3733, -1.0978],
[ 0.0257, -0.7165, 0.1774, -0.4137, -0.9312, 1.1543, -1.7091, 0.1694],
[-0.5890, -1.3841, -0.7339, 1.5916, -0.6873, -1.2056, 1.3192, 0.1930],
[-1.0862, -0.3838, 0.3116, -1.0895, -0.5821, -1.2720, 0.2369, 0.5035],
[-1.1792, 0.2191, -0.5459, 1.8000, -0.0737, 1.6784, -0.8590, -0.3808],
[-0.1573, -1.7135, -0.2278, 1.8250, -0.4302, 2.0009, -1.1343, 0.4233],
[-0.3599, -0.0824, -0.4446, 0.4188, -0.7154, -0.1829, 0.0595, -1.1767],
[-1.2007, 2.5477, 0.0693, -0.9544, -0.8122, 0.1949, -0.0823, 0.3554]],
requires_grad=True)
2. Compute the score for the gold path.
seq_len, batch_size, _ = logits.size()
print(f"seq_len: {seq_len}\nbatch_size: {batch_size}")
batch_idx = torch.arange(batch_size, dtype=torch.long)
batch_idx
seq_idx = torch.arange(seq_len, dtype=torch.long)
seq_idx
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
2.1 transition probability score
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 0.],
[1., 0., 0.],
[0., 0., 0.]])
mask = mask.eq(True)
mask
tensor([[ True, True, True],
[ True, True, True],
[ True, True, True],
[ True, True, True],
[ True, True, True],
[ True, True, True],
[ True, True, False],
[ True, False, False],
[False, False, False]])
flip_mask = mask.eq(False)
flip_mask
tensor([[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, True],
[False, True, True],
[ True, True, True]])
tensor([[6, 6, 6],
[4, 4, 2],
[2, 2, 5],
[5, 3, 3],
[5, 4, 4],
[3, 4, 7],
[4, 7, 0],
[7, 0, 0],
[0, 0, 0]])
tensor([[6, 6, 6],
[4, 4, 2],
[2, 2, 5],
[5, 3, 3],
[5, 4, 4],
[3, 4, 7],
[4, 7, 0],
[7, 0, 0]])
tensor([[4, 4, 2],
[2, 2, 5],
[5, 3, 3],
[5, 4, 4],
[3, 4, 7],
[4, 7, 0],
[7, 0, 0],
[0, 0, 0]])
Parameter containing:
tensor([[ 0.3130, 0.6805, -1.6291, 1.1787, 0.2590, 0.0393, 0.3733, -1.0978],
[ 0.0257, -0.7165, 0.1774, -0.4137, -0.9312, 1.1543, -1.7091, 0.1694],
[-0.5890, -1.3841, -0.7339, 1.5916, -0.6873, -1.2056, 1.3192, 0.1930],
[-1.0862, -0.3838, 0.3116, -1.0895, -0.5821, -1.2720, 0.2369, 0.5035],
[-1.1792, 0.2191, -0.5459, 1.8000, -0.0737, 1.6784, -0.8590, -0.3808],
[-0.1573, -1.7135, -0.2278, 1.8250, -0.4302, 2.0009, -1.1343, 0.4233],
[-0.3599, -0.0824, -0.4446, 0.4188, -0.7154, -0.1829, 0.0595, -1.1767],
[-1.2007, 2.5477, 0.0693, -0.9544, -0.8122, 0.1949, -0.0823, 0.3554]],
requires_grad=True)
tensor([[6, 6, 6],
[4, 4, 2],
[2, 2, 5],
[5, 3, 3],
[5, 4, 4],
[3, 4, 7],
[4, 7, 0],
[7, 0, 0],
[0, 0, 0]])
trans_score = trans_matrix[tags[:seq_len - 1], tags[1:]] # tensor的二维索引
trans_score
tensor([[-0.7154, -0.7154, -0.4446],
[-0.5459, -0.5459, -1.2056],
[-1.2056, 1.5916, 1.8250],
[ 2.0009, -0.5821, -0.5821],
[ 1.8250, -0.0737, -0.3808],
[-0.5821, -0.3808, -1.2007],
[-0.3808, -1.2007, 0.3130],
[-1.2007, 0.3130, 0.3130]], grad_fn=<IndexBackward>)
tensor([[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, True],
[False, True, True],
[ True, True, True]])
tensor([[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, False],
[False, False, True],
[False, True, True],
[ True, True, True]])
# 最终计算的转移矩阵!!!
trans_score = trans_score.masked_fill(flip_mask[1:, :], 0)
trans_score
tensor([[-0.7154, -0.7154, -0.4446],
[-0.5459, -0.5459, -1.2056],
[-1.2056, 1.5916, 1.8250],
[ 2.0009, -0.5821, -0.5821],
[ 1.8250, -0.0737, -0.3808],
[-0.5821, -0.3808, 0.0000],
[-0.3808, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]], grad_fn=<MaskedFillBackward0>)
2.2 emission probability score
# emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0)
# emit_score
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8])
tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8]])
tensor([[6, 6, 6],
[4, 4, 2],
[2, 2, 5],
[5, 3, 3],
[5, 4, 4],
[3, 4, 7],
[4, 7, 0],
[7, 0, 0],
[0, 0, 0]])
logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags]
tensor([[ 0.5950, 0.7838, 0.1123],
[-1.5633, -0.2962, 1.1782],
[ 0.3717, -0.6467, 0.1256],
[ 0.3275, 0.7764, 0.6220],
[-0.6766, -1.9707, 1.8370],
[-0.9009, -0.6992, -1.2199],
[-1.4247, 0.0753, 0.2262],
[-1.0249, 0.0379, 1.6245],
[ 0.4878, 0.6599, 0.3702]])
tensor([[[-1.4452e-01, -1.9019e-01, 9.7182e-01, -1.5851e+00, -1.3361e+00,
-1.4078e+00, 5.9501e-01, 1.1116e+00],
[ 4.8289e-01, 8.0970e-01, -2.6692e-01, -1.0974e+00, 2.5424e-01,
1.4737e-01, 7.8376e-01, -1.4346e+00],
[-5.1107e-01, -7.3291e-01, 8.8863e-01, 8.5765e-02, 2.0189e-01,
-6.3774e-01, 1.1234e-01, 1.0721e+00]],
[[-1.2016e+00, 5.7204e-01, -1.7459e-01, -1.2101e+00, -1.5633e+00,
-1.5958e+00, 7.4246e-01, -2.2454e-01],
[-2.6347e-01, -1.5178e+00, 1.5867e+00, 7.5319e-01, -2.9615e-01,
6.0084e-01, -5.5508e-01, -1.8708e-01],
[-1.7006e-01, -2.5931e-01, 1.1782e+00, -2.6335e-01, 1.4209e+00,
4.8150e-01, 1.1860e+00, 7.5185e-01]],
[[-1.5298e-01, 7.3708e-01, 3.7166e-01, 2.4409e-01, -7.6002e-01,
-1.1487e+00, -3.5016e-01, 1.7178e-01],
[ 2.3153e+00, -1.7178e-01, -6.4674e-01, 1.1492e+00, 7.9638e-01,
5.4677e-01, 1.6118e-01, 2.1595e-01],
[-1.4374e+00, 1.0826e+00, 2.8144e-02, 6.7766e-01, -1.9165e-01,
1.2558e-01, 1.2113e+00, 5.0190e-02]],
[[ 5.3908e-01, 1.7665e-01, 8.4331e-02, 1.2331e+00, -6.3707e-01,
3.2750e-01, -9.5666e-01, -1.0764e+00],
[ 1.6032e+00, 1.0249e+00, -1.1967e+00, 7.7639e-01, 1.0185e+00,
-9.3879e-01, -1.4424e+00, 1.3888e+00],
[ 3.9188e-01, 1.5474e+00, -1.5797e+00, 6.2201e-01, -3.8201e-01,
-2.2004e+00, -1.2397e+00, 1.3466e+00]],
[[-2.0275e-01, -4.7478e-01, -2.4096e-01, 3.4847e-01, 1.4107e+00,
-6.7662e-01, 1.1356e+00, -8.8798e-01],
[-1.0615e-01, 4.5697e-01, 1.2877e-01, 1.3390e+00, -1.9707e+00,
-5.8355e-01, -2.2570e+00, -1.1907e+00],
[-1.6803e+00, 1.5669e+00, 5.5389e-01, -8.0821e-01, 1.8370e+00,
-2.8107e-02, 8.2856e-01, 7.0409e-01]],
[[-5.2993e-01, -8.7384e-01, -7.1909e-01, -9.0088e-01, -1.0477e+00,
5.7400e-01, -8.9259e-02, -9.4986e-01],
[-6.4982e-01, 1.6911e-01, -3.0681e-01, -5.5137e-01, -6.9925e-01,
1.9890e-01, 4.8145e-01, 5.3715e-01],
[ 1.3986e+00, -1.1884e+00, 4.2444e-01, 1.3998e+00, -7.3532e-01,
-9.4360e-01, -1.4932e-01, -1.2199e+00]],
[[ 5.1900e-01, -1.3026e+00, -1.7043e+00, -2.8520e-01, -1.4247e+00,
5.4460e-02, -6.3961e-01, 1.3025e-01],
[ 5.2184e-01, 5.8037e-01, 8.7149e-01, -2.0938e+00, -4.8395e-01,
-2.0589e+00, -2.3386e+00, 7.5332e-02],
[ 2.2620e-01, -3.7655e-01, -3.5929e-01, -1.8760e+00, -2.0527e+00,
1.6695e-01, 2.2314e-01, 2.8254e-01]],
[[ 1.2510e+00, 2.2883e-01, 3.4238e-01, -7.5308e-01, 2.5237e-03,
-3.4200e-01, -2.2455e-01, -1.0249e+00],
[ 3.7932e-02, 1.3092e+00, -5.4694e-01, 1.6446e-01, -1.7779e+00,
-6.6966e-01, -8.4106e-01, 9.3973e-02],
[ 1.6245e+00, 7.8641e-01, 1.4010e+00, -7.4482e-02, -1.4197e+00,
-1.3425e+00, -8.8786e-01, 9.7691e-01]],
[[ 4.8784e-01, -1.5034e+00, -1.6049e-01, 5.3555e-01, 2.2210e-03,
-5.5209e-01, 2.6669e-01, -1.2266e-01],
[ 6.5991e-01, 6.3159e-01, 1.3538e+00, -2.7384e-01, 8.5952e-01,
-3.7105e-01, -1.3350e-01, -2.1770e+00],
[ 3.7015e-01, 1.5858e+00, -4.2269e-01, 2.0528e+00, 3.1429e-01,
-8.5222e-01, -5.5130e-01, -9.6802e-02]]])
Share on:
Twitter
❄ Facebook
❄ Email