最近發現身邊的一些初學者朋友捧著各種pytorch指南一邊看一邊敲代碼,到最后反而變成了打字員。
敲完代碼一運行,出來結果和書上一對比,哦,是書上的結果,就翻到下一章。
半天就能把一本書都打完,但是合上書好像什么都不記得。有的甚至看了兩三遍,都搭不出一個簡單的網絡來,這種學習方式很不可取。
如果你剛好是這種情況,這篇文章應該能給你一些幫助。如果你已經是進階的水平了,就直接關掉頁面就好了。
pytorch的網絡搭建,比tensorflow簡單很多。格式很好理解。
如果你想做一個網絡,需要先定義一個Class,繼承 nn.Module(這個是必須的,所以先import torch.nn as nn,nn是一個工具箱,很好用),我們把class的名字就叫成Net.
Class Net (nn.Module):
這個Class里面主要寫兩個函數,一個是初始化的__init__函數,另一個是forward函數。我們隨便搭一個,如下:
def __init__(self):
super().__init__()
self.conv1=nn.Conv2d(1,6,5)
self.conv2=nn.Conv2d(6,16,5)
def forward(self, x):
x=F.max_pool2d(F.relu(self.conv1(x)),2)
x=F.max_pool2d(F.relu(self.conv2(x)),2)
return x
__init__里面就是定義卷積層,當然先得super()一下,給父類nn.Module初始化一下。
(Python的基礎知識)在這個里面主要就是定義卷積層的,比如第一層,我們叫它conv1,把它定義成輸入1通道,輸出6通道,卷積核5*5的的一個卷積層。conv2同理。
神經網絡“深度學習”其實主要就是學習卷積核里的參數,像別的不需要學習和改變的,就不用放進去。
比如激活函數relu(),你非要放進去也行,再給它起個名字叫myrelu,也是可以的。forward里面就是真正執行數據的流動。
比如上面的代碼,輸入的x先經過定義的conv1(這個名字是你自己起的),再經過激活函數F.relu()(這個就不是自己起的名字了,最開始應該先import torch.nn.functional as F,F.relu()是官方提供的函數。
當然如果你在__init__里面把relu定義成了我上面說的myrelu,那你這里直接第一句話就成了x=F.max_pool2d(myrelu(self.conv1(x)),2)。
下一步的F.max_pool2d池化也是一樣的,不多廢話了。在一系列流動以后,最后把x返回到外面去。
這個Net的Class定義主要要注意兩點。
第一:是注意前后輸出通道和輸入通道的一致性。不能第一個卷積層輸出4通道第二個輸入6通道,這樣就會報錯。
第二:它和我們常規的python的class還有一些不同,發現了沒有?我們該怎么用這個Net呢?
先定義一個Net的實例(畢竟Net只是一個類不能直接傳參數,output=Net(input)當然不行)
net=Net()
這樣我們就可以往里傳x了,假設你已經有一個要往神經網絡的輸入的數據“input“(這個input應該定義成tensor類型,怎么定義tensor那就自己去看看書了。)在傳入的時候,是:
output=net(input)
看之前的定義:
def __init__(self):
…… def forward(self, x): ……
有點奇怪。好像常規python一般向class里面傳入一個數據x,在class的定義里面,應該是把這個x作為形參傳入__init__函數里的,而在上面的定義中,x作為形參是傳入forward函數里面的。
其實也不矛盾,因為你定義net的時候,是net=Net(),并沒有往里面傳入參數。如果你想初始化的時候按需傳入,就把需要的傳入進去。
只是x是神經網絡的輸入,但是并非是初始化需要的,初始化一個網絡,必須要有輸入數據嗎?
未必吧。只是在傳入網絡的時候,會自動認為你這個x是喂給forward里面的。也就是說,先定義一個網絡的實例net=Net(), 這時調用output=net(input), 可以理解為等同于調用output=net.forward(input), 這兩者可以理解為一碼事。
在網絡定義好以后,就涉及到傳入參數,算誤差,反向傳播,更新權重…確實很容易記不住這些東西的格式和順序。
傳入的方式上面已經介紹了,相當于一次正向傳播,把一路上各層的輸入x都算出來了。
想讓神經網絡輸出的output跟你期望的ground truth差不多,那就是不斷減小二者間的差異,這個差異是你自己定義的,也就是目標函數(object function)或者就是損失函數。
如果損失函數loss趨近于0,那么自然就達到目的了。
損失函數loss基本上沒法達到0,但是希望能讓它達到最小值,那么就是希望它能按照梯度進行下降。
梯度下降的公式,大家應該都很熟悉,不熟悉的話,建議去看一下相關的理論。誰喜歡看公式呢?所以我這里不講。
只是你的輸入是由你來決定的,那神經網絡能學習和決定什么呢?
自然它只能決定每一層卷積層的權重。所以神經網絡只能不停修改權重,比如y=wx+b,x是你給的,它只能改變w,b讓最后的輸出y盡可能接近你希望的y值,這樣損失loss就越來越小。
如果loss對于輸入x的偏導數接近0了,不就意味著到達了一個極值嗎?
而l在你的loss計算方式已經給定的情況下,loss對于輸入x的偏導數的減小,其實只能通過更新參數卷積層參數W來實現(別的它決定不了啊,都是你輸入和提供的)。
所以,通過下述方式實現對W的更新:(注意這些編號,下面還要提)
【1】 先算loss對于輸入x的偏導,(當然網絡好幾層,這個x指的是每一層的輸入,而不是最開始的輸入input)
【2】 對【1】的結果再乘以一個步長(這樣就相當于是得到一個對參數W的修改量)
【3】 用W減掉這個修改量,完成一次對參數W的修改。
說的不太嚴謹,但是大致意思是這樣的。這個過程你可以手動實現,但是大規模神經網絡怎么手動實現?那是不可能的事情。所以我們要利用框架pytorch和工具箱torch.nn。
所以要定義損失函數,以MSEloss為例:
compute_loss=nn.MSELoss()
明顯它也是個類,不能直接傳入輸入數據,所以直接loss=nn.MSEloss(target,output)是不對的。需要把這個函數賦一個實例,叫成compute_loss。
之后就可以把你的神經網絡的輸出,和標準答案target傳入進去:
loss=compute_loss(target,output)
算出loss,下一步就是反向傳播:
loss.backward()
這一步其實就是把【1】給算完了,得到對參數W一步的更新量,算是一次反向傳播。
這里就注意了,loss.backward()是啥玩意?如果是自己的定義的loss(比如你就自己定義了個def loss(x,y):return y-x )這樣肯定直接backward會出錯。所以應當用nn里面提供的函數。
當然搞深度學習不可能只用官方提供的loss函數,所以如果你要想用自己的loss函數。
必須也把loss定義成上面Net的樣子(不然你的loss不能反向傳播,這點要注意,注:這點是以前寫的,很久之前的版本不行,現在都可以了,基本不太需要這樣了)。
也是繼承nn.Module,把傳入的參數放進forward里面,具體的loss在forward里面算,最后return loss。__init__()就空著,寫個super().__init__就行了。
在反向傳播之后,第【2】和第【3】怎么實現?就是通過優化器來實現。讓優化器來自動實現對網絡權重W的更新。
所以在Net定義完以后,需要寫一個優化器的定義(選SGD方式為例):
from torch import optimoptimizer=optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
同樣,優化器也是一個類,先定義一個實例optimizer,然后之后會用。
注意在optimizer定義的時候,需要給SGD傳入了net的參數parameters,這樣之后優化器就掌握了對網絡參數的控制權,就能夠對它進行修改了。
傳入的時候把學習率lr也傳入了。
在每次迭代之前,先把optimizer里存的梯度清零一下(因為W已經更新過的“更新量”下一次就不需要用了)
optimizer.zero_grad()
在loss.backward()反向傳播以后,更新參數:
optimizer.step()
所以我們的順序是:
1.先定義網絡:寫網絡Net的Class,聲明網絡的實例net=Net(),
2.定義優化器
optimizer=optim.xxx(net.parameters(),lr=xxx),
3.再定義損失函數(自己寫class或者直接用官方的,compute_loss=nn.MSELoss()或者其他。
4.在定義完之后,開始一次一次的循環:
①先清空優化器里的梯度信息,optimizer.zero_grad();
②再將input傳入,output=net(input) ,正向傳播
③算損失,loss=compute_loss(target,output) ##這里target就是參考標準值GT,需要自己準備,和之前傳入的input一一對應
④誤差反向傳播,loss.backward()
⑤更新參數,optimizer.step()
這樣就實現了一個基本的神經網絡。大部分神經網絡的訓練都可以簡化為這個過程,無非是傳入的內容復雜,網絡定義復雜,損失函數復雜,等等等等。
說的有問題的地方感謝指正!
編輯:黃飛
評論
查看更多