PyTorch已為我們實現了大多數常用的非線性激活函數,我們可以像使用任何其他的層那樣使用它們。讓我們快速看一個在PyTorch中使用ReLU激活函數的例子:
在上面這個例子中,輸入是包含兩個正值、兩個負值的張量,對其調用ReLU函數,負值將取為0,正值則保持不變。
現在我們已經了解了構建神經網絡架構的大部分細節,我們來構建一個可用于解決真實問題的深度學習架構。上一章中,我們使用了簡單的方法,因而可以只關注深度學習算法如何工作。后面將不再使用這種方式構建架構,而是使用PyTorch中正常該用的方式構建。
1.PyTorch構建深度學習算法的方式
PyTorch中所有網絡都實現為類,創建PyTorch類的子類要調用nn.Module,并實現__init__和forward方法。在init方法中初始化層,這一點已在前一節講過。在forward方法中,把輸入數據傳給init方法中初始化的層,并返回最終的輸出。非線性函數經常被forward函數直接使用,init方法也會使用一些。下面的代碼片段展示了深度學習架構是如何用PyTrorch實現的:
如果你是Python新手,上述代碼可能會比較難懂,但它全部要做的就是繼承一個父類,并實現父類中的兩個方法。在Python中,我們通過將父類的名字作為參數傳入來創建子類。init方法相當于Python中的構造器,super方法用于將子類的參數傳給父類,我們的例子中父類就是nn.Module。
2.不同機器學習問題的模型架構
待解決的問題種類將基本決定我們將要使用的層,處理序列化數據問題的模型從線性層開始,一直到長短期記憶(LSTM)層。基于要解決的問題類別,最后一層是確定的。使用機器學習或深度學習算法解決的問題通常有三類,最后一層的情況通常如下。
?對于回歸問題,如預測T恤衫的銷售價格,最后使用的是有一個輸出的線性層,輸出值為連續的。
?將一張給定的圖片歸類為T恤衫或襯衫,用到的是sigmoid激活函數,因為它的輸出值不是接近1就是接近0,這種問題通常稱為二分類問題。
?對于多類別分類問題,如必須把給定的圖片歸類為T恤、牛仔褲、襯衫或連衣裙,網絡最后將使用softmax層。讓我們拋開數學原理來直觀理解softmax的作用。舉例來說,它從前一線性層獲取輸入,并輸出給定數量樣例上的概率。在我們的例子中,將訓練它預測每個圖片類別的4種概率。記住,所有概率相加的總和必然為1。
3.損失函數
一旦定義好了網絡架構,還剩下最重要的兩步。一步是評估網絡執行特定的回歸或分類任務時表現的優異程度,另一步是優化權重。
優化器(梯度下降)通常接受一個標量值,因而loss函數應生成一個標量值,并使其在訓練期間最小化。某些用例,如預測道路上障礙物的位置并判斷是否為行人,將需要兩個或更多損失函數。即使在這樣的場景下,我們也需要把損失組合成一個優化器可以最小化的標量。最后一章將詳細討論把多個損失值組合成一個標量的真實例子。
上一章中,我們定義了自己的loss函數。PyTorch提供了經常使用的loss函數的實現。我們看看回歸和分類問題的loss函數。
回歸問題經常使用的loss函數是均方誤差(MSE)。它和前面一章實現的loss函數相同。可以使用PyTorch中實現的loss函數,如下所示:
對于分類問題,我們使用交叉熵損失函數。在介紹交叉熵的數學原理之前,先了解下交叉熵損失函數做的事情。它計算用于預測概率的分類網絡的損失值,損失總和應為1,就像softmax層一樣。當預測概率相對正確概率發散時,交叉熵損失增加。例如,如果我們的分類算法對圖3.5為貓的預測概率值為0.1,而實際上這是只熊貓,那么交叉熵損失就會更高。如果預測的結果和真實標簽相近,那么交叉熵損失就會更低。
下面是用Python代碼實現這種場景的例子。
為了在分類問題中使用交叉熵損失,我們真的不需要擔心內部發生的事情——只要記住,預測差時損失值高,預測好時損失值低。PyTorch提供了loss函數的實現,可以按照如下方式使用。
PyTorch包含的其他一些loss函數如表3.1所示。
表3.1
L1 loss 通常作為正則化器使用;第4章將進一步講述
MSE loss 均方誤差損失,用于回歸問題的損失函數
Cross-entropy loss 交叉熵損失,用于二分類和多類別分類問題
NLL Loss 用于分類問題,允許用戶使用特定的權重處理不平衡數據集
NLL Loss2d 用于像素級分類,通常和圖像分割問題有關
4.優化網絡架構
計算出網絡的損失值后,需要優化權重以減少損失,并改善算法準確率。簡單起見,讓我們看看作為黑盒的優化器,它們接受損失函數和所有的學習參數,并微量調整來改善網絡性能。PyTorch提供了深度學習中經常用到的大多數優化器。如果大家想研究這些優化器內部的動作,了解其數學原理,強烈建議瀏覽以下博客:
PyTorch提供的一些常用的優化器如下:
?ADADELTA
?Adagrad
?Adam
?SparseAdam
?Adamax
?ASGD
?LBFGS
?RMSProp
?Rprop
?SGD
審核編輯 黃昊宇
-
非線性
+關注
關注
1文章
213瀏覽量
23114 -
函數
+關注
關注
3文章
4341瀏覽量
62806 -
pytorch
+關注
關注
2文章
808瀏覽量
13283
發布評論請先 登錄
相關推薦
評論