HWC
← Note

No date

MADIFF | Offline Multi-agent Learning with Diffusion Models

PaperRLDiffusion

快速解釋

MADIFF 把多智能體離線學習改成生成未來 joint trajectory 的問題:先用 diffusion model 規劃多個 agents 接下來會怎麼一起動,再用 inverse dynamics 把相鄰狀態轉回 actions。它的關鍵不是單純把 diffusion 套進 RL,而是把跨 agent attention 放進每一步 denoising,讓協調直接發生在生成過程中。

問題設定

  • Offline MARL 裡,TD learning 容易因離線資料造成 extrapolation error。
  • 若每個 agent 各自學獨立模型,容易失去 coordination;若把所有 agents 生硬串接成一個超大向量,又會浪費表示能力、破壞 agent permutation symmetry。
  • 因此需要一個既能在離線 joint trajectories 上學到協調,又能支援 centralized training / decentralized execution 的生成模型。

核心想法

  • MADIFF 不直接生成 actions,而是生成未來一段多 agent 的狀態軌跡;因為狀態序列通常比 action 序列更平滑、更好學。
  • 為了在 denoising 過程中交換 agent 之間的資訊,模型在每個 agent 的 U-Net decoder 前插入跨 agent attention,讓每個 agent 的 latent 表徵都能看見其他 agents 的關鍵訊息。
  • 訓練時使用帶 return 條件的 diffusion model;推論時用 classifier-free guidance 把採樣往高回報的 joint behaviors 拉。
  • 執行時可採 centralized control,也可在 decentralized setting 只給單一 agent 的局部觀測,透過模型同時預測 teammates 的未來行為,等於把 teammate modeling 內建在 diffusion policy 裡。

關鍵公式

diffusion 的基本去噪目標為

LDM(θ)=E[ϵϵθ(xk,k)22].L_{\mathrm{DM}}(\theta)= \mathbb{E}\left[ \left\lVert \epsilon-\epsilon_\theta(x_k,k) \right\rVert_2^2 \right].

MADIFF 生成的是未來狀態軌跡,再由 inverse dynamics 轉成 action:

τ^=[st,s^t+1,,s^t+H1],a^t=Iϕ(st,s^t+1).\hat{\tau}= [s_t,\hat{s}_{t+1},\dots,\hat{s}_{t+H-1}], \qquad \hat{a}_t = I_\phi(s_t,\hat{s}_{t+1}).

inverse dynamics 與條件 diffusion 的聯合訓練可寫成

L(θ,ϕ)=iE[aiIϕi(oi,oi)22]+E[ϵϵθ(τ^k,(1β)y(τ)+β,k)22].L(\theta,\phi)= \sum_i \mathbb{E} \left[ \left\lVert a^i-I_\phi^i(o^i,o^{\prime i}) \right\rVert_2^2 \right] + \mathbb{E} \left[ \left\lVert \epsilon-\epsilon_\theta\bigl(\hat{\tau}_k,(1-\beta)y(\tau)+\beta\varnothing,k\bigr) \right\rVert_2^2 \right].

classifier-free guidance 在推論時使用

ϵ^=ϵθ(τ^k,,k)+ω(ϵθ(τ^k,y(τ),k)ϵθ(τ^k,,k)).\hat{\epsilon} = \epsilon_\theta(\hat{\tau}_k,\varnothing,k) + \omega \left( \epsilon_\theta(\hat{\tau}_k,y(\tau),k) -\epsilon_\theta(\hat{\tau}_k,\varnothing,k) \right).

跨 agent attention 讓第 i 個 agent 能聚合其他 agents 的 latent:

qi=fquery(ci),ki=fkey(ci),vi=fvalue(ci),q^i=f_{\mathrm{query}}(c^i),\qquad k^i=f_{\mathrm{key}}(c^i),\qquad v^i=f_{\mathrm{value}}(c^i),
αij=exp(qikj/dk)pexp(qikp/dk),c^i=jαijvj.\alpha_{ij} = \frac{\exp\left(q^ik^j/\sqrt{d_k}\right)} {\sum_p \exp\left(q^ik^p/\sqrt{d_k}\right)}, \qquad \hat{c}^i=\sum_j \alpha_{ij}v^j.
  • 第二式表示策略不是直接輸出 action,而是先規劃下一段可行狀態,再反推 action。
  • 第三、四式把高 return 的條件引導和 agent 間資訊交換同時放進生成過程,因此 coordination 不是事後修補,而是逐步生成出來的。

模型結構

  1. 每個 agent 的 backbone:以 U-Net 為主體,沿時間維度做 1D convolution residual blocks 來生成個別 agent 的未來軌跡。
  2. 跨 agent attention:在每個 decoder block 前,把所有 agents 的 latent features 做 multi-head attention,交換協調訊息,同時保留 index-free 的表示方式。
  3. 條件輸入:可加入當前觀測、回報條件與空條件,支援 classifier-free guidance。
  4. Inverse dynamics:針對每個 agent 額外訓練 inverse dynamics model,把 o_t 與預測的 o_{t+1} 轉成 action。
  5. 兩種部署模式:centralized 版本一次輸出所有 agents 的 joint plan;decentralized 版本只靠單一 agent 的局部觀測規劃自身行為,並隱式預測 teammate trajectories。