HWC
← Note

No date

CQL | Conservative Q-Learning for Offline Reinforcement Learning

PaperRLOffline RL

快速解釋

這篇論文的核心問題是:offline RL 之所以困難,不只是因為沒有互動資料,而是因為標準 off-policy Q-learning 會把資料集外的 action 評得太高。當 policy improvement 去最大化這些被高估的 Q-value 時,policy 就會偏向 dataset support 之外的行為,最後整個 Bellman backup 會被錯誤的 optimistic value 帶走。

CQL 的做法不是像許多 prior offline RL 方法那樣直接對 policy 加 constraint,而是直接修改 Q-function 的學習目標,讓 Q 對資料集外 action 更保守。具體來說,CQL 在標準 Bellman error 之外,加上一個 conservative regularizer:它會傾向提高資料內 action 的相對價值,並壓低某個候選 action 分佈下的 Q-value。結果是,學到的 Q-function 對目前 policy 的期望值形成 lower bound,policy improvement 因而變得更安全。

問題設定

offline RL 的設定是:給定一個由 behavior policy 收集好的靜態資料集

D={(s,a,r,s)}\mathcal{D} = \{(s,a,r,s')\}

在不再與環境互動的前提下,學出一個高回報 policy。MDP 記為:

(S,A,T,r,γ)(\mathcal{S}, \mathcal{A}, T, r, \gamma)

其中,T(smids,a)T(s' mid s,a) 是轉移動態,r(s,a)r(s,a) 是 reward,gammagamma 是 discount factor。行為資料來自 behavior policy pibeta(amids)pi_{beta}(a mid s),資料分佈可寫成 state marginal 與 behavior policy 的乘積。

標準 off-policy RL 會維持一個 Q-function 與可選的 actor。對 actor-critic 而言,典型流程是交替做 policy evaluation 與 policy improvement。以單樣本 empirical Bellman backup 表示,可以寫成:

Q^k+1argminQE(s,a,s)D[(Q(s,a)B^πkQ^k(s,a))2]\hat{Q}_{k+1} \leftarrow \arg\min_Q \mathbb{E}_{(s,a,s') \sim \mathcal{D}} \left[ \left( Q(s,a) - \hat{\mathcal{B}}^{\pi_k} \hat{Q}_k(s,a) \right)^2 \right]

接著再用學到的 Q 去更新 policy。

問題在於:Bellman target 會查詢當前學到的 policy 所選出的 action,但 Q-function 真正被資料監督到的 action,只有 dataset 裡出現過的那些 action。因此,一旦函數近似在 OOD action 上產生虛高的 Q-value,policy improvement 就會把 actor 往這些 action 推,offline setting 又沒有真實環境可以把錯誤修回來,於是誤差會持續自我放大。

這篇論文的核心目標,就是讓 Q-function 變成保守的估計器:不要求每個 state-action pair 都是緊的 lower bound,但至少要讓 policy 真正會用到的那部分 value 被 lower-bound。

核心想法

CQL 的方法分成三個層次。

  1. 先做 conservative off-policy evaluation

作者先考慮 policy evaluation,而不是整個 RL loop。若目標 policy 是 pipi,那就希望學到一個保守的 QpiQ^{pi}。最直接的方法,是在 Bellman error 外,額外最小化某個 action 分佈下的 Q-value。這會把 Q 往下拉,降低過度樂觀估計。

但如果我們把所有 action 的 Q 都往下壓太多,就會得到過度悲觀的 pointwise lower bound。這雖然安全,卻太保守。作者指出,offline RL 真正需要 lower-bound 的不是每個 action 的值,而是 policy 自己的期望值。

  1. 對 policy 的 expected value 做更緊的 lower bound

為了得到更緊的 bound,CQL 不只在某個分佈 μ(as)\mu(a \mid s) 下最小化 Q,還會額外對資料分佈中的 action 做一個反向的最大化項。這樣做的效果是:

  1. 把 OOD 或高風險 action 的 Q 壓低。
  2. 讓 dataset 裡真正出現過的 action 不會被無差別一起壓扁。
  3. 使得 policy 在 learned Q 下的期望值,成為真實 policy value 的 lower bound。

也就是說,CQL 並不是單純做 pessimism,而是做「有方向的 pessimism」:把保守性主要放在資料外 action 上。

  1. 把這個 conservative Q-update 直接放進離線 RL loop

在完整 offline RL 中,作者不每次都做昂貴的完整 off-policy evaluation,而是把 conservative regularizer 直接嵌進一般的 Q-learning 或 actor-critic 更新。這形成了一個 family,稱為 CQL(R)。在這個 family 裡,可以用不同方式選擇額外的 action 分佈。

其中最實用的版本是 CQL(H)。它把 maximization over actions 轉成 log-sum-exp,等價於在每個 state 上對所有 action 的高 Q 值做 soft maximum。這個形式有兩個重要效果:

  1. 如果某些 OOD action 被函數近似錯誤地評得很高,它們會在 log-sum-exp 裡被放大,因而受到更強的懲罰。
  2. CQL 會擴大 in-distribution action 與 OOD action 之間的 Q-gap,也就是論文裡說的 gap-expanding behavior。

這是 CQL 真正和 policy-constraint 方法不同的地方。前者直接把 Q 函數修正成對 OOD action 不樂觀;後者則主要是限制 actor 不要走太遠。作者的觀點是,只限制 actor 還不夠,因為錯的 Q-function 本身仍然可能把 actor 往壞的方向推。

  1. CQL 隱式做了 safe policy improvement

論文更進一步證明,CQL 實際上等價於在 empirical MDP 上優化一個帶 penalty 的 RL 目標。這個 penalty 反映 learned policy 與 empirical behavior policy 之間的偏離程度。因此,CQL 雖然沒有顯式加 actor constraint,但它其實透過 conservative Q-function 隱含地做了保守 policy improvement。

關鍵公式

CQL 最基本的保守 policy evaluation 形式,是在 Bellman error 外加上一個 Q-value 懲罰。論文的 Equation 1 可寫成:

Q^k+1argminQαEsD,aμ(as)[Q(s,a)]+12E(s,a,s)D[Q(s,a)B^πQ^k(s,a)]2\hat{Q}_{k+1} \leftarrow \arg\min_Q \alpha \, \mathbb{E}_{s \sim \mathcal{D},\, a \sim \mu(a \mid s)} [Q(s,a)] + \frac{1}{2} \mathbb{E}_{(s,a,s') \sim \mathcal{D}} \left[ Q(s,a) - \hat{\mathcal{B}}^{\pi} \hat{Q}_k(s,a) \right]^2

這個版本會把某個 action 分佈下的 Q 壓低,因此得到 pointwise lower bound 的傾向。

為了讓 lower bound 更聚焦在 policy 的期望值,而不是所有 action,作者提出更重要的 Equation 2:

Q^k+1argminQα(EsD,aμ(as)[Q(s,a)]EsD,aπ^β(as)[Q(s,a)])+12E(s,a,s)D[Q(s,a)B^πQ^k(s,a)]2\hat{Q}_{k+1} \leftarrow \arg\min_Q \alpha \left( \mathbb{E}_{s \sim \mathcal{D},\, a \sim \mu(a \mid s)} [Q(s,a)] - \mathbb{E}_{s \sim \mathcal{D},\, a \sim \hat{\pi}_{\beta}(a \mid s)} [Q(s,a)] \right) + \frac{1}{2} \mathbb{E}_{(s,a,s') \sim \mathcal{D}} \left[ Q(s,a) - \hat{\mathcal{B}}^{\pi} \hat{Q}_k(s,a) \right]^2

這條式子的直覺非常重要。第一個期望把某個候選 action 分佈的 Q 往下拉,第二個期望則把資料中的 action 往上撐一些。若令 mu=pimu = pi,則作者證明 learned Q 對 policy 的期望值形成 tighter lower bound:

V^π(s)=Eaπ(as)[Q^π(s,a)]Vπ(s)\hat{V}^{\pi}(s) = \mathbb{E}_{a \sim \pi(a \mid s)}[\hat{Q}^{\pi}(s,a)] \le V^{\pi}(s)

也就是說,policy 在 learned Q 下看到的是保守值,因此拿這個 Q 去做 policy improvement 會更安全。

接著,作者把這個想法推廣成一個 family:

minQmaxμα(EsD,aμ(as)[Q(s,a)]EsD,aπ^β(as)[Q(s,a)])+12E(s,a,s)D[Q(s,a)B^πkQ^k(s,a)]2+R(μ)\min_Q \max_{\mu} \alpha \left( \mathbb{E}_{s \sim \mathcal{D},\, a \sim \mu(a \mid s)}[Q(s,a)] - \mathbb{E}_{s \sim \mathcal{D},\, a \sim \hat{\pi}_{\beta}(a \mid s)}[Q(s,a)] \right) + \frac{1}{2} \mathbb{E}_{(s,a,s') \sim \mathcal{D}} \left[ Q(s,a) - \hat{\mathcal{B}}^{\pi_k} \hat{Q}_k(s,a) \right]^2 + R(\mu)

這就是 CQL(R)。若把 regularizer 選成對某個 prior 的 KL,則可以得到一個 closed-form 的 action 分佈。特別地,當 prior 是 uniform action distribution 時,就會得到最常見的 CQL(H):

minQαEsD[logaexp(Q(s,a))Eaπ^β(as)[Q(s,a)]]+12E(s,a,s)D[Q(s,a)B^πkQ^k(s,a)]2\min_Q \alpha \, \mathbb{E}_{s \sim \mathcal{D}} \left[ \log \sum_a \exp(Q(s,a)) - \mathbb{E}_{a \sim \hat{\pi}_{\beta}(a \mid s)}[Q(s,a)] \right] + \frac{1}{2} \mathbb{E}_{(s,a,s') \sim \mathcal{D}} \left[ Q(s,a) - \hat{\mathcal{B}}^{\pi_k} \hat{Q}_k(s,a) \right]^2

這個 log-sum-exp 項可以視為 soft maximum,因此任何被錯誤高估的 action 都會被更強力地懲罰。

理論上,CQL 還對應到一個帶 penalty 的 empirical RL 目標。論文的 Theorem 3.5 表示,若用 Equation 2 的 fixed point 來做 policy optimization,等價於解:

πargmaxπJ(π,M^)α1γEsdM^π(s)[DCQL(π,π^β)(s)]\pi^* \leftarrow \arg\max_{\pi} J(\pi, \hat{M}) - \frac{\alpha}{1-\gamma} \mathbb{E}_{s \sim d^{\pi}_{\hat{M}}(s)} \left[ D_{\mathrm{CQL}}(\pi, \hat{\pi}_{\beta})(s) \right]

其中 penalty 定義為:

DCQL(π,π^β)(s)=aπ(as)(π(as)π^β(as)1)D_{\mathrm{CQL}}(\pi, \hat{\pi}_{\beta})(s) = \sum_a \pi(a \mid s) \left( \frac{\pi(a \mid s)}{\hat{\pi}_{\beta}(a \mid s)} - 1 \right)

這條式子說明:CQL 雖然沒有像 BEAR、BRAC 那樣直接對 actor 加距離約束,但它透過 conservative Q-update,隱含地在優化一個對 behavior policy 偏離有懲罰的目標。

模型結構

CQL 並不是一個全新網路架構,而是一個可以疊加在標準深度 RL 演算法上的 Q-learning framework。它的實作結構可以分成兩種。

  1. actor-critic 版本

在連續控制任務上,作者把 CQL 建在 SAC 上。整體流程是:

  1. 用 Bellman backup 加上 conservative regularizer 更新 Q-function。
  2. 再用 SAC 風格的 actor objective 更新 policy。

actor 更新可寫成:

ϕtϕt1+ηπEsD,aπϕ(s)[Qθ(s,a)logπϕ(as)]\phi_t \leftarrow \phi_{t-1} + \eta_{\pi} \, \mathbb{E}_{s \sim \mathcal{D},\, a \sim \pi_{\phi}(\cdot \mid s)} \left[ Q_{\theta}(s,a) - \log \pi_{\phi}(a \mid s) \right]

重要的是,CQL 不需要額外學一個 behavior policy model 來約束 actor,這點和很多 prior offline RL 方法不同。

  1. Q-learning 版本

在離散 action 任務上,作者把 CQL 建在 QR-DQN 上。此時沒有顯式 actor,而是直接把 conservative Q-objective 接在 value-based backup 上,用 Bellman optimality operator 取代 actor-critic 裡的 policy Bellman operator。

  1. CQL(H) 與 CQL(ρ)

論文裡實際上有兩個常用版本:

  1. CQL(H):使用 log-sum-exp 形式,效果通常較好,也最常被視為標準版 CQL。
  2. CQL(ρ):把 soft maximum 的 prior 換成前一個 policy,在高維 action 空間中有時更穩,因為 importance sampling 估計 log-sum-exp 在高維連續動作空間會有較大變異。
  3. 自動調整保守強度

對連續控制任務,作者還提出 Lagrange 版本,自動調整保守係數 alphaalpha。其形式為:

minQmaxα0α(Esdπβ(s)[logaexp(Q(s,a))Eaπβ(as)[Q(s,a)]]τ)+12E(s,a,s)D[Q(s,a)B^πkQ^k(s,a)]2\min_Q \max_{\alpha \ge 0} \alpha \left( \mathbb{E}_{s \sim d^{\pi_{\beta}}(s)} \left[ \log \sum_a \exp(Q(s,a)) - \mathbb{E}_{a \sim \pi_{\beta}(a \mid s)}[Q(s,a)] \right] - \tau \right) + \frac{1}{2} \mathbb{E}_{(s,a,s') \sim \mathcal{D}} \left[ Q(s,a) - \hat{\mathcal{B}}^{\pi_k} \hat{Q}_k(s,a) \right]^2

這個版本的意義是:如果 conservative gap 還不夠大,alphaalpha 就會被拉高;如果 gap 已經達到預設門檻 tautau,則 α\alpha 會下降。實驗裡這個 CQL-Lagrange 在 D4RL 尤其是 AntMaze 上更穩定。

  1. 整體訓練流程

如果把整體流程寫成最簡單的演算法,它就是:

  1. 初始化 Q-function,以及可選的 policy。
  2. 重複從靜態資料集抽 batch。
  3. 用 CQL objective 更新 Q-function。
  4. 若是 actor-critic 版本,再用 SAC 方式更新 actor。
  5. 反覆迭代直到收斂。

所以,CQL 的本質不是新的 policy parameterization,也不是新的 model-based 結構,而是把 offline RL 的核心問題重新放回 Q-function 本身:與其想辦法限制 actor 不要踩出資料分佈,不如先把 Q 練成對資料外 action 不樂觀的估計器。這也是它在複雜、多模態、混合來源資料集上通常比傳統 policy-constraint 方法更穩的原因。