與我們大多數從頭開始的實施一樣, 第 9.5 節旨在深入了解每個組件的工作原理。但是,當您每天使用 RNN 或編寫生產代碼時,您會希望更多地依賴于減少實現時間(通過為通用模型和函數提供庫代碼)和計算時間(通過優化這些庫實現)。本節將向您展示如何使用深度學習框架提供的高級 API 更有效地實現相同的語言模型。和以前一樣,我們首先加載時間機器數據集。
import torch from torch import nn from torch.nn import functional as F from d2l import torch as d2l
from mxnet import np, npx from mxnet.gluon import nn, rnn from d2l import mxnet as d2l npx.set_np()
from flax import linen as nn from jax import numpy as jnp from d2l import jax as d2l
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf from d2l import tensorflow as d2l
9.6.1. 定義模型
我們使用由高級 API 實現的 RNN 定義以下類。
class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_inputs, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = nn.RNN(num_inputs, num_hiddens) def forward(self, inputs, H=None): return self.rnn(inputs, H)
Specifically, to initialize the hidden state, we invoke the member method begin_state. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.
class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = rnn.RNN(num_hiddens) def forward(self, inputs, H=None): if H is None: H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx) outputs, (H, ) = self.rnn(inputs, (H, )) return outputs, H
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen API.
class RNN(nn.Module): #@save """The RNN model implemented with high-level APIs.""" num_hiddens: int @nn.compact def __call__(self, inputs, H=None): raise NotImplementedError
class RNN(d2l.Module): #@save """The RNN model implemented with high-level APIs.""" def __init__(self, num_hiddens): super().__init__() self.save_hyperparameters() self.rnn = tf.keras.layers.SimpleRNN( num_hiddens, return_sequences=True, return_state=True, time_major=True) def forward(self, inputs, H=None): outputs, H = self.rnn(inputs, H) return outputs, H
繼承自9.5 節RNNLMScratch中的類 ,下面的類定義了一個完整的基于 RNN 的語言模型。請注意,我們需要創建一個單獨的全連接輸出層。RNNLM
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = nn.LazyLinear(self.vocab_size) def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = nn.Dense(self.vocab_size, flatten=False) self.initialize() def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" training: bool = True def setup(self): self.linear = nn.Dense(self.vocab_size) def output_layer(self, hiddens): return self.linear(hiddens).swapaxes(0, 1) def forward(self, X, state=None): embs = self.one_hot(X) rnn_outputs, _ = self.rnn(embs, state, self.training) return self.output_layer(rnn_outputs)
class RNNLM(d2l.RNNLMScratch): #@save """The RNN-based language model implemented with high-level APIs.""" def init_params(self): self.linear = tf.keras.layers.Dense(self.vocab_size) def output_layer(self, hiddens): return tf.transpose(self.linear(hiddens), (1, 0, 2))
9.6.2. 訓練和預測
在訓練模型之前,讓我們使用隨機權重初始化的模型進行預測。鑒于我們還沒有訓練網絡,它會產生無意義的預測。
data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab)
'it hasgggggggggggggggggggg'
data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab)
'it hasxlxlxlxlxlxlxlxlxlxl'
data = d2l.TimeMachine(batch_size=1024, num_steps=32) rnn = RNN(num_hiddens=32) model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1) model.predict('it has', 20, data.vocab)
'it hasnvjdtagwbcsxvcjwuyby'
接下來,我們利用高級 API 訓練我們的模型。
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1) trainer.fit(model, data)
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1) trainer.fit(model, data)
with d2l.try_gpu(): trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1) trainer.fit(model, data)
與第 9.5 節相比,該模型實現了相當的困惑度,但由于實現優化,運行速度更快。和以前一樣,我們可以在指定的前綴字符串之后生成預測標記。
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has and the time trave '
model.predict('it has', 20, data.vocab, d2l.try_gpu())
'it has and the thi baid th'
model.predict('it has', 20, data.vocab)
'it has our in the time tim'
9.6.3. 概括
深度學習框架中的高級 API 提供標準 RNN 的實現。這些庫可幫助您避免浪費時間重新實現標準模型。此外,框架實施通常經過高度優化,與從頭開始實施相比,可顯著提高(計算)性能。
9.6.4. 練習
您能否使用高級 API 使 RNN 模型過擬合?
使用 RNN實現第 9.1 節的自回歸模型。
-
神經網絡
+關注
關注
42文章
4776瀏覽量
100952 -
pytorch
+關注
關注
2文章
808瀏覽量
13283
發布評論請先 登錄
相關推薦
評論