我們繼續(xù)實(shí)現(xiàn) 15.1 節(jié)中定義的 skip-gram 模型。然后我們將在 PTB 數(shù)據(jù)集上使用負(fù)采樣來預(yù)訓(xùn)練 word2vec。首先,讓我們通過調(diào)用函數(shù)來獲取數(shù)據(jù)迭代器和這個(gè)數(shù)據(jù)集的詞匯表 ,這在第 15.3 節(jié)d2l.load_data_ptb
中有描述
import math
import torch
from torch import nn
from d2l import torch as d2l
batch_size, max_window_size, num_noise_words = 512, 5, 5
data_iter, vocab = d2l.load_data_ptb(batch_size, max_window_size,
num_noise_words)
15.4.1。Skip-Gram 模型
我們通過使用嵌入層和批量矩陣乘法來實(shí)現(xiàn) skip-gram 模型。首先,讓我們回顧一下嵌入層是如何工作的。
15.4.1.1。嵌入層
如第 10.7 節(jié)所述,嵌入層將標(biāo)記的索引映射到其特征向量。該層的權(quán)重是一個(gè)矩陣,其行數(shù)等于字典大小 ( input_dim
),列數(shù)等于每個(gè)標(biāo)記的向量維數(shù) ( output_dim
)。一個(gè)詞嵌入模型訓(xùn)練好之后,這個(gè)權(quán)重就是我們所需要的。
Parameter embedding_weight (torch.Size([20, 4]), dtype=torch.float32)
embed = nn.Embedding(input_dim=20, output_dim=4)
embed.initialize()
embed.weight
Parameter embedding0_weight (shape=(20, 4), dtype=float32)
嵌入層的輸入是標(biāo)記(單詞)的索引。對(duì)于任何令牌索引i,它的向量表示可以從ith嵌入層中權(quán)重矩陣的行。由于向量維度 ( output_dim
) 設(shè)置為 4,因此嵌入層返回形狀為 (2, 3, 4) 的向量,用于形狀為 (2, 3) 的標(biāo)記索引的小批量。
tensor([[[-0.6501, 1.3547, 0.7968, 0.3916],
[ 0.4739, -0.0944, 1.2308, 0.6457],
[ 0.4539, 1.5194, 0.4377, -1.5122]],
[[-0.7032, -0.1213, 0.2657, -0.6797],
[ 0.2930, -0.6564, 0.8960, -0.5637],
[-0.1815, 0.9487, 0.8482, 0.5486]]], grad_fn=<EmbeddingBackward0>)
array([[[ 0.01438687, 0.05011239, 0.00628365, 0.04861524],
[-0.01068833, 0.01729892, 0.02042518, -0.01618656],
[-0.00873779, -0.02834515, 0.05484822, -0.06206018]],
[[ 0.06491279, -0.03182812, -0.01631819, -0.00312688],
[ 0.0408415 , 0.04370362, 0.00404529, -0.0028032 ],
[ 0.00952624, -0.01501013, 0.05958354, 0.04705103]]])
15.4.1.2。定義前向傳播
在正向傳播中,skip-gram 模型的輸入包括形狀為(批大小,1)的中心詞索引和 形狀為(批大小,)center
的連接上下文和噪聲詞索引,其中定義在 第 15.3.5 節(jié). 這兩個(gè)變量首先通過嵌入層從標(biāo)記索引轉(zhuǎn)換為向量,然后它們的批量矩陣乘法(在第 11.3.2.2 節(jié)中描述)返回形狀為(批量大小,1, )的輸出 。輸出中的每個(gè)元素都是中心詞向量與上下文或噪聲詞向量的點(diǎn)積。contexts_and_negatives
max_len
max_len
max_len
skip_gram
讓我們?yōu)橐恍┦纠斎?/font>打印此函數(shù)的輸出形狀。
15.4.2。訓(xùn)練
在用負(fù)采樣訓(xùn)練skip-gram模型之前,我們先定義它的損失函數(shù)。
15.4.2.1。二元交叉熵?fù)p失
根據(jù)15.2.1節(jié)負(fù)采樣損失函數(shù)的定義,我們將使用二元交叉熵?fù)p失。
class SigmoidBCELoss(nn.Module):
# Binary cross-entropy loss with masking
def __init__(self):
super().__init__()
def forward(self, inputs, target, mask=None):
out = nn.functional.binary_cross_entropy_with_logits(
inputs, target, weight=mask, reduction="none")
return out.mean(dim=1)
loss = SigmoidBCELoss()
回想我們?cè)诘?15.3.5 節(jié)中對(duì)掩碼變量和標(biāo)簽變量的描述 。下面計(jì)算給定變量的二元交叉熵?fù)p失。
tensor([0.9352, 1.8462])
pred = np.array([[1.1, -2.2, 3.3, -4.4]] * 2)
label = np.array([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = np.array([[1,
評(píng)論
查看更多