論文鏈接:https://arxiv.org/abs/2305.17476
代碼鏈接:
https://github.com/ML-GSAI/Understanding-GDA
概述
生成式數據擴增通過條件生成模型生成新樣本來擴展數據集,從而提高各種學習任務的分類性能。然而,很少有人從理論上研究生成數據增強的效果。為了填補這一空白,我們在這種非獨立同分布環境下構建了基于穩定性的通用泛化誤差界。基于通用的泛化界,我們進一步了探究了高斯混合模型和生成對抗網絡的學習情況。
在這兩種情況下,我們證明了,雖然生成式數據增強并不能享受更快的學習率,但當訓練集較小時,它可以在一個常數的水平上提高學習保證,這在發生過擬合時是非常重要的。最后,高斯混合模型的仿真結果和生成式對抗網絡的實驗結果都支持我們的理論結論。
主要的理論結果
2.1 符號與定義
讓 作為數據輸入空間, 作為標簽空間。定義 為 上的真實分布。給定集合 ,我們定義 為去掉第 個數據后剩下的集合, 為把第 個數據換成 后的集合。我們用 表示 total variation distance。
我們讓 為所有從 到 的所有可測函數, 為學習算法,為從數據集 中學到的映射。對于一個學到的映射 和損失函數,真實誤差 被定義為。相應的經驗的誤差 被定義為。
我們文章理論推導采用的是穩定性框架,我們稱算法 相對于損失函數 是一致 穩定的,如果
2.2 生成式數據增強
給定帶有 個 i.i.d. 樣本的 數據集,我們能訓練一個條件生成模型 ,并將學到的分布定義為 。基于訓練得到的條件生成模型,我們能生成一個新的具有 個 i.i.d. 樣本的數據集 。我們記增廣后的數據集 大小為 。我們可以在增廣后的數據集上學到映射 。為了理解生成式數據增強,我們關心泛化誤差 。據我們所知,這是第一個理解生成式數據增強泛化誤差的工作。2.3 一般情況
我們可以對于任意的生成器和一致 穩定的分類器,推得如下的泛化誤差:▲ general一般來說,我們比較關心泛化誤差界關于樣本數 的收斂率。將 看成超參數,并將后面兩項記為 generalization error w.r.t. mixed distribution,我們可以定義如下的“最有效的增強數量”:在這個設置下,并和沒有數據增強的情況進行對比(),我們可以得到如下的充分條件,它刻畫了生成式數據增強何時(不)能夠促進下游分類任務,這和生成模型學習分的能力息息相關:
▲ corollary
2.4 高斯混合模型為了驗證我們理論的正確性,我們先考慮了一個簡單的高斯混合模型的 setting。 混合高斯分布。我們考慮二分類任務 。我們假設真實分布滿足 and 。我們假設 的分布是已知的。 線性分類器。我們考慮一個被 參數化的分類器,預測函數為 。給定訓練集, 通過最小化負對數似然損失函數得到,即最小化學習算法將會推得 ,which satisfies 條件生成模型。我們考慮參數為 的條件生成模型,其中 以及 。給定訓練集,讓 為第 類的樣本量,條件生成模型學到
它們是 和 的無偏估計。我們可以從這個條件模型中進行采樣,即 ,,其中 。 我們在高斯混合模型的場景下具體計算 Theorem 3.1 中的各個項,可以推得
▲ GMM
- 當數據量 足夠時,即使我們采用“最有效的增強數量”,生成式數據增強也難以提高下游任務的分類性能。
- 當數據量 較小的,此時主導泛化誤差的是維度等其他項,此時進行生成式數據增強可以常數級降低泛化誤差,這意味著在過擬合的場景下,生成式數據增強是很有必要的。
2.5 生成對抗網絡
我們也考慮了深度學習的情況。我們假設生成模型為 MLP 生成對抗網絡,分類器為 層 MLP 或者 CNN。損失函數為二元交叉熵,優化算法為 SGD。我們假設損失函數平滑,并且第 層的神經網絡參數可以被 控制。我們可以推得如下的結論:▲ GAN
- 當數據量 足夠時,生成式數據增強也難以提高下游任務的分類性能,甚至會惡化。
- 當數據量 較小的,此時主導泛化誤差的是維度等其他項,此時進行生成式數據增強可以常數級降低泛化誤差,同樣地,這意味著在過擬合的場景下,生成式數據增強是很有必要的。
實驗
3.1 高斯混合模型模擬實驗
我們在混合高斯分布上驗證我們的理論,我們調整數據量 ,數據維度 以及 。實驗結果如下圖所示:
▲ simulation
- 觀察圖(a),我們可以發現當 相對于 足夠大的時候,生成式數據增強的引入并不能明顯改變泛化誤差。
- 觀察圖(d),我們可以發現當 固定時,真實的泛化誤差確實是 階的,且隨著增強數量 的增大,泛化誤差呈現常數級的降低。
- 另外 4 張圖,我們選取了兩種情況,驗證了我們的 bound 能在趨勢上一定程度上預測泛化誤差。
▲ deep
- 在沒有額外數據增強的時候, 較小,分類器陷入了嚴重的過擬合。此時,即使選取的 cDCGAN 很古早(bad GAN),生成式數據增強都能帶來明顯的提升。
- 在有額外數據增強的時候, 充足。此時,即使選取的 StyleGAN 很先進(SOTA GAN),生成式數據增強都難以帶來明顯的提升,在 50k 和 100k 增強的情況下甚至都造成了一致的損害。
-
我們也測試了一個 SOTA 的擴散模型 EDM,發現即使在有額外數據增強的時候,生成式數據增強也能提升分類效果。這意味著擴散模型學習分布的能力可能會優于 GAN。
-
物聯網
+關注
關注
2912文章
44889瀏覽量
375755
原文標題:NeurIPS 2023 | 如何從理論上研究生成式數據增強的效果?
文章出處:【微信號:tyutcsplab,微信公眾號:智能感知與物聯網技術研究所】歡迎添加關注!文章轉載請注明出處。
發布評論請先 登錄
相關推薦
評論