您的位置:首頁 > 軟件教程 > 教程 > 強化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】

強化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】

來源:好特整理 | 時間:2024-10-18 09:46:01 | 閱讀:167 |  標(biāo)簽: a T CTO AWA Ri rop Pyre S C ICY Causality AR   | 分享到:

2024年ICML文章,ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization精讀

強化學(xué)習(xí)筆記之【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】


前言:


該論文是清華項目組組內(nèi)博士師兄寫的文章,項目主頁為 ACE (ace-rl.github.io) ,于2024年7月發(fā)表在ICML期刊

因為最近組內(nèi)(其實只有我)需要從零開始做一個相關(guān)項目,前面的幾篇文章都是鋪墊

本文章為強化學(xué)習(xí)筆記第5篇

本文初編輯于2024.10.5,好像是這個時間,忘記了,前后寫了兩個多星期

CSDN主頁: https://blog.csdn.net/rvdgdsva

博客園主頁: https://www.cnblogs.com/hassle

博客園本文鏈接:


論文一覽

這篇強化學(xué)習(xí)論文主要介紹了一個名為 ACE 的算法,完整名稱為 Off-Policy Actor-Critic with Causality-Aware Entropy Regularization ,它通過引入因果關(guān)系分析和因果熵正則化來解決現(xiàn)有模型在不同動作維度上的不平等探索問題,旨在改進(jìn)強化學(xué)習(xí)【注釋1】中探索效率和樣本效率的問題,特別是在高維度連續(xù)控制任務(wù)中的表現(xiàn)。

【注釋1】: 強化學(xué)習(xí)入門這一篇就夠了


論文摘要

在policy【注釋2】學(xué)習(xí)過程中,不同原始行為的不同意義被先前的model-free RL 算法所忽視。利用這一見解,我們探索了不同行動維度和獎勵之間的因果關(guān)系,以評估訓(xùn)練過程中各種原始行為的重要性。我們引入了一個因果關(guān)系感知熵【注釋3】項(causality-aware entropy term),它可以有效地識別并優(yōu)先考慮具有高潛在影響的行為,以實現(xiàn)高效的探索。此外,為了防止過度關(guān)注特定的原始行為,我們分析了梯度休眠現(xiàn)象(gradientdormancyphenomenon),并引入了休眠引導(dǎo)的重置機(jī)制,以進(jìn)一步增強我們方法的有效性。與無模型RL基線相比,我們提出的算法 ACE :Off-policy A ctor-criticwith C ausality-aware E ntropyregularization。在跨越7個域的29種不同連續(xù)控制任務(wù)中顯示出實質(zhì)性的性能優(yōu)勢,這強調(diào)了我們方法的有效性、多功能性和高效的樣本效率。 基準(zhǔn)測試結(jié)果和視頻可在https://ace-rl.github.io/上獲得。

【注釋2】: 強化學(xué)習(xí)算法中on-policy和off-policy

【注釋3】: 最大熵 RL:從Soft Q-Learning到SAC - 知乎


論文主要貢獻(xiàn):

【1】 因果關(guān)系分析 :通過引入因果政策-獎勵結(jié)構(gòu)模型,評估不同動作維度(即原始行為)對獎勵的影響大。ǚQ為“因果權(quán)重”)。這些權(quán)重反映了每個動作維度在不同學(xué)習(xí)階段的相對重要性。

作出上述改進(jìn)的原因是:考慮一個簡單的例子,一個機(jī)械手最初應(yīng)該學(xué)習(xí)放下手臂并抓住物體,然后將注意力轉(zhuǎn)移到學(xué)習(xí)手臂朝著最終目標(biāo)的運動方向上。因此,在策略學(xué)習(xí)的不同階段強調(diào)對最重要的原始行為的探索是 至關(guān)重要的。在探索過程中刻意關(guān)注各種原始行為,可以加速智能體在每個階段對基本原始行為的學(xué)習(xí),從而提高掌握完整運動任務(wù)的效率。

此處可供學(xué)習(xí)的資料:

【2】 因果熵正則化 :在最大熵強化學(xué)習(xí)框架的基礎(chǔ)上(如SAC算法),加入了 因果加權(quán)的熵正則化項 。與傳統(tǒng)熵正則化不同,這一項根據(jù)各個原始行為的因果權(quán)重動態(tài)調(diào)整,強化對重要行為的探索,減少對不重要行為的探索。

作出上述改進(jìn)的原因是:論文引入了一個因果策略-獎勵結(jié)構(gòu)模型來計算行動空間上的因果權(quán)重(causal weights),因果權(quán)重會引導(dǎo)agent進(jìn)行更有效的探索, 鼓勵對因果權(quán)重較大的動作維度進(jìn)行探索,表明對獎勵的重要性更大,并減少對因果權(quán)重較小的行為維度的探 索。一般的最大熵目標(biāo)缺乏對不同學(xué)習(xí)階段原始行為之間區(qū)別的重要性的認(rèn)識,可能導(dǎo)致低效的探索。為了解決這一限制,論文引入了一個由因果權(quán)重加權(quán)的策略熵作為因果關(guān)系感知的熵最大化目標(biāo),有效地加強了對重要原始行為的探索,并導(dǎo)致了更有效的探索。

此處可供學(xué)習(xí)的資料:

【3】 梯度“休眠”現(xiàn)象(Gradient Dormancy) :論文觀察到,模型訓(xùn)練時有些梯度會在某些階段不活躍(即“休眠”)。為了防止模型過度關(guān)注某些原始行為,論文引入了 梯度休眠導(dǎo)向的重置機(jī)制 。該機(jī)制通過周期性地對模型進(jìn)行擾動(reset),避免模型陷入局部最優(yōu),促進(jìn)更廣泛的探索。

作出上述改進(jìn)的原因是:該機(jī)制通過一個由梯度休眠程度決定的因素間歇性地干擾智能體的神經(jīng)網(wǎng)絡(luò)。將因果關(guān)系感知探索與這種新穎的重置機(jī)制相結(jié)合,旨在促進(jìn)更高效、更有效的探索,最終提高智能體的整體性能。

通過在多個連續(xù)控制任務(wù)中的實驗,ACE 展示出了顯著優(yōu)于主流強化學(xué)習(xí)算法(如SAC、TD3)的表現(xiàn):

  • 29個不同的連續(xù)控制任務(wù) :包括 Meta-World(12個任務(wù))、DMControl(5個任務(wù))、Dexterous Hand(3個任務(wù))和其他稀疏獎勵任務(wù)(6個任務(wù))。
  • 實驗結(jié)果 表明,ACE 在所有任務(wù)中都達(dá)到了更好的樣本效率和更高的最終性能。例如,在復(fù)雜的稀疏獎勵場景中,ACE 憑借其因果權(quán)重引導(dǎo)的探索策略,顯著超越了 SAC 和 TD3 等現(xiàn)有算法。

論文中的對比實驗圖表顯示了 ACE 在多種任務(wù)下的顯著優(yōu)勢,尤其是在 稀疏獎勵和高維度任務(wù) 中,ACE 憑借其探索效率的提升,能更快達(dá)到最優(yōu)策略。


論文代碼框架

在ACE原論文的第21頁,這玩意兒應(yīng)該寫在正篇的,害的我看了好久的代碼去排流程

不過說實話這偽代碼有夠簡潔的,代碼多少有點糊成一坨了

這是一個強化學(xué)習(xí)(RL)算法的框架,具體是一個結(jié)合因果推斷(Causal Discovery)的離策略(Off-policy)Actor-Critic方法。下面是對每個模塊及其參數(shù)的說明:

1. 初始化模塊

  • Q網(wǎng)絡(luò) ( \(Q_\phi\) ) :用于估計動作價值,(\phi) 是權(quán)重參數(shù)。
  • 策略網(wǎng)絡(luò) ( $\pi_\theta $) :用于生成動作策略,(\theta) 是其權(quán)重。
  • 重放緩沖區(qū) ($ D$ ) :存儲環(huán)境交互的數(shù)據(jù),以便進(jìn)行采樣。
  • 局部緩沖區(qū) ( $D_c $) :存儲因果發(fā)現(xiàn)所需的局部數(shù)據(jù)。
  • 因果權(quán)重矩陣 ($ B_{a \rightarrow r|s} $) :用于捕捉動作與獎勵之間的因果關(guān)系。
  • 擾動因子 ( \(f\) ) :用于對策略進(jìn)行微小擾動,增加探索。

2. 因果發(fā)現(xiàn)模塊

  • 每 ( $$I$$ ) 步更新
    • 樣本采樣 :從局部緩沖區(qū) ( \(D_c\) ) 中抽樣 ( \(N_c\) ) 條轉(zhuǎn)移。
    • 更新因果權(quán)重矩陣 :調(diào)整 ($ B_{a \rightarrow r|s}$ ),用于反映當(dāng)前策略和獎勵之間的因果關(guān)系。

3. 策略優(yōu)化模塊

  • 每個梯度步驟
    • 樣本采樣 :從重放緩沖區(qū) ( \(D\) ) 中抽樣 ($ N$ ) 條轉(zhuǎn)移。
    • 計算因果意識熵 ( \(H_c(\pi(\cdot|s))\) ) :衡量在給定狀態(tài)下策略的隨機(jī)性和確定性,用于修改策略。
    • 目標(biāo) Q 值計算 :更新目標(biāo) Q 值,用于訓(xùn)練 Q 網(wǎng)絡(luò)。
    • 更新 Q 網(wǎng)絡(luò) :減少預(yù)測的 Q 值與目標(biāo) Q 值之間的誤差。
    • 更新策略網(wǎng)絡(luò) :最大化當(dāng)前狀態(tài)下的 Q 值,以提高收益。

4. 重置機(jī)制模塊

  • 每個重置間隔
    • 計算梯度主導(dǎo)度 ( $\beta_\gamma $) :用來量化策略更新的影響程度。
    • 初始化隨機(jī)網(wǎng)絡(luò) :為新的策略更新準(zhǔn)備初始權(quán)重 ( $\phi_i $)。
    • 軟重置策略和 Q 網(wǎng)絡(luò) :根據(jù)因果權(quán)重進(jìn)行平滑更新,幫助實現(xiàn)更穩(wěn)定的優(yōu)化。
    • 重置策略和 Q 優(yōu)化器 :在重置時清空狀態(tài),以便進(jìn)行新的學(xué)習(xí)過程。

論文源代碼主干

源代碼上千行呢,這里只是貼上main_casual里面的部分代碼,并且刪掉了很大一部分代碼以便理清程序脈絡(luò)

def train_loop(config, msg = "default"):
    # Agent
    agent = ACE_agent(env.observation_space.shape[0], env.action_space, config)

    memory = ReplayMemory(config.replay_size, config.seed)
    local_buffer = ReplayMemory(config.causal_sample_size, config.seed)

    for i_episode in itertools.count(1):
        done = False

        state = env.reset()
        while not done:
            if config.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy

            if len(memory) > config.batch_size:
                for i in range(config.updates_per_step):
                    #* Update parameters of causal weight
                    if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
                        causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                        print("Current Causal Weight is: ",causal_weight)
                        
                    dormant_metrics = {}
                    # Update parameters of all the networks
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight,config.batch_size, updates)

                    updates += 1
            next_state, reward, done, info = env.step(action) # Step
            total_numsteps += 1
            episode_steps += 1
            episode_reward += reward

            #* Ignore the "done" signal if it comes from hitting the time horizon.
            if '_max_episode_steps' in dir(env):  
                mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            elif 'max_path_length' in dir(env):
                mask = 1 if episode_steps == env.max_path_length else float(not done)
            else: 
                mask = 1 if episode_steps == 1000 else float(not done)

            memory.push(state, action, reward, next_state, mask) # Append transition to memory
            local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
            state = next_state

        if total_numsteps > config.num_steps:
            break

        # test agent
        if i_episode % config.eval_interval == 0 and config.eval is True:
            eval_reward_list = []
            for _  in range(config.eval_episodes):
                state = env.reset()
                episode_reward = []
                done = False
                while not done:
                    action = agent.select_action(state, evaluate=True)
                    next_state, reward, done, info = env.step(action)
                    state = next_state
                    episode_reward.append(reward)
                eval_reward_list.append(sum(episode_reward))

            avg_reward = np.average(eval_reward_list)
          
    env.close() 

代碼流程解釋

  1. 初始化 :

    • 通過配置文件 config 設(shè)置環(huán)境和隨機(jī)種子。
    • 使用 ACE_agent 初始化強化學(xué)習(xí)智能體,該智能體會在后續(xù)過程中學(xué)習(xí)如何在環(huán)境中行動。
    • 創(chuàng)建存儲結(jié)果和檢查點的目錄,確保訓(xùn)練過程中的配置和因果權(quán)重會被記錄下來。
    • 初始化了兩個重放緩沖區(qū): memory 用于存儲所有的歷史數(shù)據(jù), local_buffer 則用于因果權(quán)重的更新。
  2. 主訓(xùn)練循環(huán) :

    • 采樣動作 :如果總步數(shù)較小,則從環(huán)境中隨機(jī)采樣動作,否則從策略中選擇動作。通過這種方式,確保早期探索和后期利用。

    • 更新因果權(quán)重 :在特定間隔內(nèi),從局部緩沖區(qū)中采樣數(shù)據(jù),通過 get_sa2r_weight 函數(shù)使用DirectLiNGAM算法計算從動作到獎勵的因果權(quán)重。這個權(quán)重會作為額外信息,幫助智能體優(yōu)化策略。

    • 更新網(wǎng)絡(luò)參數(shù) :當(dāng) memory 中的數(shù)據(jù)足夠多時,開始通過采樣更新Q網(wǎng)絡(luò)和策略網(wǎng)絡(luò),使用計算出的因果權(quán)重來修正損失函數(shù)。

    • 記錄與保存模型 :每隔一定的步數(shù),算法會測試當(dāng)前策略的性能,記錄并比較獎勵是否超過歷史最佳值,如果是,則保存模型的檢查點。

    • 使用 wandb 記錄訓(xùn)練過程中的指標(biāo),例如損失函數(shù)、獎勵和因果權(quán)重的計算時間,這些信息可以幫助調(diào)試和分析訓(xùn)練過程。


論文模塊代碼及實現(xiàn)

因果發(fā)現(xiàn)模塊

因果發(fā)現(xiàn)模塊 主要通過 get_sa2r_weight 函數(shù)實現(xiàn),并且與 DirectLiNGAM 模型結(jié)合,負(fù)責(zé)計算因果權(quán)重。具體代碼在訓(xùn)練循環(huán)中如下:

causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')

在這個代碼段, get_sa2r_weight 函數(shù)會基于當(dāng)前環(huán)境、樣本數(shù)據(jù)( local_buffer )和因果模型(這里使用的是 DirectLiNGAM ),計算與行動相關(guān)的因果權(quán)重( causal_weight )。這些權(quán)重會影響后續(xù)的策略優(yōu)化和參數(shù)更新。關(guān)鍵邏輯包括:

  1. 采樣間隔 :因果發(fā)現(xiàn)是在 total_numsteps % config.causal_sample_interval == 0 時觸發(fā),確保只在指定的步數(shù)間隔內(nèi)計算因果權(quán)重,避免每一步都進(jìn)行因果計算,減輕計算負(fù)擔(dān)。
  2. 局部緩沖區(qū) local_buffer 中存儲了足夠的樣本( config.causal_sample_size ),這些樣本用于因果關(guān)系的發(fā)現(xiàn)。
  3. 因果方法 DirectLiNGAM 是選擇的因果模型,用于從狀態(tài)、行動和獎勵之間推導(dǎo)出因果關(guān)系。

因果權(quán)重計算完成后,程序會將這些權(quán)重應(yīng)用到策略優(yōu)化中,并且記錄權(quán)重及計算時間等信息。

def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
    ······
    return weight, model._running_time

這個代碼的核心是利用DirectLiNGAM模型計算給定狀態(tài)、動作和獎勵之間的因果權(quán)重。接下來,用LaTeX公式詳細(xì)表述計算因果權(quán)重的過程:

  1. 數(shù)據(jù)預(yù)處理
    將從 memory 中采樣的 states (狀態(tài))、 actions (動作)和 rewards (獎勵)進(jìn)行拼接,構(gòu)建輸入數(shù)據(jù)矩陣 \(X_{\text{ori}}\)

    其中, \(S\) 代表狀態(tài), \(A\) 代表動作, \(R\) 代表獎勵。接著,構(gòu)建數(shù)據(jù)框 \(X\) 來進(jìn)行因果分析。

  2. 因果模型擬合

    X_ori 轉(zhuǎn)換為 X 是為了利用 pandas 數(shù)據(jù)框的便利性和靈活性

    使用 DirectLiNGAM 模型對矩陣 \(X\) 進(jìn)行擬合,得到因果關(guān)系的鄰接矩陣 \(A_{\text{model}}\)

    該鄰接矩陣表示狀態(tài)、動作、獎勵之間的因果結(jié)構(gòu),特別是從動作到獎勵的影響關(guān)系。

  3. 提取動作對獎勵的因果權(quán)重
    通過鄰接矩陣提取動作對獎勵的因果權(quán)重 \(w_{\text{r}}\) ,該權(quán)重從鄰接矩陣的最后一行中選擇與動作對應(yīng)的元素:

    其中, \(d_s\) 是狀態(tài)的維度, \(d_a\) 是動作的維度。

  4. 因果權(quán)重的歸一化
    對因果權(quán)重 \(w_{\text{r}}\) 進(jìn)行Softmax歸一化,確保它們的總和為1:

  5. 調(diào)整權(quán)重的尺度
    最后,因果權(quán)重根據(jù)動作的數(shù)量進(jìn)行縮放:

最終輸出的權(quán)重 \(w\) 表示每個動作對獎勵的因果影響,經(jīng)過歸一化和縮放處理,可以用于進(jìn)一步的策略調(diào)整或分析。

策略優(yōu)化模塊

以下是對函數(shù)工作原理的逐步解釋:

策略優(yōu)化模塊 主要由 agent.update_parameters 函數(shù)實現(xiàn)。 agent.update_parameters 這個函數(shù)的主要目的是在強化學(xué)習(xí)中更新策略 ( policy ) 和價值網(wǎng)絡(luò)(critic)的參數(shù),以提升智能體的性能。這個函數(shù)實現(xiàn)了一個基于軟演員評論家(SAC, Soft Actor-Critic)的更新機(jī)制,并且加入了因果權(quán)重與"休眠"神經(jīng)元(dormant neurons)的處理,以提高模型的魯棒性和穩(wěn)定性。

critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight, config.batch_size, updates)

通過 agent.update_parameters 函數(shù),程序會更新以下幾個部分:

  1. Critic網(wǎng)絡(luò)(價值網(wǎng)絡(luò)) critic_1_loss critic_2_loss 分別是兩個 Critic 網(wǎng)絡(luò)的損失,用于評估當(dāng)前策略的價值。
  2. Policy網(wǎng)絡(luò)(策略網(wǎng)絡(luò)) policy_loss 表示策略網(wǎng)絡(luò)的損失,用于優(yōu)化 agent 的行動選擇。
  3. Entropy損失 ent_loss 用來調(diào)節(jié)策略的隨機(jī)性,幫助 agent 在探索和利用之間找到平衡。
  4. Alpha :表示自適應(yīng)的熵系數(shù),用于調(diào)整探索與利用之間的權(quán)衡。

這些參數(shù)的更新在每次訓(xùn)練循環(huán)中被調(diào)用,并使用 wandb.log 記錄損失和其他相關(guān)的訓(xùn)練數(shù)據(jù)。

update_parameters ACE_agent 類中的一個關(guān)鍵函數(shù),用于根據(jù)經(jīng)驗回放緩沖區(qū)中的樣本數(shù)據(jù)來更新模型的參數(shù)。下面是對其工作原理的詳細(xì)解釋:

1. 采樣經(jīng)驗數(shù)據(jù)

首先,函數(shù)從 memory 中采樣一批樣本( state_batch action_batch 、 reward_batch next_state_batch 、 mask_batch ),其中包括狀態(tài)、動作、獎勵、下一個狀態(tài)以及掩碼,用于表示是否為終止?fàn)顟B(tài)。

state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
  • state_batch :當(dāng)前的狀態(tài)。
  • action_batch :在當(dāng)前狀態(tài)下執(zhí)行的動作。
  • reward_batch :執(zhí)行該動作后獲得的獎勵。
  • next_state_batch :執(zhí)行動作后到達(dá)的下一個狀態(tài)。
  • mask_batch :掩碼,用于表示是否為終止?fàn)顟B(tài)(1 表示非終止,0 表示終止)。

2. 計算目標(biāo) Q 值

利用當(dāng)前策略(policy)網(wǎng)絡(luò),采樣下一個狀態(tài)的動作 next_state_action 和其對應(yīng)的概率分布對數(shù) next_state_log_pi 。然后利用目標(biāo) Q 網(wǎng)絡(luò) critic_target 估計下一時刻的最小 Q 值,并結(jié)合獎勵和折扣因子 \(\gamma\) 計算下一個 Q 值:

with torch.no_grad():
    next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch, causal_weight)
    qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
    next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
  • 通過策略網(wǎng)絡(luò) self.policy 為下一個狀態(tài) next_state_batch 采樣動作 next_state_action 和相應(yīng)的策略熵 next_state_log_pi 。

  • 使用目標(biāo) Q 網(wǎng)絡(luò)計算 qf1_next_target qf2_next_target ,并取兩者的最小值來減少估計偏差。

  • 最終使用貝爾曼方程計算 next_q_value ,即當(dāng)前的獎勵加上折扣因子 \(\gamma\) 乘以下一個狀態(tài)的 Q 值。

  • 這里, \(\alpha\) 是熵項的權(quán)重,用于平衡探索和利用的權(quán)衡,而 mask_batch 是為了處理終止?fàn)顟B(tài)的情況。

    使用無偏估計來計算目標(biāo) Q 值。通過目標(biāo)網(wǎng)絡(luò) ( critic_target ) 計算出下一個狀態(tài)和動作的 Q 值,并使用獎勵和掩碼更新當(dāng)前 Q 值

3. 更新 Q 網(wǎng)絡(luò)

接著,使用當(dāng)前 Q 網(wǎng)絡(luò) critic 估計當(dāng)前狀態(tài)和動作下的 Q 值 \(Q_1\) \(Q_2\) ,并計算它們與目標(biāo) Q 值的均方誤差損失:

最終 Q 網(wǎng)絡(luò)的總損失是兩個 Q 網(wǎng)絡(luò)損失之和:

然后,通過反向傳播 qf_loss 來更新 Q 網(wǎng)絡(luò)的參數(shù)。

qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss

self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
  • qf1 qf2 是兩個 Q 網(wǎng)絡(luò)的輸出,用于減少正向估計偏差。
  • 損失函數(shù)是 Q 值的均方誤差(MSE), qf1_loss qf2_loss 分別計算兩個 Q 網(wǎng)絡(luò)的誤差,最后將兩者相加為總的 Q 損失 qf_loss 。
  • 通過 self.critic_optim 優(yōu)化器對損失進(jìn)行反向傳播和參數(shù)更新。

4. 策略網(wǎng)絡(luò)更新

每隔若干步(通過 target_update_interval 控制),開始更新策略網(wǎng)絡(luò) policy 。首先,重新采樣當(dāng)前狀態(tài)下的策略 \(\pi(a|s)\) ,并計算 Q 值和熵權(quán)重下的策略損失:

這個損失通過反向傳播更新策略網(wǎng)絡(luò)。

if updates % self.target_update_interval == 0:
    pi, log_pi, _ = self.policy.sample(state_batch, causal_weight)
    qf1_pi, qf2_pi = self.critic(state_batch, pi)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

    self.policy_optim.zero_grad()
    policy_loss.backward()
    self.policy_optim.step()
  • 通過策略網(wǎng)絡(luò)對當(dāng)前狀態(tài) state_batch 進(jìn)行采樣,得到動作 pi 及其對應(yīng)的策略熵 log_pi 。
  • 計算策略損失 policy_loss ,即 \(\alpha\) 倍的策略熵減去最小的 Q 值。
  • 通過 self.policy_optim 優(yōu)化器對策略損失進(jìn)行反向傳播和參數(shù)更新。

5. 自適應(yīng)熵調(diào)節(jié)

如果開啟了自動熵項調(diào)整( automatic_entropy_tuning ),則會進(jìn)一步更新熵項 \(\alpha\) 的損失:

并通過梯度下降更新 \(\alpha\) 。

如果 automatic_entropy_tuning 為真,則會更新熵項:

if self.automatic_entropy_tuning:
    alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
    self.alpha_optim.zero_grad()
    alpha_loss.backward()
    self.alpha_optim.step()
    self.alpha = self.log_alpha.exp()
    alpha_tlogs = self.alpha.clone()
else:
    alpha_loss = torch.tensor(0.).to(self.device)
    alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
  • 通過計算 alpha_loss 更新 self.alpha ,調(diào)整策略的探索-利用平衡。

6. 返回值

  • qf1_loss , qf2_loss : 兩個 Q 網(wǎng)絡(luò)的損失
  • policy_loss : 策略網(wǎng)絡(luò)的損失
  • alpha_loss : 熵權(quán)重的損失
  • alpha_tlogs : 用于日志記錄的熵權(quán)重
  • next_q_value : 平均下一個 Q 值
  • dormant_metrics : 休眠神經(jīng)元的相關(guān)度量

重置機(jī)制模塊

重置機(jī)制模塊在代碼中主要體現(xiàn)在 update_parameters 函數(shù)中,并通過 梯度主導(dǎo)度 (dominant metrics) 和 擾動函數(shù) (perturbation functions) 實現(xiàn)對策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò)的重置。

重置邏輯

函數(shù)根據(jù)設(shè)定的 reset_interval 判斷是否需要對策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò)進(jìn)行擾動和重置。這里使用了"休眠"神經(jīng)元的概念,即一些在梯度更新中影響較小的神經(jīng)元,可能會被調(diào)整或重置。

函數(shù)計算了休眠度量 dormant_metrics 和因果權(quán)重差異 causal_diff ,通過擾動因子 perturb_factor 來決定是否對網(wǎng)絡(luò)進(jìn)行部分或全部的擾動與重置。

重置機(jī)制模塊的原理

重置機(jī)制主要由以下部分組成:

1. 計算梯度主導(dǎo)度 ( $\beta_\gamma $)

在更新策略時,計算 主導(dǎo)梯度 ,即某些特定神經(jīng)元或參數(shù)在更新中主導(dǎo)作用的比率。代碼中通過調(diào)用 cal_dormant_grad(self.policy, type='policy', percentage=0.05) 實現(xiàn)這一計算,代表提取出 5% 的主導(dǎo)梯度來作為判斷因子。

dormant_metrics = cal_dormant_grad(self.policy, type='policy', percentage=0.05)

根據(jù)主導(dǎo)度 ($ \beta_\gamma$ ) 和權(quán)重 ($ w$ ),可以得到因果效應(yīng)的差異。代碼里用 causal_diff 來表示因果差異:

2. 軟重置策略和 Q 網(wǎng)絡(luò)

軟重置機(jī)制通過平滑更新策略網(wǎng)絡(luò)和 Q 網(wǎng)絡(luò),避免過大的權(quán)重更新導(dǎo)致的網(wǎng)絡(luò)不穩(wěn)定。這在代碼中由 soft_update 實現(xiàn):

soft_update(self.critic_target, self.critic, self.tau)

具體來說,軟更新的公式為:

其中,( \(\tau\) ) 是一個較小的常數(shù),通常介于 ( [0, 1] ) 之間,確保目標(biāo)網(wǎng)絡(luò)的更新是緩慢的,以提高學(xué)習(xí)的穩(wěn)定性。

3. 策略和 Q 優(yōu)化器的重置
4. 重置機(jī)制模塊的應(yīng)用

每當(dāng)經(jīng)過一定的重置間隔時,判斷是否需要擾動策略和 Q 網(wǎng)絡(luò)。通過調(diào)用 perturb() dormant_perturb() 實現(xiàn)對網(wǎng)絡(luò)的擾動(perturbation)。擾動因子由梯度主導(dǎo)度和因果差異共同決定。

策略與 Q 網(wǎng)絡(luò)的擾動會在以下兩種情況下發(fā)生:

a. 重置間隔達(dá)成時

代碼中每當(dāng)更新次數(shù) updates 達(dá)到設(shè)定的重置間隔 self.reset_interval ,并且 updates > 5000 時,才會觸發(fā)策略與 Q 網(wǎng)絡(luò)的重置邏輯。這是為了確保擾動不是頻繁發(fā)生,而是在經(jīng)過一段較長的訓(xùn)練時間后進(jìn)行。

具體判斷條件:

if updates % self.reset_interval == 0 and updates > 5000:
b. 主導(dǎo)梯度或因果效應(yīng)差異滿足條件時

在達(dá)到了重置間隔后,首先會計算 梯度主導(dǎo)度 因果效應(yīng)的差異 。這可以通過計算因果差異 causal_diff 或梯度主導(dǎo)度 dormant_metrics['policy_grad_dormant_ratio'] 來決定是否需要擾動。

  • 梯度主導(dǎo)度 計算方式通過 cal_dormant_grad() 函數(shù)實現(xiàn),如果梯度主導(dǎo)度較低,意味著網(wǎng)絡(luò)中的某些神經(jīng)元更新幅度過小,則需要對網(wǎng)絡(luò)進(jìn)行擾動。

  • 因果效應(yīng)差異 通過計算 causal_diff = np.max(causal_weight) - np.min(causal_weight) 得到,如果差異過大,則可能需要重置。

然后根據(jù)這些值通過擾動因子 factor 進(jìn)行判斷:

factor = perturb_factor(dormant_metrics['policy_grad_dormant_ratio'])

如果擾動因子 ( \(\text{factor} < 1\) ),網(wǎng)絡(luò)會進(jìn)行擾動:

if factor < 1:
    if self.reset == 'reset' or self.reset == 'causal_reset':
        perturb(self.policy, self.policy_optim, factor)
        perturb(self.critic, self.critic_optim, factor)
        perturb(self.critic_target, self.critic_optim, factor)
c. 總結(jié)
  • 更新次數(shù)達(dá)到設(shè)定的重置間隔 ,且經(jīng)過了一定時間的訓(xùn)練( updates > 5000 )。
  • 梯度主導(dǎo)度 較低或 因果效應(yīng)差異 過大,導(dǎo)致計算出的擾動因子 ( $\text{factor} < 1 $)。

這兩種條件同時滿足時,策略和 Q 網(wǎng)絡(luò)將被擾動或重置。

擾動因子的計算

在這段代碼中, factor 是基于網(wǎng)絡(luò)中梯度主導(dǎo)度或者因果效應(yīng)差異計算出來的擾動因子。擾動因子通過函數(shù) perturb_factor() 進(jìn)行計算,該函數(shù)會根據(jù)神經(jīng)元的梯度主導(dǎo)度( dormant_ratio )或因果效應(yīng)差異( causal_diff )來調(diào)整 factor 的大小。

擾動因子(factor)

擾動因子 factor 的計算公式如下:

其中:

  • ( \(\text{dormant\_ratio}\) ) 是網(wǎng)絡(luò)中梯度主導(dǎo)度,即表示有多少神經(jīng)元的梯度變化較小(或者接近零),處于“休眠”狀態(tài)。

  • ( \(\text{min\_perturb\_factor}\) ) 是最小擾動因子值,代碼中設(shè)定為 0.2

  • ( \(\text{max\_perturb\_factor}\) ) 是最大擾動因子值,代碼中設(shè)定為 0.9

  • dormant_ratio :

    • 表示網(wǎng)絡(luò)中處于“休眠狀態(tài)”的梯度比例。這個比例通常通過計算神經(jīng)網(wǎng)絡(luò)中梯度幅度小于某個閾值的神經(jīng)元數(shù)量來獲得。 dormant_ratio 越大,表示越多神經(jīng)元的梯度變化很小,說明網(wǎng)絡(luò)更新不充分,需要擾動。
  • max_perturb_factor :

    • 最大擾動因子值,用來限制擾動因子的上限,代碼中設(shè)定為 0.9,意味著最大擾動幅度不會超過 90%。
  • min_perturb_factor :

    • 最小擾動因子值,用來限制擾動因子的下限,代碼中設(shè)定為 0.2,意味著即使休眠神經(jīng)元比例很低,擾動幅度也不會小于 20%。

在計算因果效應(yīng)的部分,擾動因子 factor 還會根據(jù)因果效應(yīng)差異 causal_diff 來調(diào)整。 causal_diff 是通過計算因果效應(yīng)的最大值與最小值的差異來獲得的:

計算出的 causal_diff 會影響 causal_factor ,并進(jìn)一步對 factor 進(jìn)行調(diào)整:

組合擾動因子的公式

最后,如果選擇了因果重置( causal_reset ),擾動因子將使用因果差異計算出的 causal_factor 進(jìn)行二次調(diào)整:

綜上所述, factor 的最終值是由梯度主導(dǎo)度或因果效應(yīng)差異來控制的,當(dāng)休眠神經(jīng)元比例較大或因果效應(yīng)差異較大時, factor 會減小,導(dǎo)致網(wǎng)絡(luò)進(jìn)行擾動。

評估代碼

這段代碼主要實現(xiàn)了在強化學(xué)習(xí)(RL)訓(xùn)練過程中,定期評估智能體(agent)的性能,并在某些條件下保存最佳模型的檢查點。我們可以分段解釋該代碼:

1. 定期評估條件

if i_episode % config.eval_interval == 0 and config.eval is True:

這部分代碼用于判斷是否應(yīng)該執(zhí)行智能體的評估。條件為:

  • i_episode % config.eval_interval == 0 :表示每隔 config.eval_interval 個訓(xùn)練回合( i_episode 是當(dāng)前回合數(shù))進(jìn)行一次評估。
  • config.eval is True :確保 eval 設(shè)置為 True ,也就是說,評估功能開啟。

如果滿足這兩個條件,代碼將開始執(zhí)行評估操作。

2. 初始化評估列表

eval_reward_list = []

用于存儲每個評估回合(episode)的累計獎勵,以便之后計算平均獎勵。

3. 進(jìn)行評估

for _ in range(config.eval_episodes):

評估階段將運行多個回合(由 config.eval_episodes 指定的回合數(shù)),以獲得智能體的表現(xiàn)。

3.1 回合初始化
state = env.reset()
episode_reward = []
done = False
  • env.reset() :重置環(huán)境,獲得初始狀態(tài) state
  • episode_reward :初始化一個列表,用于存儲當(dāng)前回合中智能體獲得的所有獎勵。
  • done = False :用 done 來跟蹤當(dāng)前回合是否結(jié)束。
3.2 執(zhí)行智能體動作
while not done:
    action = agent.select_action(state, evaluate=True)
    next_state, reward, done, info = env.step(action)
    state = next_state
    episode_reward.append(reward)
  • 動作選擇 agent.select_action(state, evaluate=True) 在評估模式下根據(jù)當(dāng)前狀態(tài) state 選擇動作。 evaluate=True 表示該選擇是在評估模式下,通常意味著探索行為被關(guān)閉(即不進(jìn)行隨機(jī)探索,而是選擇最優(yōu)動作)。

  • 環(huán)境反饋 next_state, reward, done, info = env.step(action) 通過執(zhí)行動作 action ,環(huán)境返回下一個狀態(tài) next_state ,當(dāng)前獎勵 reward ,回合是否結(jié)束的標(biāo)志 done ,以及附加信息 info 。

  • 狀態(tài)更新 :當(dāng)前狀態(tài)被更新為 next_state ,并將獲得的獎勵 reward 存儲在 episode_reward 列表中。

循環(huán)持續(xù),直到回合結(jié)束(即 done == True )。

3.3 存儲回合獎勵
eval_reward_list.append(sum(episode_reward))

當(dāng)前回合結(jié)束后,累計獎勵( sum(episode_reward) )被添加到 eval_reward_list ,用于后續(xù)計算平均獎勵。

4. 計算平均獎勵

avg_reward = np.average(eval_reward_list)

在所有評估回合結(jié)束后,計算 eval_reward_list 的平均值 avg_reward 。這是當(dāng)前評估階段智能體的表現(xiàn)指標(biāo)。

5. 保存最佳模型

if config.save_checkpoint:
    if avg_reward >= best_reward:
        best_reward = avg_reward
        agent.save_checkpoint(checkpoint_path, 'best')
  • 如果 config.save_checkpoint True ,則表示需要檢查是否保存模型。
  • 通過判斷 avg_reward 是否超過了之前的最佳獎勵 best_reward ,如果是,則更新 best_reward ,并保存當(dāng)前模型的檢查點。
agent.save_checkpoint(checkpoint_path, 'best')

這行代碼會將智能體的狀態(tài)保存到指定的路徑 checkpoint_path ,并標(biāo)記為 "best" ,表示這是性能最佳的模型。

論文復(fù)現(xiàn)結(jié)果

咳咳,可以發(fā)現(xiàn)程序只記錄了 0~1000 的數(shù)據(jù),從 1001 開始的每一個數(shù)據(jù)都顯示報錯所以被舍棄掉了。

后面重新下載了github代碼包,發(fā)生了同樣的報錯信息

報錯信息是:你在 X+1 輪次中嘗試記載 X 輪次中的信息,所以這個數(shù)據(jù)被舍棄掉了

大概是主程序哪里有問題吧,我自己也沒調(diào) bug

不過這個項目結(jié)題了,主要負(fù)責(zé)這個項目的博士師兄也畢業(yè)了,也不好說些什么(雖然我有他微信),至少論文里面的模塊挺有用的。ㄊ謩踊

小編推薦閱讀

好特網(wǎng)發(fā)布此文僅為傳遞信息,不代表好特網(wǎng)認(rèn)同期限觀點或證實其描述。

a 1.0
a 1.0
類型:休閑益智  運營狀態(tài):正式運營  語言:中文   

游戲攻略

游戲禮包

游戲視頻

游戲下載

游戲活動

《alittletotheleft》官網(wǎng)正版是一款備受歡迎的休閑益智整理游戲。玩家的任務(wù)是對日常生活中的各種雜亂物
AWA 1.40
AWA 1.40
類型:休閑益智  運營狀態(tài):未知  語言:中文   

游戲攻略

游戲禮包

游戲視頻

游戲下載

游戲活動

《AWA》安卓版是由開發(fā)商MentalLab研發(fā)的一款帶有奇幻色彩的神秘迷宮冒險游戲,華麗而精美的游戲界面,讓

相關(guān)視頻攻略

更多

掃二維碼進(jìn)入好特網(wǎng)手機(jī)版本!

掃二維碼進(jìn)入好特網(wǎng)微信公眾號!

本站所有軟件,都由網(wǎng)友上傳,如有侵犯你的版權(quán),請發(fā)郵件[email protected]

湘ICP備2022002427號-10 湘公網(wǎng)安備:43070202000427號© 2013~2024 haote.com 好特網(wǎng)