之前帶大家一起使用Keras訓練了一個GRU模型,并使用mnist的手寫字體數據集進行了驗證。本期小編將繼續帶來一篇擴展,即GRU模型的測試方法。盡管我們將其當作和CNN類似的方式,一次性傳給他固定長度的數據,但在具體實現上來說,還是另有門道的。讓我們慢慢講來。
首先回顧前面我們最終訓練并導出的測試模型:
注意紅色標注的位置,這就是一個典型的GRU節點:
模型的輸入是28*28,代表的含義是:時間步*特征維度,簡單來說,就是一次性送入模型多久的數據,即時間步實際上是一個時間單位。例如,我們想測試1s的數據進行檢測。那么可以將其提取出來10*28的特征向量,那么每一個特征代表的就是1s/10即100ms的音頻特征。了解了時間步,再回到模型本身,這里就是其中一種模型推理形式。即一次性將所有的數據都送進去,即1s對應的特征數據。然后計算得出一個結果。好處是:所見所得,和CNN類似,缺點是:必須等待1s數據,且循環time_step次。
那有沒有替代方案呢?當然,那就是小編要提到的另一種,我們在導出模型時候將time_step設置為1,并且設置stateful=True,同時將time_step=N的模型權重設置到新模型上。這里的stateful=True要注意,因為我們將之前連續的time_step拆成了獨立的,因此需要讓模型記住前一次的中間狀態。
可能這里大家有些疑問,為什么兩個模型time_step不同,權重竟然通用。這就要說GRU模型的特殊性了,我們剛才看到的被展開的小GRU節點,其實是權重共享的。也就是說,不管展開多少次,他們的權重不會變(這個讀者可以打開模型自行查看驗證)。因此,就可以用如下代碼生成新模型,并設置權重:
# 構建新模型 new_model = Sequential() new_model.add(GRU(128, batch_input_shape=(1, 1, 28), unroll=True, stateful=True)) new_model.add(Dense(10, activation='softmax')) new_model.set_weights(model.get_weights()) |
讓我們看看模型的樣子:
是不是看起來非常清爽,請注意右下角那個AssignVariable,這個就是為了保存當前狀態,在下一次推理可以直接使用上一次的狀態。需要注意的是,由于模型輸入變成了一個time step,即1*28,在送入模型前,要注意一下。后續處理部分,依舊是FullyConnected+Softmax的形式,其他沒有改變,照常即可。
至此,所有關于GRU模型的介紹以及使用就全部講完了。MCU端的部署要靠大家自行體驗了,因為模型本身實際上可以使用和CNN一樣的推理方案,只是內部結構不同而已,希望對大家有所幫助!
恩智浦致力于打造安全的連接和基礎設施解決方案,為智慧生活保駕護航。
-
mcu
+關注
關注
146文章
17263瀏覽量
351981 -
測試
+關注
關注
8文章
5360瀏覽量
126873 -
Gru
+關注
關注
0文章
12瀏覽量
7499 -
模型
+關注
關注
1文章
3279瀏覽量
48980 -
cnn
+關注
關注
3文章
353瀏覽量
22269
原文標題:深入GRU:解鎖模型測試新維度
文章出處:【微信號:NXP_SMART_HARDWARE,微信公眾號:恩智浦MCU加油站】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論