No date
MBPO | When to Trust Your Model: Model-Based Policy Optimization
快速解釋
這篇論文要回答的核心問題是:在 model-based RL 裡,到底什麼時候應該相信模型?模型可以幫你便宜地生成額外資料,提升樣本效率;但模型也會有誤差,rollout 一長就會累積偏差,最後讓 policy 學到利用模型漏洞的行為。MBPO 的答案很直接:不要用模型做很長的 rollout,而是從真實 replay buffer 裡的狀態出發,只做很多條很短的 model rollout,再把這些合成資料拿去訓練一個強的 model-free learner。
這個設計把兩件事分開了。任務本身的 horizon 可以很長,但模型只需要負責很短的 horizon。作者把這件事概括成把 task horizon 和 model horizon disentangle。實作上,MBPO 用 probabilistic ensemble dynamics model 產生短軌跡,用 SAC 吃這些資料做 policy optimization。結果是:它通常比純 model-free 方法更省樣本,同時又比需要長期模型規劃的 model-based 方法穩定。
問題設定
論文考慮的是標準的強化學習問題。MDP 可以寫成:
其中,轉移動態是未知的,reward 函數也是未知的,目標是找到一個 policy,使得期望折扣回報最大:
model-based RL 的典型流程,是先學一個近似模型:
再利用這個模型來改善 policy。問題在於:如果 policy optimization 太依賴模型,policy 很可能會跑到模型不準的區域,並利用模型誤差得到虛假的高回報;如果完全不依賴模型,又失去 model-based 方法的資料效率。
所以真正的設計問題不是要不要用模型,而是模型要用多少、用在哪裡、用多長的 rollout。這篇論文的重點就是把這個問題拆開分析,最後得出一個很實用的結論:模型可以用,但要短。
核心想法
MBPO 的方法分成理論與實作兩層。
- 先從理論上分析 model bias 與 policy shift 的 tradeoff
作者先寫出一個 model-based policy optimization 的一般形式:在真實環境收集資料,訓練模型,然後在模型下優化 policy。理論上,如果能控制模型誤差與 policy 和資料收集 policy 之間的分佈偏移,就能得到 monotonic improvement 的下界。
但問題是,若直接考慮完整 rollout,理論 bound 會很悲觀,因為模型誤差會隨 rollout horizon 累積。這會讓最保守的結論變成:乾脆不要用模型,也就是最好的 rollout 長度等於 0。
- 用 branched rollout 取代從初始狀態開始的長 rollout
為了避免長期 rollout 的誤差爆炸,作者提出 branched rollout。做法是:不從初始狀態開始在模型裡模擬整條軌跡,而是先從真實 replay buffer 取一個狀態,然後只在模型中往前 rollout k 步。這樣一來,模型只負責局部短期預測,長期分佈則仍然由真實環境資料支撐。
這個設計的關鍵直覺是:如果每條 model rollout 都很短,那麼 compounding error 還來不及變嚴重;但如果你從很多真實狀態出發,各自產生很多短 rollout,累積起來仍然可以提供大量訓練資料。
- 不用悲觀 worst-case bound,而是看模型實際的 generalization
作者發現,理論上最悲觀的 bound 幾乎永遠不會鼓勵使用模型。但在實驗中,模型對鄰近 policy 分佈的泛化常常比 worst-case 假設好得多。因此作者改用 empirical model generalization 來估計模型誤差如何隨 policy shift 增加。
換句話說,作者不是只問模型在訓練分佈上準不準,而是問當 policy 稍微改變時,模型誤差會長多快。如果這個成長率夠小,那麼非零的 rollout 長度就有合理性。
- 實作上用 ensemble dynamics 加 SAC 再配 short rollout
MBPO 的最終版本很務實。它不用模型直接做長期 planning,也不用模型取代 value learning,而是把模型當成一個高效率的資料擴增器:
- 用真實環境資料訓練 probabilistic ensemble dynamics model。
- 從 replay buffer 隨機抽真實狀態作為 branch point。
- 用當前 policy 在模型裡 rollout 很短的 k 步,產生 synthetic transitions。
- 把這些 model data 加進 model buffer。
- 用 SAC 在這些資料上做大量 gradient updates。
這個流程的重點不是追求模型的長期規劃能力,而是利用模型在局部短期上還算可信的區域,快速補充 policy learning 所需的 transition data。
- 最重要的實驗發現是一階甚至單步 rollout 就很有用
論文很強的一點,是它不是只提出短 rollout 比較好這種抽象結論,而是直接做 ablation。結果顯示,單步 rollout 已經很有競爭力;即使模型本身可以做很長的預測,拿長 rollout 去做 policy optimization 反而通常比較差。這剛好和理論分析的方向一致:模型該用在它最可靠的短期區域,而不是拿去替代整個長期規劃。
關鍵公式
整篇論文的起點是標準 RL 目標:
若有一個學到的動態模型,則可定義 policy 在模型下的回報:
作者先寫出一個一般形式的保證:真實環境中的回報可以由模型中的回報扣掉一個誤差項來下界:
其中, 代表模型誤差, 代表 policy shift 所造成的分佈偏移。這個式子的意義很重要:如果模型很準、policy shift 很小,那麼在模型下的改善就比較可能轉化成真實環境中的改善。
接著作者把 rollout 方式改成 k-step branched rollout,得到更細的下界。若 是從真實資料分佈分支出去、在模型中 rollout k 步得到的回報,則有:
這條式子說明 rollout 長度 k 帶來兩種相反效應:
- rollout 太短,模型提供的額外遠期資訊太少。
- rollout 太長,誤差項裡和 k 相關的部分會變大,compounding error 開始主導。
不過如果直接照這個 pessimistic bound 來選,通常會得到最好的長度等於 0。因此作者引入更新後 policy 下的模型誤差 ,並用 empirical generalization 去近似它:
將這個更實際的誤差估計代入後,可得到新的下界:
這個版本的重要性在於:它終於允許最優 rollout 長度滿足
也就是說,只要模型在新 policy 附近的泛化夠好,短而非零的 model rollout 就是合理的。
在實作上,MBPO 的 dynamics model 是 bootstrap ensemble,每個模型輸出下一步狀態與 reward 的高斯分佈:
這讓模型同時表達 aleatoric uncertainty,而 ensemble 則額外提供 epistemic uncertainty。
policy learning 採用 SAC。論文中給出的 actor 目標可寫成:
這表示 MBPO 的核心不是發明新的 policy optimizer,而是把模型產生的短期合成資料穩定地餵給一個已經很強的 off-policy learner。
模型結構
MBPO 的實際系統可以理解成四個模組的組合。
- 真實環境資料庫
先和真實環境互動,把 transition 存進 environment replay buffer。這些資料同時用來訓練 dynamics model,也提供 branched rollout 的起始狀態。
- 機率式 ensemble dynamics model
作者使用 bootstrap ensemble 的 probabilistic neural networks。每個 member 都預測下一步狀態與 reward 的均值和對角協方差。做 model rollout 時,每一步可隨機選一個 ensemble member 來產生轉移,降低 policy 對單一模型誤差的利用風險。
- 短期 branched rollout 生成器
每次需要 model data 時,從真實 replay buffer 均勻抽樣一個狀態,然後用當前 policy 在模型中 rollout k 步,把產生的 transition 存進 model replay buffer。這裡最重要的設計就是:rollout 起點來自真實資料,rollout 長度刻意保持很短。
- SAC policy optimizer
最後用 SAC 在 model replay buffer 上做大量更新。因為 model data 可以大量生成,MBPO 每個真實環境步通常可以做比純 model-free RL 更多的 gradient updates。論文提到,這個數量大約可達每個環境樣本 20 到 40 次 policy gradient steps。
如果把整個訓練流程寫成演算法,大致如下:
- 收集真實環境資料,更新 environment buffer。
- 用 environment buffer 重新訓練 dynamics ensemble。
- 從 environment buffer 抽狀態,做大量 k-step branched rollout。
- 把 rollout 得到的 synthetic transitions 存進 model buffer。
- 用 model buffer 上的資料更新 SAC policy 與 critic。
- 重複上述流程。
這個架構和傳統 model-based planner 的差別非常大。MBPO 不要求模型在整個任務 horizon 上都準;它只要求模型在短期內夠準,足以提供有效的局部訓練訊號。實驗也證明這個選擇很重要:在 Hopper 任務上,作者發現從 1 線性增加到 15 的 rollout schedule 效果最好,但即使固定單步 rollout,也已經是一個非常強、很難被超越的 baseline。這正是 MBPO 最核心的訊息:模型值得信任,但只在短期。