最近有很多關于數據是否是新模型驅動 [1] [2] 的討論,無論結論如何,都無法改變我們在實際工作中獲取數據成本很高這一事實(人工費用、許可證費用、設備運行時間等方面)。
因此,在機器學習項目中,一個關鍵的問題是,為了達到比如分類器準確度等特定性能指標,我們需要多少訓練數據才夠。訓練數據多少的問題在相關文獻中也稱為樣本復雜度。
在這篇文章中,我們將從回歸分析開始到深度學習等領域,快速而廣泛地回顧目前關于訓練數據多少的經驗和相關的研究結果。具體來說,我們將:
說明回歸任務和計算機視覺任務訓練數據的經驗范圍;
給定統計檢驗的檢驗效能,討論如何確定樣本數量。這是一個統計學的話題,然而,由于它與確定機器學習訓練數據量密切相關,因此也將包含在本討論中;
展示統計理論學習的結果,說明是什么決定了訓練數據的多少;
給出下面問題的答案:隨著訓練數據的增加,模型性能是否會繼續改善?在深度學習的情況下又會如何?
提出一種在分類任務中確定訓練數據量的方法;
最后,我們將回答這個問題:增加訓練數據是處理數據不平衡的最佳方式嗎?
01
訓練數據量的經驗范圍
首先讓我們看一些廣泛使用的,用來確定訓練數據量的經驗方法,根據我們使用的模型類型:
回歸分析:根據 1/10 的經驗規則,每個預測因子 [3] 需要 10 個樣例。在 [4] 中討論了這種方法的其他版本,比如用 1/20 來處理回歸系數減小的問題,在 [5] 中提出了一個令人興奮的二元邏輯回歸變量。
具體地說,作者通過考慮預測變量的數量、總體樣本量以及正樣本量/總體樣本量的比例來估計訓練數據的多少。
計算機視覺:對于使用深度學習的圖像分類,經驗法則是每一個分類需要 1000 幅圖像,如果使用預訓練的模型 [6],這個需求可以顯著下降。
02
假設檢驗中樣本大小的確定
假設檢驗是數據科學家用來檢驗群體差異、確定新藥物療效等的工具之一??紤]到進行測試的能力,這里通常需要確定樣本大小。
讓我們來看看這個例子:一家科技巨頭搬到了 A 市,那里的房價大幅上漲。一位記者想知道,現在公寓的平均價格是多少。
如果給定公寓價格標準差為 60K,可接受的誤差范圍為 10K,他應該統計多少套公寓的價格然后進行平均,才能使結果有 95% 的置信度?
計算的公式如下:N 是他需要的樣本量,1.96 是 95% 置信度所對應的標準正態分布的個數:
樣本容量估計
根據上面的等式,記者需要考慮約 138 套公寓的價格即可。
上面的公式會根據具體的測試任務而變化,但它總是包括置信區間、可接受的誤差范圍和標準差度量。在[7]中可以找到關于這個主題的更好的討論。
03
訓練數據規模的統計學習理論
讓我們首先介紹一下著名的 Vapnik-Chevronenkis 維度 ( VC 維) [8]。VC 維是模型復雜度的度量,模型越復雜,VC 維越大。在下一段中,我們將介紹一個用 VC 表示訓練數據大小的公式。
首先,讓我們看一個經常用于展示 VC 維如何計算的例子:假設我們的分類器是二維平面上的一條直線,有 3 個點需要分類。
無論這 3 個點的正/負組合是什么(都是正的、2個正的、1個正的,等等),一條直線都可以正確地分類/區分這些正樣本和負樣本。
我們說線性分類器可以區分所有的點,因此,它的 VC 維至少是 3,又因為我們可以找到4個不能被直線準確區分的點的例子,所以我們說線性分類器的 VC 維正好是3。結果表明,訓練數據大小 N 是 VC 的函數 [8]:
從 VC 維估計訓練數據的大小
其中 d 為失效概率,epsilon 為學習誤差。因此,正如 [9] 所指出的,學習所需的數據量取決于模型的復雜度。一個明顯的例子是眾所周知的神經網絡對訓練數據的貪婪,因為它們非常復雜。
04
隨著訓練數據的增加,模型性能會繼續提高嗎?在深度學習的情況下又會怎樣?
學習曲線
上圖展示了在傳統機器學習 [10] 算法(回歸等)和深度學習 [11] 的情況下,機器學習算法的性能隨著數據量的增加而如何變化。
具體來說,對于傳統的機器學習算法,性能是按照冪律增長的,一段時間后趨于平穩。 文獻 [12]-[16],[18] 的研究展示了對于深度學習,隨著數據量的增加性能如何變化。
圖1顯示了當前大多數研究的共識:對于深度學習,根據冪次定律,性能會隨著數據量的增加而增加。
例如,在文獻 [13] 中,作者使用深度學習技術對3億幅圖像進行分類,他們發現隨著訓練數據的增加模型性能呈對數增長。
讓我們看看另一些在深度學習領域值得注意的,與上述矛盾的結果。具體來說,在文獻 [15] 中,作者使用卷積網絡來處理 1 億張 Flickr 圖片和標題的數據集。
對于訓練集的數據量,他們報告說,模型性能會隨著數據量的增加而增加,然而,在 5000 萬張圖片之后,它就停滯不前了。
在文獻[16]中,作者發現圖像分類準確度隨著訓練集的增大而增加,然而,模型的魯棒性在超過與模型特定相關的某一點后便開始下降。
05
在分類任務中確定訓練數據量的方法
眾所周知的學習曲線,通常是誤差與訓練數據量的關系圖。[17] 和 [18] 是了解機器學習中學習曲線以及它們如何隨著偏差或方差的增加而變化的參考資料。Python 在 scikit-learn [17] 也中提供了一個學習曲線的函數。
在分類任務中,我們通常使用一個稍微變化的學習曲線形式:分類準確度與訓練數據量的關系圖。
確定訓練數據量的方法很簡單:首先根據任務確定一個學習曲線形式,然后簡單地在圖上找到所需分類準確度對應的點。例如,在文獻 [19]、[20] 中,作者在醫學領域中使用了學習曲線法,并用冪律函數表示:
學習曲線方程
上式中 y 為分類準確度,x 為訓練數據,b1、b2 分別對應學習率和衰減率。參數的設置隨問題的不同而變化,可以用非線性回歸或加權非線性回歸對它們進行估計。
06
增加訓練數據是處理數據不平衡的最好方法嗎?
這個問題在文獻 [9] 中得到了解決。作者提出了一個有趣的觀點:在數據不平衡的情況下,準確性并不是衡量分類器性能的最佳指標。
原因很直觀:讓我們假設負樣本是占絕大多數,然后如果我們在大部分時間里都預測為負樣本,就可以達到很高的準確度。
相反,他們建議準確度和召回率(也稱為靈敏度)是衡量數據不平衡性能的最合適指標。除了上述明顯的準確度問題外,作者還認為,測量精度對不平衡區域的內在影響更大。
例如,在醫院的警報系統 [9] 中,高精確度意味著當警報響起時,病人很可能確實有問題。
選擇適當的性能測量方法,作者比較了在 imbalanced-learn [21] (Python scikit-learn 庫)中的不平衡校正方法和簡單的使用一個更大的訓練數據集。
具體地說,他們在一個 50,000 個樣本的藥物相關的數據集上,使用 imbalance-correction 中的K近鄰方法進行數據不平衡校正,這些不平衡校正技術包括欠采樣、過采樣和集成學習等,然后在與原數據集相近的 100 萬數據集上訓練了一個神經網絡。
作者重復實驗了 200 次,最終的結論簡單而深刻:在測量準確度和召回率方面,沒有任何一種不平衡校正技術可以與增加更多的訓練數據相媲美。
-
計算機視覺
+關注
關注
8文章
1698瀏覽量
46030 -
深度學習
+關注
關注
73文章
5507瀏覽量
121265
原文標題:深度學習,怎么知道你的訓練數據真的夠了?
文章出處:【微信號:cas-ciomp,微信公眾號:中科院長春光機所】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論