破解ChatGPT驚人耗電DeepMind新演算法訓練提效13倍,能耗暴降10倍
ChatGPT早已成為世界耗能大戶:一天用掉超50萬度電,相當於1.7萬個美國家庭的用電量!然而,大模型對能源的吞噬,遠不止於此。國際能源總署(IEA)預測,從2022年到2026年,資料中心的用電量將會翻倍。
隨著AI計算需求的膨脹,還需要用水來冷卻計算系統。研究稱,微軟用水量從2021年到22年飆升了34%,ChatGPT每處理5-50個提示就會消耗接近半公升水。
針對這種現狀,我們有更好的解決策略嗎?
最近,GoogleDeepMind研究團隊提出了一種加快AI訓練的新方法—多模態對比學習與聯合範例選擇(JEST),大大減少了所需的運算資源和時間。
JEST以13倍更少的迭代次數,以及10倍更少的計算量,超越了最先進的模型!
預訓練的參考模型,已經學習了什麼樣的資料是有「優質的」或「有用的」。然後透過模型,來引導資料選擇那些精心篩選的小型資料集。
這項發現揭示了,資料篩選水準可以作為評判Scaling Law的一個新維度。
網友激動表示,「我沒想到這麼快就會發生。模型能夠自主選擇訓練數據的能力是巨大的,因為它使訓練變得顯著更容易,你不再需要猜測什麼是高質量的訓練數據,你有一個能夠『理解』什麼樣的資料對自身學習最有價值的模型」。
前Google、蘋果軟體工程師稱讚道,這項研究非常令人印象深刻。
從「超級batch」篩選數據
無論是語言、視覺或多模態模型,資料品質是預訓練表現的重要驅動因素。例如Phi-3、Gemma 2等模型的成功讓我們看到了,更少、更高品質的數據有可能實現更強大的效能。
要篩選出高品質的數據,數據管道的建立就成為重要的工作。現有的方法大致可分為兩種:1)手動管理2)基於模型的資料管理,以正在訓練模型的特徵選擇高品質資料。
前者成本高且難以擴展,後者則可望為多模態LLM實現Scaling Law。
然而,現有方法忽略了一個事實。
如果僅在單一資料點的層面進行篩選,就沒有考慮到資料集以及batch的總體組成。畢竟,訓練資料是以batch為單位,資料點之間的依賴性不可忽視。
許多電腦視覺的研究都曾表明,hard negatives(表達空間中相近但標籤不同的樣本)相比可被平凡解的數據簇,能提供更有效的學習訊號。
那麼如何讓模型以batch為單位篩選資料呢?
論文提出的JEST演算法正是要解決這個問題,原理很好理解:就是直接從「超級batch」篩選出「子batch」。
技術介紹
用數學語言來描述這個問題,就是從大小為B的「超級batch」中提取出與學習最相關的子batch ℬ={,∈[1,…,]}⊂,過濾比率可以寫作=1−/ 。
先前的優先採樣(prioritized sampling)會使用基於模型的評分函數對每個資料點評分,再按比例採樣。 JEST則直接對整個子batch評分,再依照batch等級的分數取樣。
一種最直觀的啟發式方法就是在現有模型參數: hard(ℬ|)=ℓ(ℬ|) 中,直接選擇損失值最高的batch,這種方法可被稱之為“硬學習” (hard learner)。
這種方法具有丟棄瑣碎資料的理想屬性,已被證明適用於小型、乾淨的資料集;然而對於較大、較少管理的資料集往往弊大於利,因為它依舊會取樣到雜訊資料。
另一種方法常用於多模態,使用具有參數∗:^easy(ℬ|∗)=−ℓ(ℬ|∗) 的參考模型為預訓練模型取樣資料。但作者依舊否定了這個方案,因為它無法直接反映模型目前的狀態,可能過度依賴參考模型的選擇,而且不容易擴展。
最後,論文選擇借鏡ICML 2022年的一篇論文中提到的方法,將上述兩方面的評分結合起來:^learn(ℬ|,∗)=hard(ℬ|)+^easy(ℬ| ∗)=ℓ(ℬ|)−ℓ(ℬ|∗),並將此啟發式方法稱為「可學習性評分」(learnability score)。
其中,batch上的損失值ℓ(ℬ|)是各數據點總和,使用sigmoid對比損失函數計算(sigmoid-contrastive loss),因為相比softmax對比損失而言,它的擴展性更強。
由於batch上的對比損失可以分解為每個樣本的條件損失之和,因此可學習性評分可被分解為單一樣本可學習性分數(|,∗,ℬ)之和,寫作:
使用的順序採樣方法則受到了block Gibbs採樣的啟發。在第n次迭代、對第B_n個batch進行取樣時,依據下列機率公式對區塊{X_k}進行無替換取樣:
將X_k區塊加入B_n來更新目前採樣的batch,直到迭代數n=N時終止。演算法的整體流程如下圖所示:
實驗中發現,使用迭代數N=16且每次迭代時獨立取樣b/N=2048個樣本時,就足以恢復出學習性非常高的batch。
可學習性評分涉及使用參考模型為資料點評分,先前的方法慣常使用額外的小型模型,但這會增加每次迭代的計算成本,降低整體FLOP效率增益。
因此論文使用了線上模型近似的方法以及效率較高的FlexiViT架構,只使用降低解析度的32×32的patch來評估“超級batch”,與全解析度、patch大小為16×16的方法相比減少了72%的FLOP,以及67%的掛鐘時間(wall-clock 時間)。
此外,論文也提出了進行多解析度訓練的技巧。將每個batch隨機分成兩半,使用不同解析度編碼後再拼接起來,提升了評分流程和訓練的效率。
下圖詳細描述了全解析度JEST和多重解析度Flexi-JEST方法的偽代碼實作。
所有JEST實驗都在WebLI資料集上運行,包含經過寬鬆過濾的十億規模的英語圖像-文字對,參考模型的訓練則使用其中經過高品質過濾100M大小的子集(被稱為WebLI-curated) 。
在WebLI的基礎上,作者還額外從網路上抓取了6億個文字-圖像對並經過同樣強度的過濾,組成WebLI-curated++資料集訓練參考模型,拓展出JEST++/FlexiJEST++方法,來探索對數據管理的擴展。
論文所報告的平均性能包括4個多模態規範基準:ImageNet 0-Shot和10-Shot 分類以及COCO圖像到文本和文本到圖像的top-1檢索。
實驗結果
圖1中可以看到,使用JEST或FlexiJEST方法最明顯的優勢就是效率提升。
左圖中,相較於原有的SigLIP基準模型,JEST++可以在訓練資料量減少13.1×的情況下達到相同準確率。即使考慮到額外引入的評分成本,也有近10×的FLOP效率提升(中圖)。
右圖展現了JEST++/FlexiJEST++(綠色)與先前方法(灰色)的比較,相比CLIP、EVA-CLIP經典模型實現了計算成本和性能的雙重提升。
左圖和中圖的平均準確率由8個下游任務得出,右圖效能由ImageNet和COCO基準測試得出
產生可學習batch
研究人員首先評估了JEST在選擇可學習batch的效果。
為了直觀地理解此方法,作者們先將可學習性矩陣進行視覺化,也就是學習模型和參考模型之間,對batch中所有範例對的損失差異。
JEST就是依照範例子矩陣的可學習性總和比例進行取樣。
由於矩陣明顯非對角關係(圖2,左),獨立選擇顯然是次優的。
經過少量迭代(對應於用N=16塊填充batch),作者發現子batch的可學習性快速增加,達到了需要數千次迭代的暴力吉布斯採樣(Gibbs sampling )所提取batch的可學習性(圖2,中)。
對於0.5、0.8和0.9的過濾比例,他們從大小分別為65,536、163,840和327,680的超級batch中選擇32,768個範例的子batch。
在圖2右側,研究者也發現子batch的可學習性隨著更大的過濾比例而增加。
總之,JEST演算法是在訓練過程中選擇高度可學習batch的有效,且有效率的方法。
加速多模態學習
接下來,研究者使用JEST演算法選擇的可學習batch,檢驗訓練模型的效果。
所有實驗都使用在WebLI-curated上訓練的參考模型,這是一個ViT-B/16和Bert-B圖像-文字雙編碼器,30億訓練樣本,採用sigmoid對比損失函數。
圖3(左)顯示了在訓練過程中多個下游任務(ImageNet 0-Shot/10-Shot準確率和COCO圖像到文字/文字到圖像檢索)的平均效能。
結果也發現,JEST顯著加速了學習過程。
在使用50%、80%和90%的過濾比例時,分別只需20億、10億和6.7億訓練樣本就達到了30億均勻基準的最終性能。
在更大的過濾比例下,坐著觀察到類似於更大batch size時的訓練不穩定性,需要修改Adam優化器(β2 = 0.95)以穩定訓練,這表明JEST的資料篩選可以被視為增加了有效batch size。
在最終效能方面,當過濾90%的資料時,JEST也帶來了高達6%的顯著提升(圖3,中間,藍色曲線)。
值得注意的是,這種scaling行為這種表現提昇在獨立樣本選擇方法中,並沒有被觀察到。 (圖3,中間,橘色曲線)。
最後,研究者也評估JEST是否也改善了,除可學習性之外的其他優先標準。
圖3右側顯示了使用easy-reference優先選擇的模型在不同濾波比例下的表現。
與基於可學習性的優先選擇一致,JEST仍優於獨立樣本選擇,特別是在高過濾比例下(在這種情況下,獨立樣本選擇導致性能下降)。
優先選擇具有最高損失的數據產生了較小的收益,並且隨著過濾更多數據而更快地退化(圖10)。
由於基於可學習性的JEST產生了最佳的scaling行為,研究人員在後續實驗中保留了這個標準。
多解析度訓練和線上batch選擇之間的協同效應
隨著數據batch中被過濾的比例增加,基於可學習性評分的JEST變得更有效率。
然而,評分的成本會帶來顯著的提升:過濾超級batch 80%的資料會導致每次迭代的浮點運算量是IID訓練的4倍,或是在快取參考模型得分時是2.3倍。
儘管JEST在訓練迭代次數方面(以下簡稱「訓練效率」)顯著提高了效率,但額外的評分浮點運算降低了其相對於IID基準的計算效率(圖1,左vs右)。
因此,作者也研究了一種計算效率更高的變體,稱為Flexi-JEST,它使用多解析度訓練和低解析度評分,將總開銷降低到僅比基準高10%(圖4,左)。
這些近似方法對效能有什麼影響?
如預期的那樣,Flexi-JEST的每次迭代性能相對於JEST有所下降,但仍比IID有顯著的加速(圖1,左;圖4,中)。
然而,考慮到總浮點運算量的減少,每次迭代性能的下降是非常有利的:最好的Flexi-JEST模型與40B Siglip運行產生相同的平均性能,但浮點運算量減少了9.9倍,比全解析度JEST少2倍(圖1,右;圖4,中)。
這些實驗顯示了多解析度訓練和聯合範例選擇之間的協同效應,前者為加速後者提供了高效和準確的評分能力。
實驗結果,也指出了資料策劃策略的帕累托前沿(pareto front)。
如果以計算為代價來最大化訓練速度或訓練效率,全解析度JEST方法相對於可比較的IID訓練運行,可以產生高達13倍的加速。
實現強大數據品質引導
可學習性評分的核心是,一個在人類選擇的小型、精心篩選的資料集上,訓練的參考模型。
JEST的性能如何隨不同的篩選策略(在品質和數量之間權衡)而變化?
此外,JEST訓練的改進是否與參考模型的表現相關,還是這些指標是分離的?
理解質與量的權衡
研究人員探索了三種規模的資料篩選,每種都是原始WebLI資料集的一個子集:
– 弱篩選(十億級規模):使用圖像-文字對齊(ITA)過濾器。
– 中度篩選(3億級規模):使用ITA過濾器或文字品質(TQ)過濾器。
– 強篩選(1億級規模):結合使用TQ、ITA和額外的影像品質(aesthetic)濾鏡。
在整個過程中,作者將這個強篩選子集稱為「WebLI-curated」。
然後,他們在這四個WebLI子集上,各訓練10個epoch的標準SigLIP編碼器,並將它們用作在全WebLI資料集上進行JEST訓練的參考模型。
在不同的資料篩選方法中,參考模型的表現和JEST的表現似乎是解耦的(甚至可能是反相關的;圖5,左)。
雖然增加篩選(和減少資料集大小)會產生較弱的模型,但當它們被用作JEST預訓練的參考模型時,卻產生了相反的效果:
使用強篩選參考模型的JEST獲得了2.7%的改進,中度篩選獲得了1.5%的改進,弱篩選獲得了0.3%的改進。
擴展資料篩選
假設參考模型效能與JEST效能之間的普遍解耦,可能只是由資料篩選所施加的資料集大小限製成所造成的。
為了理解這種效果,研究人員在WebLI-curated上訓練了5個參考模型,同時改變所見的總樣本數(從2.5億到30億)。
在這種情況下,圖5(右)顯示了改進的參考模型與更好的JEST預訓練之間存在顯著的相關性。
這顯示「解耦」現象主要可以歸因於參考模型因篩選後資料集大小減少而導致的飽和。
此外,研究人員也注意到,當資料集達到飽和時,圖5(右)中的相關性開始崩解,即在10個epoch或看到10億個樣本之後。
這些結果表明,JEST可能會從進一步擴大參考資料集的資料篩選中獲益。
鑑於使用WebLI-curated++對資料進行擴充整理能顯著提升參考模型的效能,作者提出了是否有必要在原始WebLI資料集上進行預訓練的問題。
然而,在評估參考模型在不同資料集上的表現時,卻發現:雖然它在2個下游任務上的表現優於WebLI預訓練,但在其他6個任務上的表現,以及平均表現都明顯低於WebLI預訓練(表5)。
與現有數據比較
最後,論文應用JEST++在公開的LAION-2B資料集上進行預訓練,刪除了其中不安全的影像-文字對,但沒有進行其他的預先過濾。
這個資料規模相比的SOTA方法DBP減少了4×,但JEST++依舊遠遠超過了所有先前的離線資料管理方法。
簡化資料管理
之前提到過,用於預訓練的WebLI-curated是原始資料集WebLI過濾後得到的,以求篩選出高品質的圖像-文字對齊的資料。
如表3所示,這種離線資料管理流程對IID(獨立同分佈)訓練方法的效能至關重要,但JEST++則表現出了對預過濾流程的穩健性。即使沒有過濾,JEST++的效能也沒有明顯下滑,降低了模型對基礎資料集的要求。
結論和局限性
總體來說,JEST方法展現了「資料品質引導」(data quality bootstrapping)方法的巨大潛力,即使用小規模精選資料集來指導對更大的、未經管理的資料集的學習。
最近的研究表明,在下游任務未知時,靜態資料集的篩選會限制模型效能。這篇論文的結果則表明,相較於單獨選擇樣本的方法,在線建立batch能提高預訓練的效率。
無論是使用JEST參考模型對資料集進行預評分,還是透過可學習性評分來根據模型需求進行動態調整,都可以成為通用基礎資料集的更有效率的替代方案。
論文的最後,作者也提出了此方法的限制。雖然JEST同時實現了效能增益和訓練成本降低,但依舊依賴於小型、精心管理的參考資料集,它指定了未經管理的更大資料集中優先考慮的分佈。
因此,未來的工作可以探索一種方法,從指定的下游任務中如何推斷出參考資料集的組成和分佈。