HWC
← Note

No date

MBPO | When to Trust Your Model: Model-Based Policy Optimization

PaperRLWorld Model

快速解釋

這篇論文要回答的核心問題是:在 model-based RL 裡,到底什麼時候應該相信模型?模型可以幫你便宜地生成額外資料,提升樣本效率;但模型也會有誤差,rollout 一長就會累積偏差,最後讓 policy 學到利用模型漏洞的行為。MBPO 的答案很直接:不要用模型做很長的 rollout,而是從真實 replay buffer 裡的狀態出發,只做很多條很短的 model rollout,再把這些合成資料拿去訓練一個強的 model-free learner。

這個設計把兩件事分開了。任務本身的 horizon 可以很長,但模型只需要負責很短的 horizon。作者把這件事概括成把 task horizon 和 model horizon disentangle。實作上,MBPO 用 probabilistic ensemble dynamics model 產生短軌跡,用 SAC 吃這些資料做 policy optimization。結果是:它通常比純 model-free 方法更省樣本,同時又比需要長期模型規劃的 model-based 方法穩定。

問題設定

論文考慮的是標準的強化學習問題。MDP 可以寫成:

(S,A,p,r,γ,ρ0)(\mathcal{S}, \mathcal{A}, p, r, \gamma, \rho_0)

其中,轉移動態是未知的,reward 函數也是未知的,目標是找到一個 policy,使得期望折扣回報最大:

π=argmaxπη[π]\pi^* = \arg\max_{\pi} \eta[\pi]
η[π]=Eπ[t=0γtr(st,at)]\eta[\pi] = \mathbb{E}_{\pi}\left[\sum_{t=0}^{\infty} \gamma^t r(s_t, a_t)\right]

model-based RL 的典型流程,是先學一個近似模型:

pθ(s,rs,a)p_{\theta}(s', r \mid s, a)

再利用這個模型來改善 policy。問題在於:如果 policy optimization 太依賴模型,policy 很可能會跑到模型不準的區域,並利用模型誤差得到虛假的高回報;如果完全不依賴模型,又失去 model-based 方法的資料效率。

所以真正的設計問題不是要不要用模型,而是模型要用多少、用在哪裡、用多長的 rollout。這篇論文的重點就是把這個問題拆開分析,最後得出一個很實用的結論:模型可以用,但要短。

核心想法

MBPO 的方法分成理論與實作兩層。

  1. 先從理論上分析 model bias 與 policy shift 的 tradeoff

作者先寫出一個 model-based policy optimization 的一般形式:在真實環境收集資料,訓練模型,然後在模型下優化 policy。理論上,如果能控制模型誤差與 policy 和資料收集 policy 之間的分佈偏移,就能得到 monotonic improvement 的下界。

但問題是,若直接考慮完整 rollout,理論 bound 會很悲觀,因為模型誤差會隨 rollout horizon 累積。這會讓最保守的結論變成:乾脆不要用模型,也就是最好的 rollout 長度等於 0。

  1. 用 branched rollout 取代從初始狀態開始的長 rollout

為了避免長期 rollout 的誤差爆炸,作者提出 branched rollout。做法是:不從初始狀態開始在模型裡模擬整條軌跡,而是先從真實 replay buffer 取一個狀態,然後只在模型中往前 rollout k 步。這樣一來,模型只負責局部短期預測,長期分佈則仍然由真實環境資料支撐。

這個設計的關鍵直覺是:如果每條 model rollout 都很短,那麼 compounding error 還來不及變嚴重;但如果你從很多真實狀態出發,各自產生很多短 rollout,累積起來仍然可以提供大量訓練資料。

  1. 不用悲觀 worst-case bound,而是看模型實際的 generalization

作者發現,理論上最悲觀的 bound 幾乎永遠不會鼓勵使用模型。但在實驗中,模型對鄰近 policy 分佈的泛化常常比 worst-case 假設好得多。因此作者改用 empirical model generalization 來估計模型誤差如何隨 policy shift 增加。

換句話說,作者不是只問模型在訓練分佈上準不準,而是問當 policy 稍微改變時,模型誤差會長多快。如果這個成長率夠小,那麼非零的 rollout 長度就有合理性。

  1. 實作上用 ensemble dynamics 加 SAC 再配 short rollout

MBPO 的最終版本很務實。它不用模型直接做長期 planning,也不用模型取代 value learning,而是把模型當成一個高效率的資料擴增器:

  1. 用真實環境資料訓練 probabilistic ensemble dynamics model。
  2. 從 replay buffer 隨機抽真實狀態作為 branch point。
  3. 用當前 policy 在模型裡 rollout 很短的 k 步,產生 synthetic transitions。
  4. 把這些 model data 加進 model buffer。
  5. 用 SAC 在這些資料上做大量 gradient updates。

這個流程的重點不是追求模型的長期規劃能力,而是利用模型在局部短期上還算可信的區域,快速補充 policy learning 所需的 transition data。

  1. 最重要的實驗發現是一階甚至單步 rollout 就很有用

論文很強的一點,是它不是只提出短 rollout 比較好這種抽象結論,而是直接做 ablation。結果顯示,單步 rollout 已經很有競爭力;即使模型本身可以做很長的預測,拿長 rollout 去做 policy optimization 反而通常比較差。這剛好和理論分析的方向一致:模型該用在它最可靠的短期區域,而不是拿去替代整個長期規劃。

關鍵公式

整篇論文的起點是標準 RL 目標:

η[π]=Eπ[t=0γtr(st,at)]\eta[\pi] = \mathbb{E}_{\pi}\left[\sum_{t=0}^{\infty} \gamma^t r(s_t, a_t)\right]

若有一個學到的動態模型,則可定義 policy 在模型下的回報:

η^[π]\hat{\eta}[\pi]

作者先寫出一個一般形式的保證:真實環境中的回報可以由模型中的回報扣掉一個誤差項來下界:

η[π]η^[π]C(ϵm,ϵπ)\eta[\pi] \ge \hat{\eta}[\pi] - C(\epsilon_m, \epsilon_{\pi})

其中,epsilonmepsilon_m 代表模型誤差,epsilonpiepsilon_{pi} 代表 policy shift 所造成的分佈偏移。這個式子的意義很重要:如果模型很準、policy shift 很小,那麼在模型下的改善就比較可能轉化成真實環境中的改善。

接著作者把 rollout 方式改成 k-step branched rollout,得到更細的下界。若 ηbranch[π]\eta^{\mathrm{branch}}[\pi] 是從真實資料分佈分支出去、在模型中 rollout k 步得到的回報,則有:

η[π]ηbranch[π]2rmax[γk+1ϵπ(1γ)2+(γk+2)ϵπ1γ+k1γ(ϵm+2ϵπ)]\eta[\pi] \ge \eta^{\mathrm{branch}}[\pi] - 2 r_{\max} \left[ \frac{\gamma^{k+1} \epsilon_{\pi}}{(1-\gamma)^2} + \frac{(\gamma^k + 2)\epsilon_{\pi}}{1-\gamma} + \frac{k}{1-\gamma}(\epsilon_m + 2\epsilon_{\pi}) \right]

這條式子說明 rollout 長度 k 帶來兩種相反效應:

  1. rollout 太短,模型提供的額外遠期資訊太少。
  2. rollout 太長,誤差項裡和 k 相關的部分會變大,compounding error 開始主導。

不過如果直接照這個 pessimistic bound 來選,通常會得到最好的長度等於 0。因此作者引入更新後 policy 下的模型誤差 epsilonmepsilon_m',並用 empirical generalization 去近似它:

ϵ^m(ϵπ)ϵm+ϵπdϵmdϵπ\hat{\epsilon}_m'(\epsilon_{\pi}) \approx \epsilon_m + \epsilon_{\pi} \frac{d\epsilon_m'}{d\epsilon_{\pi}}

將這個更實際的誤差估計代入後,可得到新的下界:

η[π]ηbranch[π]2rmax[γk+1ϵπ(1γ)2+γkϵπ1γ+kϵm1γ]\eta[\pi] \ge \eta^{\mathrm{branch}}[\pi] - 2 r_{\max} \left[ \frac{\gamma^{k+1} \epsilon_{\pi}}{(1-\gamma)^2} + \frac{\gamma^k \epsilon_{\pi}}{1-\gamma} + \frac{k\epsilon_m'}{1-\gamma} \right]

這個版本的重要性在於:它終於允許最優 rollout 長度滿足

k>0k^* > 0

也就是說,只要模型在新 policy 附近的泛化夠好,短而非零的 model rollout 就是合理的。

在實作上,MBPO 的 dynamics model 是 bootstrap ensemble,每個模型輸出下一步狀態與 reward 的高斯分佈:

pθ(i)(st+1,rtst,at)=N(μθ(i)(st,at),Σθ(i)(st,at))p_{\theta}^{(i)}(s_{t+1}, r_t \mid s_t, a_t) = \mathcal{N}\bigl(\mu_{\theta}^{(i)}(s_t, a_t), \Sigma_{\theta}^{(i)}(s_t, a_t)\bigr)

這讓模型同時表達 aleatoric uncertainty,而 ensemble 則額外提供 epistemic uncertainty。

policy learning 採用 SAC。論文中給出的 actor 目標可寫成:

Jπ(ϕ,D)=EstD[DKL(πϕ(st)    exp{Qπ(st,)Vπ(st)})]J_{\pi}(\phi, \mathcal{D}) = \mathbb{E}_{s_t \sim \mathcal{D}} \left[ D_{\mathrm{KL}} \left( \pi_{\phi}(\cdot \mid s_t) \;\middle\|\; \exp\{Q^{\pi}(s_t, \cdot) - V^{\pi}(s_t)\} \right) \right]

這表示 MBPO 的核心不是發明新的 policy optimizer,而是把模型產生的短期合成資料穩定地餵給一個已經很強的 off-policy learner。

模型結構

MBPO 的實際系統可以理解成四個模組的組合。

  1. 真實環境資料庫

先和真實環境互動,把 transition 存進 environment replay buffer。這些資料同時用來訓練 dynamics model,也提供 branched rollout 的起始狀態。

  1. 機率式 ensemble dynamics model

作者使用 bootstrap ensemble 的 probabilistic neural networks。每個 member 都預測下一步狀態與 reward 的均值和對角協方差。做 model rollout 時,每一步可隨機選一個 ensemble member 來產生轉移,降低 policy 對單一模型誤差的利用風險。

  1. 短期 branched rollout 生成器

每次需要 model data 時,從真實 replay buffer 均勻抽樣一個狀態,然後用當前 policy 在模型中 rollout k 步,把產生的 transition 存進 model replay buffer。這裡最重要的設計就是:rollout 起點來自真實資料,rollout 長度刻意保持很短。

  1. SAC policy optimizer

最後用 SAC 在 model replay buffer 上做大量更新。因為 model data 可以大量生成,MBPO 每個真實環境步通常可以做比純 model-free RL 更多的 gradient updates。論文提到,這個數量大約可達每個環境樣本 20 到 40 次 policy gradient steps。

如果把整個訓練流程寫成演算法,大致如下:

  1. 收集真實環境資料,更新 environment buffer。
  2. 用 environment buffer 重新訓練 dynamics ensemble。
  3. 從 environment buffer 抽狀態,做大量 k-step branched rollout。
  4. 把 rollout 得到的 synthetic transitions 存進 model buffer。
  5. 用 model buffer 上的資料更新 SAC policy 與 critic。
  6. 重複上述流程。

這個架構和傳統 model-based planner 的差別非常大。MBPO 不要求模型在整個任務 horizon 上都準;它只要求模型在短期內夠準,足以提供有效的局部訓練訊號。實驗也證明這個選擇很重要:在 Hopper 任務上,作者發現從 1 線性增加到 15 的 rollout schedule 效果最好,但即使固定單步 rollout,也已經是一個非常強、很難被超越的 baseline。這正是 MBPO 最核心的訊息:模型值得信任,但只在短期。